Files
gangyan/langchain-chat/server/agent/tools/search_internet.py

90 lines
3.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import asyncio
import unicodedata
from server.chat.KgoSearchAPIWrapper import KgoSearchAPIWrapper
from server.chat.search_engine_chat import search_engine_chat
from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
from server.agent import model_container
from pydantic import BaseModel, Field
from configs import LLM_MODELS, TEMPERATURE
from configs.basic_config import *
# def get_kgo_search_type(query: str = "全部"):
# # 过滤掉所有非汉字的字符
# query = ''.join(char for char in query if unicodedata.category(char) == 'Lo')
# search_map = KgoSearchAPIWrapper().search_map
# if "论文" in query:
# return "1001"
# elif "外文" in query or "英文" in query:
# return "1013"
# elif "期刊" in query or "研究进展" in query:
# return "1002"
# else:
# matched_types = [value for key, value in search_map.items() if key in query]
# if matched_types:
# return ','.join(matched_types)
# else:
# print("未找到匹配的搜索类型返回默认值1000")
# return "1000"
@timing_decorator
async def search_engine_iter(query: str , uid: str):
response = await search_engine_chat(uid = uid,
query=query,
search_engine_name="zhipu_search",
model_name=LLM_MODELS[0],
temperature=TEMPERATURE, # Agent搜索互联网的时候温度设为0.1
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="search",
stream=False,
kgo_search_type="1000")
contents = ""
docs = []
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents = data.get("answer", [])
current_docs = data.get("docs", [])
if current_docs:
docs.extend(current_docs)
else:
logging.error("No docs found in the response")
return docs
# print("contents", contents)
# print("docs", docs)
# 回复搜索结果和搜索结果的总结
return contents,docs
def search_internet(query: str, uid: str):
# filter_words = {
# "统计数据", "视频", "数据集", "新闻", "专利", "期刊", "图书", "报告", "项目", "成果", "会议论文",
# "政策", "外文期刊论文", "学位论文", "期刊论文", "全部论文", "原文", "全文", "pdf", "资料", " ",
# "进展", "研究", "最新", "外文", "英文", "最新", "文件", "资料", "论文"
# }
# filtered_query = query
# for word in filter_words:
# if word in query:
# filtered_query = filtered_query.replace(word, "")
# print("filtered query:", filtered_query)
# kgo_search_type = get_kgo_search_type(query)
# print("kgo_search_type:", kgo_search_type)
# 使用过滤后的查询字符串进行搜索
return asyncio.run(search_engine_iter(query, uid))
class SearchInternetInput(BaseModel):
location: str = Field(description="Query for Internet search")
if __name__ == "__main__":
result = search_internet("人工智能领域的政策")
print("答案:", result)