import asyncio
import concurrent
from datetime import datetime
import json
import logging
import re
from typing import List
from fastapi import logger
from configs import kb_config
from configs.model_config import LLM_MODELS
from server.agent.tools import duckduckgo_search
from server.agent.tools.duckduckgo_search import duckduckgo_search_iter
from server.agent.tools.knowledgebase_kgo_search import knowledgebase_kgo_iter
from server.agent.tools.rag_search import rag_search1
from server.chat import utils
from server.chat.policy_fun_iast import get_llm_model_response
from server.chat.utils import doc_to_list, get_similar_documents1, solve_knowledge_map, solve_mental_data
from server.knowledge_base.kb_doc_api import search_docs
def rag_search(query: str,uid):
"""
根据用户输入的query,返回rag搜索结果
"""
source_docs = []
try:
search = json.loads(query)
logging.info(f'模型输入: {search["query"]}')
original_query = search["query"]
search_query = get_llm_model_response(
strategy_name="rag_search_rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="rag_search_rewrite",
prompt_param_dict={"input": search["query"], "year": datetime.now().strftime("%Y")},
temperature=0.3,
max_tokens=512
)
logging.info(f'模型改写: {search_query}')
search_keywords = []
search_text = f"{search_query}"
# if type(search["keywords"]) == list:
# search_keywords = search["keywords"]
# for keyword in search_keywords:
# search_text += f" {keyword}"
# else:
# search_keywords = search["keywords"].split(",")
# for keyword in search_keywords:
# search_text += f" {keyword}"
self_database = utils.get_shared_variable(uid)
result = []
knownledge_name = []
if type(search["knowledge_name"]) == list:
knownledge_name=search["knowledge_name"]
else:
knownledge_name=search["knowledge_name"].split(",")
if "美术专业知识库" in knownledge_name:
knownledge_name.remove("美术专业知识库")
if "database" in self_database:
self_database["database"]= self_database["database"].append("p_cafa0101011")
else:
self_database["database"] = ["p_cafa0101011"]
# 添加个人知识库
if "database" in self_database:
knownledge_name.extend(self_database["database"])
knownledge_name = [knownledge for knownledge in knownledge_name
if (knownledge in kb_config.CH_BASE_NAME
or knownledge in kb_config.EN_BASE_NAME
or knownledge in getattr(kb_config, "YJ_BASE_NAME", [])
or kb_config.SELF_KNOWLEDGE_BASE.match(knownledge)
or knownledge == "coding")]
if len(knownledge_name)==0:
#result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
return result,source_docs
# knownledge_name=kb_config.CH_BASE_NAME
knownledge_name=solve_knowledge_map(knownledge_name)
#knownledge_name = ["p_c88859a3d06e4265bd01d816ef2650d1"]
num = 0
temp=utils.get_shared_variable(uid)
for knownledge in knownledge_name:
seen_docs = set()
duplicate_indices = []
# 针对中国钢铁行业动态库增加日期范围过滤
expr_param = ""
if knownledge == kb_config.STEEL_KB:
time_today = datetime.now().strftime("%Y-%m-%d")
# 调用LLM生成日期表达式,模板沿用 get_policy_time
try:
expr_candidate = get_llm_model_response(
strategy_name="get steel time",
llm_model_name=LLM_MODELS[0],
template_prompt_name="get_steel_time",
prompt_param_dict={"query": original_query, "time": time_today},
temperature=0.01,
max_tokens=512
).replace("None", "").strip()
expr_param = expr_candidate if expr_candidate else ""
except Exception as _:
expr_param = ""
doc_list = search_docs(
usr_query=original_query,
fileName=[],
top_k=20,
score_threshold=1.0,
query=search_text,
knowledge_base_name=knownledge,
expr=expr_param
)
logging.info(f"[RAG诊断] kb={knownledge!r} expr={expr_param!r} 召回 {len(doc_list)} docs")
if len(doc_list)==0:
# 修 bug: 原代码 return 导致首个 KB 空就放弃全部 KB;改 continue 继续尝试下一个
continue
titles = temp["title"]
doc_list,title = utils.remove_docs1(titles,doc_list)
titles.extend(title)
for inum,doc in enumerate(doc_list):
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
# 从policydocs中删除重复的文档(从后往前删除以防止索引错位)
for index in sorted(duplicate_indices, reverse=True):
del doc_list[index]
# 处理原文来源进入数组。使用开关语句明确各个条件分支
match knownledge:
# 属于政策库分支,入参为中文政策库名称
case kb_config.DEFAULT_POLICY_BASE:
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs)
# 属于期刊论文库分支,入参为期刊论文库的中文名称
case kb_config.DEFAULT_JOURNAL_BASE:
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs)
# 属于报告库分支,入参为报告库中文名称
case kb_config.DEFAULT_REPORT_BASE1:
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs)
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
case kb_config.GY_NEWS_BASE:
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs)
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
case kb_config.GY_REPORT_BASE:
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs)
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
case kb_config.GY_JOURNAL_BASE:
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金新闻库(2024年以及之前)
case kb_config.YJ_NEWS_BASE:
doc_to_list(num,kb_config.YJ_NEWS_BASE_NAME,doc_list,source_docs)
# 新增冶金中文期刊库
case kb_config.YJ_CH_JOURNAL_BASE:
doc_to_list(num,kb_config.YJ_CH_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金外文期刊库
case kb_config.YJ_FOR_JOURNAL_BASE:
doc_to_list(num,kb_config.YJ_FOR_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金OA期刊库
case kb_config.YJ_OA_JOURNAL_BASE:
doc_to_list(num,kb_config.YJ_OA_JOURNAL_BASE_NAME,doc_list,source_docs)
# 新增冶金政策库
case kb_config.YJ_POLICYS_BASE:
doc_to_list(num,kb_config.YJ_POLICYS_BASE_NAME,doc_list,source_docs)
# 新增中国钢铁行业动态库
case kb_config.STEEL_KB:
doc_to_list(num,kb_config.STEEL_KB_NAME,doc_list,source_docs)
# 属于个人知识库分支
case _ if kb_config.SELF_KNOWLEDGE_BASE.match(knownledge) or knownledge == "coding":
doc_to_list(num,knownledge,doc_list,source_docs)
case _:
print(f"输入了没有的知识库名称")
return "输入了没有的知识库名称",source_docs
# num += len(source_docs[knownledge])
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
# del num
# source = utils.get_shared_variable(uid)
# print(utils.get_shared_variable(uid))
# source["source_docs"]=source_docs
# utils.set_shared_variable(uid,source)
except Exception as e:
logging.error(f"Error in rag_search: {e}")
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
return "当前工具异常!请换其他工具",source_docs
return result,source_docs
def knowledgebase_kgo_search(query: str, uid) -> List[str]:
try:
res = knowledgebase_kgo_iter(query,uid)
try:
if type(res[0])==list and type(res[1])==list:
return res
elif type(res[1])==list and len(res[1])>0:
res[0]=[]
return res
else:
temp = []
temp[0]=[]
temp[1]=[]
return temp
except Exception as e:
temp = {}
logging.error(f"No docs: {e}")
temp[0]=[]
temp[1]=[]
return temp
except json.JSONDecodeError:
# 如果JSON解码失败,则返回错误消息
logging.error("Invalid JSON format in query.")
return "<关键指令>不需要再调用该工具了关键指令>"
except KeyError:
# 如果解析的JSON对象中缺少必要的键,则返回错误消息
return "请尝试调用其他工具"
except Exception as e:
# 捕获其他所有异常,并返回通用错误消息
return f"发生错误:{str(e)},请尝试调用其他工具"
def inner_duckduckgo_search(query: str, uuid:str,) :
logging.info(f"模型输入: {query}")
combined_result = asyncio.run(duckduckgo_search_iter(query, uuid, "y","default" ))
# 以标准json格式输出
logging.info("返回JSON格式的结果给到模型...")
return combined_result
def search_tool(query: str):
"""获取到uid并拆分query"""
if ""in query:
query = query.replace("","").replace("","")
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>当前工具不需要再调用关键指令>"
time_based_uuid = json.loads(matches[1])["uuid"]
search = json.loads(query)
if type(search["query"])==list and len(search["query"])>0:
searches = search["query"][0]
elif type(search["query"])==list and len(search["query"]) == 0:
searches = "无"
else:
searches = search["query"]
"""
根据用户输入的query,返回rag搜索结果
"""
try:
with concurrent.futures.ThreadPoolExecutor() as executor:
# 提交任务并发执行
test = {}
test["num"]=0
test["source_docs"]=[]
test["END"] = ""
test["title"] = []
utils.set_shared_variable(time_based_uuid+"¥",test)
# future2 = executor.submit(knowledgebase_kgo_search,search["query"],time_based_uuid+"q")
future1 = executor.submit(rag_search,query,time_based_uuid)
# if not "type" in utils.get_shared_variable(time_based_uuid):
# future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥")
if not "type" in utils.get_shared_variable(time_based_uuid):
future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥")
result3 = []
# 获取结果
result1,sourcedocs = future1.result()
# 诊断:看 rag_search 实际召回多少
logging.info(f"[RAG诊断] rag_search 返回 result1={len(result1) if isinstance(result1, list) else type(result1).__name__}, sourcedocs={len(sourcedocs) if isinstance(sourcedocs, list) else type(sourcedocs).__name__}, kb={search.get('knowledge_name')}, query={search.get('query')!r}")
result2 = {}
if "type" in utils.get_shared_variable(time_based_uuid):
result2[0] =[]
result2[1] = []
else:
result2 = future2.result()
logging.info(f"[RAG诊断] zhipu 返回 result2[0]={len(result2[0]) if isinstance(result2[0], list) else type(result2[0]).__name__}, result2[1]={len(result2[1]) if isinstance(result2[1], list) else type(result2[1]).__name__}")
# if "type" in utils.get_shared_variable(time_based_uuid):
# result2[0] =[]
# result2[1] = []
# else:
# result2 = future2.result()
# result2[0] = []
# result2[1] = []
utils.remove_shared_variable(time_based_uuid+"q")
if type(result2[1]) == list:
if type(sourcedocs) == list:
sourcedocs.extend(result2[1])
else:
sourcedocs = []
if type(result1) == list:
result1.extend(result2[0])
result3 = result1
else:
result3 = result2[0]
logging.info(f"result2:{result2[1]}")
source = []
res=[]
sources = utils.get_shared_variable(time_based_uuid)
i = sources["num"]
num = sources["num"]
for result in sourcedocs:
try:
i+=1
res3 = re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1)
if res3:
source.append(re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1))
else:
i -= 1
except Exception as e:
i -= 1
pass
# internet_search_res = f"参考资料[{len(result1)+1}-{len(source)}]:{result2[0]}"
# internet_search_res = f"参考资料:{result2[0]}"
j = sources["num"]
for result in result3:
j+=1
res.append(re.sub(r'\[\d+\]', f"[{j}]", result, count=1))
print(utils.get_shared_variable(time_based_uuid))
# sources["source_docs"]=source
sources["source_docs"].extend(source)
sources["num"]=i
# sources["END"] = "ok"
utils.set_shared_variable(time_based_uuid,sources)
logging.info(f"result1:{result1},sourcedocs:{sourcedocs}")
logging.info(f"result2:{result2}")
logging.info(f"{res}")
if len(res) ==0 and len(source)==0:
return f"尝试调整入参重新调用知识库联想工具(同一个问题调用超过三次就不要再使用知识库联想工具了,浪费时间)"
return f"<关键指令>如果你在写文章禁止在非规定位置输出参考资料关键指令>资料:{res}\n资料来源为:{source}\n 注意:如果你在根据大纲撰写文章,撰写中间部分章节禁止输出综上所述之类的影响文风的话,撰写中间部分禁止输出附录引用文献等!!!"
except Exception as e:
logging.error(f"Error occurred during search_tool execution.{e}")
return "同一个问题调用知识库联想工具超过5次就不要再调用知识库联想"