90 lines
3.4 KiB
Python
90 lines
3.4 KiB
Python
import json
|
||
import asyncio
|
||
import unicodedata
|
||
|
||
from server.chat.KgoSearchAPIWrapper import KgoSearchAPIWrapper
|
||
from server.chat.search_engine_chat import search_engine_chat
|
||
from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
|
||
from server.agent import model_container
|
||
from pydantic import BaseModel, Field
|
||
from configs import LLM_MODELS, TEMPERATURE
|
||
from configs.basic_config import *
|
||
|
||
# def get_kgo_search_type(query: str = "全部"):
|
||
# # 过滤掉所有非汉字的字符
|
||
# query = ''.join(char for char in query if unicodedata.category(char) == 'Lo')
|
||
# search_map = KgoSearchAPIWrapper().search_map
|
||
|
||
# if "论文" in query:
|
||
# return "1001"
|
||
# elif "外文" in query or "英文" in query:
|
||
# return "1013"
|
||
# elif "期刊" in query or "研究进展" in query:
|
||
# return "1002"
|
||
# else:
|
||
# matched_types = [value for key, value in search_map.items() if key in query]
|
||
# if matched_types:
|
||
# return ','.join(matched_types)
|
||
# else:
|
||
# print("未找到匹配的搜索类型,返回默认值:1000")
|
||
# return "1000"
|
||
|
||
@timing_decorator
|
||
async def search_engine_iter(query: str , uid: str):
|
||
response = await search_engine_chat(uid = uid,
|
||
query=query,
|
||
search_engine_name="zhipu_search",
|
||
model_name=LLM_MODELS[0],
|
||
temperature=TEMPERATURE, # Agent搜索互联网的时候,温度设为0.1
|
||
history=[],
|
||
top_k=VECTOR_SEARCH_TOP_K,
|
||
max_tokens=MAX_TOKENS,
|
||
prompt_name="search",
|
||
stream=False,
|
||
kgo_search_type="1000")
|
||
|
||
contents = ""
|
||
docs = []
|
||
|
||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||
data = json.loads(data)
|
||
contents = data.get("answer", [])
|
||
current_docs = data.get("docs", [])
|
||
if current_docs:
|
||
docs.extend(current_docs)
|
||
else:
|
||
logging.error("No docs found in the response")
|
||
return docs
|
||
# print("contents:", contents)
|
||
# print("docs:", docs)
|
||
# 回复搜索结果和搜索结果的总结
|
||
return contents,docs
|
||
|
||
|
||
def search_internet(query: str, uid: str):
|
||
# filter_words = {
|
||
# "统计数据", "视频", "数据集", "新闻", "专利", "期刊", "图书", "报告", "项目", "成果", "会议论文",
|
||
# "政策", "外文期刊论文", "学位论文", "期刊论文", "全部论文", "原文", "全文", "pdf", "资料", " ",
|
||
# "进展", "研究", "最新", "外文", "英文", "最新", "文件", "资料", "论文"
|
||
# }
|
||
# filtered_query = query
|
||
# for word in filter_words:
|
||
# if word in query:
|
||
# filtered_query = filtered_query.replace(word, "")
|
||
# print("filtered query:", filtered_query)
|
||
|
||
# kgo_search_type = get_kgo_search_type(query)
|
||
# print("kgo_search_type:", kgo_search_type)
|
||
|
||
# 使用过滤后的查询字符串进行搜索
|
||
return asyncio.run(search_engine_iter(query, uid))
|
||
|
||
|
||
class SearchInternetInput(BaseModel):
|
||
location: str = Field(description="Query for Internet search")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
result = search_internet("人工智能领域的政策")
|
||
print("答案:", result)
|