Files
gangyan/langchain-chat/server/agent/tools/search_internet.py

90 lines
3.4 KiB
Python
Raw Normal View History

import json
import asyncio
import unicodedata
from server.chat.KgoSearchAPIWrapper import KgoSearchAPIWrapper
from server.chat.search_engine_chat import search_engine_chat
from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
from server.agent import model_container
from pydantic import BaseModel, Field
from configs import LLM_MODELS, TEMPERATURE
from configs.basic_config import *
# def get_kgo_search_type(query: str = "全部"):
# # 过滤掉所有非汉字的字符
# query = ''.join(char for char in query if unicodedata.category(char) == 'Lo')
# search_map = KgoSearchAPIWrapper().search_map
# if "论文" in query:
# return "1001"
# elif "外文" in query or "英文" in query:
# return "1013"
# elif "期刊" in query or "研究进展" in query:
# return "1002"
# else:
# matched_types = [value for key, value in search_map.items() if key in query]
# if matched_types:
# return ','.join(matched_types)
# else:
# print("未找到匹配的搜索类型返回默认值1000")
# return "1000"
@timing_decorator
async def search_engine_iter(query: str , uid: str):
response = await search_engine_chat(uid = uid,
query=query,
search_engine_name="zhipu_search",
model_name=LLM_MODELS[0],
temperature=TEMPERATURE, # Agent搜索互联网的时候温度设为0.1
history=[],
top_k=VECTOR_SEARCH_TOP_K,
max_tokens=MAX_TOKENS,
prompt_name="search",
stream=False,
kgo_search_type="1000")
contents = ""
docs = []
async for data in response.body_iterator: # 这里的data是一个json字符串
data = json.loads(data)
contents = data.get("answer", [])
current_docs = data.get("docs", [])
if current_docs:
docs.extend(current_docs)
else:
logging.error("No docs found in the response")
return docs
# print("contents", contents)
# print("docs", docs)
# 回复搜索结果和搜索结果的总结
return contents,docs
def search_internet(query: str, uid: str):
# filter_words = {
# "统计数据", "视频", "数据集", "新闻", "专利", "期刊", "图书", "报告", "项目", "成果", "会议论文",
# "政策", "外文期刊论文", "学位论文", "期刊论文", "全部论文", "原文", "全文", "pdf", "资料", " ",
# "进展", "研究", "最新", "外文", "英文", "最新", "文件", "资料", "论文"
# }
# filtered_query = query
# for word in filter_words:
# if word in query:
# filtered_query = filtered_query.replace(word, "")
# print("filtered query:", filtered_query)
# kgo_search_type = get_kgo_search_type(query)
# print("kgo_search_type:", kgo_search_type)
# 使用过滤后的查询字符串进行搜索
return asyncio.run(search_engine_iter(query, uid))
class SearchInternetInput(BaseModel):
location: str = Field(description="Query for Internet search")
if __name__ == "__main__":
result = search_internet("人工智能领域的政策")
print("答案:", result)