[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
331
langchain-chat/server/agent/tools/search_tool.py
Normal file
331
langchain-chat/server/agent/tools/search_tool.py
Normal file
@@ -0,0 +1,331 @@
|
||||
import asyncio
|
||||
import concurrent
|
||||
from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
from fastapi import logger
|
||||
|
||||
from configs import kb_config
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.agent.tools import duckduckgo_search
|
||||
from server.agent.tools.duckduckgo_search import duckduckgo_search_iter
|
||||
from server.agent.tools.knowledgebase_kgo_search import knowledgebase_kgo_iter
|
||||
from server.agent.tools.rag_search import rag_search1
|
||||
from server.chat import utils
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from server.chat.utils import doc_to_list, get_similar_documents1, solve_knowledge_map, solve_mental_data
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
|
||||
|
||||
|
||||
def rag_search(query: str,uid):
|
||||
"""
|
||||
根据用户输入的query,返回rag搜索结果
|
||||
"""
|
||||
source_docs = []
|
||||
try:
|
||||
search = json.loads(query)
|
||||
logging.info(f'模型输入: {search["query"]}')
|
||||
original_query = search["query"]
|
||||
search_query = get_llm_model_response(
|
||||
strategy_name="rag_search_rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="rag_search_rewrite",
|
||||
prompt_param_dict={"input": search["query"], "year": datetime.now().strftime("%Y")},
|
||||
temperature=0.3,
|
||||
max_tokens=512
|
||||
)
|
||||
logging.info(f'模型改写: {search_query}')
|
||||
search_keywords = []
|
||||
search_text = f"{search_query}"
|
||||
# if type(search["keywords"]) == list:
|
||||
# search_keywords = search["keywords"]
|
||||
# for keyword in search_keywords:
|
||||
# search_text += f" {keyword}"
|
||||
# else:
|
||||
# search_keywords = search["keywords"].split(",")
|
||||
# for keyword in search_keywords:
|
||||
# search_text += f" {keyword}"
|
||||
self_database = utils.get_shared_variable(uid)
|
||||
result = []
|
||||
|
||||
knownledge_name = []
|
||||
if type(search["knowledge_name"]) == list:
|
||||
knownledge_name=search["knowledge_name"]
|
||||
else:
|
||||
knownledge_name=search["knowledge_name"].split(",")
|
||||
if "美术专业知识库" in knownledge_name:
|
||||
knownledge_name.remove("美术专业知识库")
|
||||
if "database" in self_database:
|
||||
self_database["database"]= self_database["database"].append("p_cafa0101011")
|
||||
else:
|
||||
self_database["database"] = ["p_cafa0101011"]
|
||||
# 添加个人知识库
|
||||
if "database" in self_database:
|
||||
knownledge_name.extend(self_database["database"])
|
||||
knownledge_name = [knownledge for knownledge in knownledge_name
|
||||
if (knownledge in kb_config.CH_BASE_NAME
|
||||
or knownledge in kb_config.EN_BASE_NAME
|
||||
or knownledge in getattr(kb_config, "YJ_BASE_NAME", [])
|
||||
or kb_config.SELF_KNOWLEDGE_BASE.match(knownledge)
|
||||
or knownledge == "coding")]
|
||||
if len(knownledge_name)==0:
|
||||
#result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
|
||||
return result,source_docs
|
||||
# knownledge_name=kb_config.CH_BASE_NAME
|
||||
|
||||
knownledge_name=solve_knowledge_map(knownledge_name)
|
||||
#knownledge_name = ["p_c88859a3d06e4265bd01d816ef2650d1"]
|
||||
num = 0
|
||||
temp=utils.get_shared_variable(uid)
|
||||
for knownledge in knownledge_name:
|
||||
seen_docs = set()
|
||||
duplicate_indices = []
|
||||
# 针对中国钢铁行业动态库增加日期范围过滤
|
||||
expr_param = ""
|
||||
if knownledge == kb_config.STEEL_KB:
|
||||
time_today = datetime.now().strftime("%Y-%m-%d")
|
||||
# 调用LLM生成日期表达式,模板沿用 get_policy_time
|
||||
try:
|
||||
expr_candidate = get_llm_model_response(
|
||||
strategy_name="get steel time",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="get_steel_time",
|
||||
prompt_param_dict={"query": original_query, "time": time_today},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
).replace("None", "").strip()
|
||||
expr_param = expr_candidate if expr_candidate else ""
|
||||
except Exception as _:
|
||||
expr_param = ""
|
||||
|
||||
doc_list = search_docs(
|
||||
usr_query=original_query,
|
||||
fileName=[],
|
||||
top_k=20,
|
||||
score_threshold=1.0,
|
||||
query=search_text,
|
||||
knowledge_base_name=knownledge,
|
||||
expr=expr_param
|
||||
)
|
||||
|
||||
if len(doc_list)==0:
|
||||
return result,source_docs
|
||||
titles = temp["title"]
|
||||
doc_list,title = utils.remove_docs1(titles,doc_list)
|
||||
titles.extend(title)
|
||||
for inum,doc in enumerate(doc_list):
|
||||
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
|
||||
|
||||
# 从policydocs中删除重复的文档(从后往前删除以防止索引错位)
|
||||
for index in sorted(duplicate_indices, reverse=True):
|
||||
del doc_list[index]
|
||||
# 处理原文来源进入数组。使用开关语句明确各个条件分支
|
||||
match knownledge:
|
||||
# 属于政策库分支,入参为中文政策库名称
|
||||
case kb_config.DEFAULT_POLICY_BASE:
|
||||
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs)
|
||||
# 属于期刊论文库分支,入参为期刊论文库的中文名称
|
||||
case kb_config.DEFAULT_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 属于报告库分支,入参为报告库中文名称
|
||||
case kb_config.DEFAULT_REPORT_BASE1:
|
||||
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs)
|
||||
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
|
||||
case kb_config.GY_NEWS_BASE:
|
||||
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs)
|
||||
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
|
||||
case kb_config.GY_REPORT_BASE:
|
||||
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs)
|
||||
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
|
||||
case kb_config.GY_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金新闻库(2024年以及之前)
|
||||
case kb_config.YJ_NEWS_BASE:
|
||||
doc_to_list(num,kb_config.YJ_NEWS_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金中文期刊库
|
||||
case kb_config.YJ_CH_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.YJ_CH_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金外文期刊库
|
||||
case kb_config.YJ_FOR_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.YJ_FOR_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金OA期刊库
|
||||
case kb_config.YJ_OA_JOURNAL_BASE:
|
||||
doc_to_list(num,kb_config.YJ_OA_JOURNAL_BASE_NAME,doc_list,source_docs)
|
||||
# 新增冶金政策库
|
||||
case kb_config.YJ_POLICYS_BASE:
|
||||
doc_to_list(num,kb_config.YJ_POLICYS_BASE_NAME,doc_list,source_docs)
|
||||
# 新增中国钢铁行业动态库
|
||||
case kb_config.STEEL_KB:
|
||||
doc_to_list(num,kb_config.STEEL_KB_NAME,doc_list,source_docs)
|
||||
# 属于个人知识库分支
|
||||
case _ if kb_config.SELF_KNOWLEDGE_BASE.match(knownledge) or knownledge == "coding":
|
||||
doc_to_list(num,knownledge,doc_list,source_docs)
|
||||
case _:
|
||||
print(f"输入了没有的知识库名称")
|
||||
return "输入了没有的知识库名称",source_docs
|
||||
# num += len(source_docs[knownledge])
|
||||
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
|
||||
# del num
|
||||
# source = utils.get_shared_variable(uid)
|
||||
# print(utils.get_shared_variable(uid))
|
||||
# source["source_docs"]=source_docs
|
||||
# utils.set_shared_variable(uid,source)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in rag_search: {e}")
|
||||
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
|
||||
return "当前工具异常!请换其他工具",source_docs
|
||||
|
||||
return result,source_docs
|
||||
|
||||
|
||||
|
||||
|
||||
def knowledgebase_kgo_search(query: str, uid) -> List[str]:
|
||||
try:
|
||||
res = knowledgebase_kgo_iter(query,uid)
|
||||
try:
|
||||
if type(res[0])==list and type(res[1])==list:
|
||||
return res
|
||||
elif type(res[1])==list and len(res[1])>0:
|
||||
res[0]=[]
|
||||
return res
|
||||
else:
|
||||
temp = []
|
||||
temp[0]=[]
|
||||
temp[1]=[]
|
||||
return temp
|
||||
except Exception as e:
|
||||
temp = {}
|
||||
logging.error(f"No docs: {e}")
|
||||
temp[0]=[]
|
||||
temp[1]=[]
|
||||
return temp
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解码失败,则返回错误消息
|
||||
logging.error("Invalid JSON format in query.")
|
||||
return "<关键指令>不需要再调用该工具了</关键指令>"
|
||||
except KeyError:
|
||||
# 如果解析的JSON对象中缺少必要的键,则返回错误消息
|
||||
return "请尝试调用其他工具"
|
||||
except Exception as e:
|
||||
# 捕获其他所有异常,并返回通用错误消息
|
||||
return f"发生错误:{str(e)},请尝试调用其他工具"
|
||||
|
||||
|
||||
def inner_duckduckgo_search(query: str, uuid:str,) :
|
||||
logging.info(f"模型输入: {query}")
|
||||
combined_result = asyncio.run(duckduckgo_search_iter(query, uuid, "y","default" ))
|
||||
# 以标准json格式输出
|
||||
logging.info("返回JSON格式的结果给到模型...")
|
||||
return combined_result
|
||||
|
||||
|
||||
def search_tool(query: str):
|
||||
"""获取到uid并拆分query"""
|
||||
if "<param>"in query:
|
||||
query = query.replace("<param>","").replace("</param>","")
|
||||
matches = re.findall(r'\{.*?\}', query)
|
||||
if len(matches)>=2:
|
||||
query = matches[0]
|
||||
else:
|
||||
return "<关键指令>当前工具不需要再调用</关键指令>"
|
||||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||||
search = json.loads(query)
|
||||
if type(search["query"])==list and len(search["query"])>0:
|
||||
searches = search["query"][0]
|
||||
elif type(search["query"])==list and len(search["query"]) == 0:
|
||||
searches = "无"
|
||||
else:
|
||||
searches = search["query"]
|
||||
"""
|
||||
根据用户输入的query,返回rag搜索结果
|
||||
"""
|
||||
|
||||
try:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# 提交任务并发执行
|
||||
|
||||
test = {}
|
||||
test["num"]=0
|
||||
test["source_docs"]=[]
|
||||
test["END"] = ""
|
||||
test["title"] = []
|
||||
utils.set_shared_variable(time_based_uuid+"¥",test)
|
||||
# future2 = executor.submit(knowledgebase_kgo_search,search["query"],time_based_uuid+"q")
|
||||
future1 = executor.submit(rag_search,query,time_based_uuid)
|
||||
# if not "type" in utils.get_shared_variable(time_based_uuid):
|
||||
# future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥")
|
||||
if not "type" in utils.get_shared_variable(time_based_uuid):
|
||||
future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥")
|
||||
result3 = []
|
||||
# 获取结果
|
||||
result1,sourcedocs = future1.result()
|
||||
result2 = {}
|
||||
if "type" in utils.get_shared_variable(time_based_uuid):
|
||||
result2[0] =[]
|
||||
result2[1] = []
|
||||
else:
|
||||
result2 = future2.result()
|
||||
# if "type" in utils.get_shared_variable(time_based_uuid):
|
||||
# result2[0] =[]
|
||||
# result2[1] = []
|
||||
# else:
|
||||
# result2 = future2.result()
|
||||
# result2[0] = []
|
||||
# result2[1] = []
|
||||
utils.remove_shared_variable(time_based_uuid+"q")
|
||||
if type(result2[1]) == list:
|
||||
if type(sourcedocs) == list:
|
||||
sourcedocs.extend(result2[1])
|
||||
else:
|
||||
sourcedocs = []
|
||||
if type(result1) == list:
|
||||
result1.extend(result2[0])
|
||||
result3 = result1
|
||||
else:
|
||||
result3 = result2[0]
|
||||
|
||||
logging.info(f"result2:{result2[1]}")
|
||||
source = []
|
||||
res=[]
|
||||
sources = utils.get_shared_variable(time_based_uuid)
|
||||
i = sources["num"]
|
||||
num = sources["num"]
|
||||
for result in sourcedocs:
|
||||
try:
|
||||
i+=1
|
||||
res3 = re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1)
|
||||
if res3:
|
||||
source.append(re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1))
|
||||
else:
|
||||
i -= 1
|
||||
except Exception as e:
|
||||
i -= 1
|
||||
pass
|
||||
# internet_search_res = f"参考资料[{len(result1)+1}-{len(source)}]:{result2[0]}"
|
||||
# internet_search_res = f"参考资料:{result2[0]}"
|
||||
j = sources["num"]
|
||||
for result in result3:
|
||||
j+=1
|
||||
res.append(re.sub(r'\[\d+\]', f"[{j}]", result, count=1))
|
||||
|
||||
print(utils.get_shared_variable(time_based_uuid))
|
||||
# sources["source_docs"]=source
|
||||
sources["source_docs"].extend(source)
|
||||
sources["num"]=i
|
||||
# sources["END"] = "ok"
|
||||
utils.set_shared_variable(time_based_uuid,sources)
|
||||
logging.info(f"result1:{result1},sourcedocs:{sourcedocs}")
|
||||
logging.info(f"result2:{result2}")
|
||||
logging.info(f"{res}")
|
||||
if len(res) ==0 and len(source)==0:
|
||||
return f"尝试调整入参重新调用知识库联想工具(同一个问题调用超过三次就不要再使用知识库联想工具了,浪费时间)"
|
||||
return f"<关键指令>如果你在写文章禁止在非规定位置输出参考资料</关键指令>资料:{res}\n资料来源为:{source}\n 注意:如果你在根据大纲撰写文章,撰写中间部分章节禁止输出综上所述之类的影响文风的话,撰写中间部分禁止输出附录引用文献等!!!"
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred during search_tool execution.{e}")
|
||||
return "同一个问题调用知识库联想工具超过5次就不要再调用知识库联想"
|
||||
|
||||
Reference in New Issue
Block a user