[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
89
langchain-chat/server/agent/tools/search_internet.py
Normal file
89
langchain-chat/server/agent/tools/search_internet.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import json
|
||||
import asyncio
|
||||
import unicodedata
|
||||
|
||||
from server.chat.KgoSearchAPIWrapper import KgoSearchAPIWrapper
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
|
||||
from server.agent import model_container
|
||||
from pydantic import BaseModel, Field
|
||||
from configs import LLM_MODELS, TEMPERATURE
|
||||
from configs.basic_config import *
|
||||
|
||||
# def get_kgo_search_type(query: str = "全部"):
|
||||
# # 过滤掉所有非汉字的字符
|
||||
# query = ''.join(char for char in query if unicodedata.category(char) == 'Lo')
|
||||
# search_map = KgoSearchAPIWrapper().search_map
|
||||
|
||||
# if "论文" in query:
|
||||
# return "1001"
|
||||
# elif "外文" in query or "英文" in query:
|
||||
# return "1013"
|
||||
# elif "期刊" in query or "研究进展" in query:
|
||||
# return "1002"
|
||||
# else:
|
||||
# matched_types = [value for key, value in search_map.items() if key in query]
|
||||
# if matched_types:
|
||||
# return ','.join(matched_types)
|
||||
# else:
|
||||
# print("未找到匹配的搜索类型,返回默认值:1000")
|
||||
# return "1000"
|
||||
|
||||
@timing_decorator
|
||||
async def search_engine_iter(query: str , uid: str):
|
||||
response = await search_engine_chat(uid = uid,
|
||||
query=query,
|
||||
search_engine_name="zhipu_search",
|
||||
model_name=LLM_MODELS[1],
|
||||
temperature=TEMPERATURE, # Agent搜索互联网的时候,温度设为0.1
|
||||
history=[],
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
max_tokens=MAX_TOKENS,
|
||||
prompt_name="search",
|
||||
stream=False,
|
||||
kgo_search_type="1000")
|
||||
|
||||
contents = ""
|
||||
docs = []
|
||||
|
||||
async for data in response.body_iterator: # 这里的data是一个json字符串
|
||||
data = json.loads(data)
|
||||
contents = data.get("answer", [])
|
||||
current_docs = data.get("docs", [])
|
||||
if current_docs:
|
||||
docs.extend(current_docs)
|
||||
else:
|
||||
logging.error("No docs found in the response")
|
||||
return docs
|
||||
# print("contents:", contents)
|
||||
# print("docs:", docs)
|
||||
# 回复搜索结果和搜索结果的总结
|
||||
return contents,docs
|
||||
|
||||
|
||||
def search_internet(query: str, uid: str):
|
||||
# filter_words = {
|
||||
# "统计数据", "视频", "数据集", "新闻", "专利", "期刊", "图书", "报告", "项目", "成果", "会议论文",
|
||||
# "政策", "外文期刊论文", "学位论文", "期刊论文", "全部论文", "原文", "全文", "pdf", "资料", " ",
|
||||
# "进展", "研究", "最新", "外文", "英文", "最新", "文件", "资料", "论文"
|
||||
# }
|
||||
# filtered_query = query
|
||||
# for word in filter_words:
|
||||
# if word in query:
|
||||
# filtered_query = filtered_query.replace(word, "")
|
||||
# print("filtered query:", filtered_query)
|
||||
|
||||
# kgo_search_type = get_kgo_search_type(query)
|
||||
# print("kgo_search_type:", kgo_search_type)
|
||||
|
||||
# 使用过滤后的查询字符串进行搜索
|
||||
return asyncio.run(search_engine_iter(query, uid))
|
||||
|
||||
|
||||
class SearchInternetInput(BaseModel):
|
||||
location: str = Field(description="Query for Internet search")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
result = search_internet("人工智能领域的政策")
|
||||
print("答案:", result)
|
||||
Reference in New Issue
Block a user