Files
gangyan/langchain-chat/server/agent/tools/search_tool.py

331 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import concurrent
from datetime import datetime
import json
import logging
import re
from typing import List
from fastapi import logger
from configs import kb_config
from configs.model_config import LLM_MODELS
from server.agent.tools import duckduckgo_search
from server.agent.tools.duckduckgo_search import duckduckgo_search_iter
from server.agent.tools.knowledgebase_kgo_search import knowledgebase_kgo_iter
from server.agent.tools.rag_search import rag_search1
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
from server.chat.utils import doc_to_list, get_similar_documents1, solve_knowledge_map, solve_mental_data
from server.knowledge_base.kb_doc_api import search_docs
def rag_search(query: str,uid):
"""
根据用户输入的query返回rag搜索结果
"""
source_docs = []
try:
search = json.loads(query)
logging.info(f'模型输入: {search["query"]}')
original_query = search["query"]
search_query = get_llm_model_response(
strategy_name="rag_search_rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="rag_search_rewrite",
prompt_param_dict={"input": search["query"], "year": datetime.now().strftime("%Y")},
temperature=0.3,
max_tokens=512
)
logging.info(f'模型改写: {search_query}')
search_keywords = []
search_text = f"{search_query}"
# if type(search["keywords"]) == list:
# search_keywords = search["keywords"]
# for keyword in search_keywords:
# search_text += f" {keyword}"
# else:
# search_keywords = search["keywords"].split(",")
# for keyword in search_keywords:
# search_text += f" {keyword}"
self_database = utils.get_shared_variable(uid)
result = []
knownledge_name = []
if type(search["knowledge_name"]) == list:
knownledge_name=search["knowledge_name"]
else:
knownledge_name=search["knowledge_name"].split(",")
if "美术专业知识库" in knownledge_name:
knownledge_name.remove("美术专业知识库")
if "database" in self_database:
self_database["database"]= self_database["database"].append("p_cafa0101011")
else:
self_database["database"] = ["p_cafa0101011"]
# 添加个人知识库
if "database" in self_database:
knownledge_name.extend(self_database["database"])
knownledge_name = [knownledge for knownledge in knownledge_name
if (knownledge in kb_config.CH_BASE_NAME
or knownledge in kb_config.EN_BASE_NAME
or knownledge in getattr(kb_config, "YJ_BASE_NAME", [])
or kb_config.SELF_KNOWLEDGE_BASE.match(knownledge)
or knownledge == "coding")]
if len(knownledge_name)==0:
#result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
return result,source_docs
# knownledge_name=kb_config.CH_BASE_NAME
knownledge_name=solve_knowledge_map(knownledge_name)
#knownledge_name = ["p_c88859a3d06e4265bd01d816ef2650d1"]
num = 0
temp=utils.get_shared_variable(uid)
for knownledge in knownledge_name:
seen_docs = set()
duplicate_indices = []
# 针对中国钢铁行业动态库增加日期范围过滤
expr_param = ""
if knownledge == kb_config.STEEL_KB:
time_today = datetime.now().strftime("%Y-%m-%d")
# 调用LLM生成日期表达式模板沿用 get_policy_time
try:
expr_candidate = get_llm_model_response(
strategy_name="get steel time",
llm_model_name=LLM_MODELS[0],
template_prompt_name="get_steel_time",
prompt_param_dict={"query": original_query, "time": time_today},
temperature=0.01,
max_tokens=512
).replace("None", "").strip()
expr_param = expr_candidate if expr_candidate else ""
except Exception as _:
expr_param = ""
doc_list = search_docs(
usr_query=original_query,
fileName=[],
top_k=20,
score_threshold=1.0,
query=search_text,
knowledge_base_name=knownledge,
expr=expr_param
)
if len(doc_list)==0:
return result,source_docs
titles = temp["title"]
doc_list,title = utils.remove_docs1(titles,doc_list)
titles.extend(title)
for inum,doc in enumerate(doc_list):
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
# 从policydocs中删除重复的文档从后往前删除以防止索引错位
for index in sorted(duplicate_indices, reverse=True):
del doc_list[index]
# 处理原文来源进入数组。使用开关语句明确各个条件分支
match knownledge:
# 属于政策库分支,入参为中文政策库名称
case kb_config.DEFAULT_POLICY_BASE:
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs)
# 属于期刊论文库分支,入参为期刊论文库的中文名称
case kb_config.DEFAULT_JOURNAL_BASE:
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs)
# 属于报告库分支,入参为报告库中文名称
case kb_config.DEFAULT_REPORT_BASE1:
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs)
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
case kb_config.GY_NEWS_BASE:
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs)
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
case kb_config.GY_REPORT_BASE:
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs)
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
case kb_config.GY_JOURNAL_BASE:
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金新闻库2024年以及之前
case kb_config.YJ_NEWS_BASE:
doc_to_list(num,kb_config.YJ_NEWS_BASE_NAME,doc_list,source_docs)
# 新增冶金中文期刊库
case kb_config.YJ_CH_JOURNAL_BASE:
doc_to_list(num,kb_config.YJ_CH_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金外文期刊库
case kb_config.YJ_FOR_JOURNAL_BASE:
doc_to_list(num,kb_config.YJ_FOR_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金OA期刊库
case kb_config.YJ_OA_JOURNAL_BASE:
doc_to_list(num,kb_config.YJ_OA_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金政策库
case kb_config.YJ_POLICYS_BASE:
doc_to_list(num,kb_config.YJ_POLICYS_BASE_NAME,doc_list,source_docs)
# 新增中国钢铁行业动态库
case kb_config.STEEL_KB:
doc_to_list(num,kb_config.STEEL_KB_NAME,doc_list,source_docs)
# 属于个人知识库分支
case _ if kb_config.SELF_KNOWLEDGE_BASE.match(knownledge) or knownledge == "coding":
doc_to_list(num,knownledge,doc_list,source_docs)
case _:
print(f"输入了没有的知识库名称")
return "输入了没有的知识库名称",source_docs
# num += len(source_docs[knownledge])
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
# del num
# source = utils.get_shared_variable(uid)
# print(utils.get_shared_variable(uid))
# source["source_docs"]=source_docs
# utils.set_shared_variable(uid,source)
except Exception as e:
logging.error(f"Error in rag_search: {e}")
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
return "当前工具异常!请换其他工具",source_docs
return result,source_docs
def knowledgebase_kgo_search(query: str, uid) -> List[str]:
try:
res = knowledgebase_kgo_iter(query,uid)
try:
if type(res[0])==list and type(res[1])==list:
return res
elif type(res[1])==list and len(res[1])>0:
res[0]=[]
return res
else:
temp = []
temp[0]=[]
temp[1]=[]
return temp
except Exception as e:
temp = {}
logging.error(f"No docs: {e}")
temp[0]=[]
temp[1]=[]
return temp
except json.JSONDecodeError:
# 如果JSON解码失败则返回错误消息
logging.error("Invalid JSON format in query.")
return "<关键指令>不需要再调用该工具了</关键指令>"
except KeyError:
# 如果解析的JSON对象中缺少必要的键则返回错误消息
return "请尝试调用其他工具"
except Exception as e:
# 捕获其他所有异常,并返回通用错误消息
return f"发生错误:{str(e)},请尝试调用其他工具"
def inner_duckduckgo_search(query: str, uuid:str,) :
logging.info(f"模型输入: {query}")
combined_result = asyncio.run(duckduckgo_search_iter(query, uuid, "y","default" ))
# 以标准json格式输出
logging.info("返回JSON格式的结果给到模型...")
return combined_result
def search_tool(query: str):
"""获取到uid并拆分query"""
if "<param>"in query:
query = query.replace("<param>","").replace("</param>","")
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>当前工具不需要再调用</关键指令>"
time_based_uuid = json.loads(matches[1])["uuid"]
search = json.loads(query)
if type(search["query"])==list and len(search["query"])>0:
searches = search["query"][0]
elif type(search["query"])==list and len(search["query"]) == 0:
searches = ""
else:
searches = search["query"]
"""
根据用户输入的query返回rag搜索结果
"""
try:
with concurrent.futures.ThreadPoolExecutor() as executor:
# 提交任务并发执行
test = {}
test["num"]=0
test["source_docs"]=[]
test["END"] = ""
test["title"] = []
utils.set_shared_variable(time_based_uuid+"",test)
# future2 = executor.submit(knowledgebase_kgo_search,search["query"],time_based_uuid+"q")
future1 = executor.submit(rag_search,query,time_based_uuid)
# if not "type" in utils.get_shared_variable(time_based_uuid):
# future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥")
if not "type" in utils.get_shared_variable(time_based_uuid):
future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"")
result3 = []
# 获取结果
result1,sourcedocs = future1.result()
result2 = {}
if "type" in utils.get_shared_variable(time_based_uuid):
result2[0] =[]
result2[1] = []
else:
result2 = future2.result()
# if "type" in utils.get_shared_variable(time_based_uuid):
# result2[0] =[]
# result2[1] = []
# else:
# result2 = future2.result()
# result2[0] = []
# result2[1] = []
utils.remove_shared_variable(time_based_uuid+"q")
if type(result2[1]) == list:
if type(sourcedocs) == list:
sourcedocs.extend(result2[1])
else:
sourcedocs = []
if type(result1) == list:
result1.extend(result2[0])
result3 = result1
else:
result3 = result2[0]
logging.info(f"result2:{result2[1]}")
source = []
res=[]
sources = utils.get_shared_variable(time_based_uuid)
i = sources["num"]
num = sources["num"]
for result in sourcedocs:
try:
i+=1
res3 = re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1)
if res3:
source.append(re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1))
else:
i -= 1
except Exception as e:
i -= 1
pass
# internet_search_res = f"参考资料[{len(result1)+1}-{len(source)}]:{result2[0]}"
# internet_search_res = f"参考资料:{result2[0]}"
j = sources["num"]
for result in result3:
j+=1
res.append(re.sub(r'\[\d+\]', f"[{j}]", result, count=1))
print(utils.get_shared_variable(time_based_uuid))
# sources["source_docs"]=source
sources["source_docs"].extend(source)
sources["num"]=i
# sources["END"] = "ok"
utils.set_shared_variable(time_based_uuid,sources)
logging.info(f"result1:{result1},sourcedocs:{sourcedocs}")
logging.info(f"result2:{result2}")
logging.info(f"{res}")
if len(res) ==0 and len(source)==0:
return f"尝试调整入参重新调用知识库联想工具(同一个问题调用超过三次就不要再使用知识库联想工具了,浪费时间)"
return f"<关键指令>如果你在写文章禁止在非规定位置输出参考资料</关键指令>资料:{res}\n资料来源为:{source}\n 注意:如果你在根据大纲撰写文章,撰写中间部分章节禁止输出综上所述之类的影响文风的话,撰写中间部分禁止输出附录引用文献等!!!"
except Exception as e:
logging.error(f"Error occurred during search_tool execution.{e}")
return "同一个问题调用知识库联想工具超过5次就不要再调用知识库联想"