import asyncio import concurrent from datetime import datetime import json import logging import re from typing import List from fastapi import logger from configs import kb_config from configs.model_config import LLM_MODELS from server.agent.tools import duckduckgo_search from server.agent.tools.duckduckgo_search import duckduckgo_search_iter from server.agent.tools.knowledgebase_kgo_search import knowledgebase_kgo_iter from server.agent.tools.rag_search import rag_search1 from server.chat import utils from server.chat.policy_fun_iast import get_llm_model_response from server.chat.utils import doc_to_list, get_similar_documents1, solve_knowledge_map, solve_mental_data from server.knowledge_base.kb_doc_api import search_docs def rag_search(query: str,uid): """ 根据用户输入的query,返回rag搜索结果 """ source_docs = [] try: search = json.loads(query) logging.info(f'模型输入: {search["query"]}') original_query = search["query"] search_query = get_llm_model_response( strategy_name="rag_search_rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="rag_search_rewrite", prompt_param_dict={"input": search["query"], "year": datetime.now().strftime("%Y")}, temperature=0.3, max_tokens=512 ) logging.info(f'模型改写: {search_query}') search_keywords = [] search_text = f"{search_query}" # if type(search["keywords"]) == list: # search_keywords = search["keywords"] # for keyword in search_keywords: # search_text += f" {keyword}" # else: # search_keywords = search["keywords"].split(",") # for keyword in search_keywords: # search_text += f" {keyword}" self_database = utils.get_shared_variable(uid) result = [] knownledge_name = [] if type(search["knowledge_name"]) == list: knownledge_name=search["knowledge_name"] else: knownledge_name=search["knowledge_name"].split(",") if "美术专业知识库" in knownledge_name: knownledge_name.remove("美术专业知识库") if "database" in self_database: self_database["database"]= self_database["database"].append("p_cafa0101011") else: self_database["database"] = ["p_cafa0101011"] # 添加个人知识库 if "database" in self_database: knownledge_name.extend(self_database["database"]) knownledge_name = [knownledge for knownledge in knownledge_name if (knownledge in kb_config.CH_BASE_NAME or knownledge in kb_config.EN_BASE_NAME or knownledge in getattr(kb_config, "YJ_BASE_NAME", []) or kb_config.SELF_KNOWLEDGE_BASE.match(knownledge) or knownledge == "coding")] if len(knownledge_name)==0: #result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容") return result,source_docs # knownledge_name=kb_config.CH_BASE_NAME knownledge_name=solve_knowledge_map(knownledge_name) #knownledge_name = ["p_c88859a3d06e4265bd01d816ef2650d1"] num = 0 temp=utils.get_shared_variable(uid) for knownledge in knownledge_name: seen_docs = set() duplicate_indices = [] # 针对中国钢铁行业动态库增加日期范围过滤 expr_param = "" if knownledge == kb_config.STEEL_KB: time_today = datetime.now().strftime("%Y-%m-%d") # 调用LLM生成日期表达式,模板沿用 get_policy_time try: expr_candidate = get_llm_model_response( strategy_name="get steel time", llm_model_name=LLM_MODELS[0], template_prompt_name="get_steel_time", prompt_param_dict={"query": original_query, "time": time_today}, temperature=0.01, max_tokens=512 ).replace("None", "").strip() expr_param = expr_candidate if expr_candidate else "" except Exception as _: expr_param = "" doc_list = search_docs( usr_query=original_query, fileName=[], top_k=20, score_threshold=1.0, query=search_text, knowledge_base_name=knownledge, expr=expr_param ) logging.info(f"[RAG诊断] kb={knownledge!r} expr={expr_param!r} 召回 {len(doc_list)} docs") if len(doc_list)==0: # 修 bug: 原代码 return 导致首个 KB 空就放弃全部 KB;改 continue 继续尝试下一个 continue titles = temp["title"] doc_list,title = utils.remove_docs1(titles,doc_list) titles.extend(title) for inum,doc in enumerate(doc_list): solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum) # 从policydocs中删除重复的文档(从后往前删除以防止索引错位) for index in sorted(duplicate_indices, reverse=True): del doc_list[index] # 处理原文来源进入数组。使用开关语句明确各个条件分支 match knownledge: # 属于政策库分支,入参为中文政策库名称 case kb_config.DEFAULT_POLICY_BASE: doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs) # 属于期刊论文库分支,入参为期刊论文库的中文名称 case kb_config.DEFAULT_JOURNAL_BASE: doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs) # 属于报告库分支,入参为报告库中文名称 case kb_config.DEFAULT_REPORT_BASE1: doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs) # 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称 case kb_config.GY_NEWS_BASE: doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs) # 属于冶金行业报告库分支,入参为冶金行业报告库中文名称 case kb_config.GY_REPORT_BASE: doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs) # 属于冶金专业知识库分支,入参为冶金专业知识库中文名称 case kb_config.GY_JOURNAL_BASE: doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs) # 新增冶金新闻库(2024年以及之前) case kb_config.YJ_NEWS_BASE: doc_to_list(num,kb_config.YJ_NEWS_BASE_NAME,doc_list,source_docs) # 新增冶金中文期刊库 case kb_config.YJ_CH_JOURNAL_BASE: doc_to_list(num,kb_config.YJ_CH_JOURNAL_BASE_NAME,doc_list,source_docs) # 新增冶金外文期刊库 case kb_config.YJ_FOR_JOURNAL_BASE: doc_to_list(num,kb_config.YJ_FOR_JOURNAL_BASE_NAME,doc_list,source_docs) # 新增冶金OA期刊库 case kb_config.YJ_OA_JOURNAL_BASE: doc_to_list(num,kb_config.YJ_OA_JOURNAL_BASE_NAME,doc_list,source_docs) # 新增冶金政策库 case kb_config.YJ_POLICYS_BASE: doc_to_list(num,kb_config.YJ_POLICYS_BASE_NAME,doc_list,source_docs) # 新增中国钢铁行业动态库 case kb_config.STEEL_KB: doc_to_list(num,kb_config.STEEL_KB_NAME,doc_list,source_docs) # 属于个人知识库分支 case _ if kb_config.SELF_KNOWLEDGE_BASE.match(knownledge) or knownledge == "coding": doc_to_list(num,knownledge,doc_list,source_docs) case _: print(f"输入了没有的知识库名称") return "输入了没有的知识库名称",source_docs # num += len(source_docs[knownledge]) # 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源 # del num # source = utils.get_shared_variable(uid) # print(utils.get_shared_variable(uid)) # source["source_docs"]=source_docs # utils.set_shared_variable(uid,source) except Exception as e: logging.error(f"Error in rag_search: {e}") # return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具" return "当前工具异常!请换其他工具",source_docs return result,source_docs def knowledgebase_kgo_search(query: str, uid) -> List[str]: try: res = knowledgebase_kgo_iter(query,uid) try: if type(res[0])==list and type(res[1])==list: return res elif type(res[1])==list and len(res[1])>0: res[0]=[] return res else: temp = [] temp[0]=[] temp[1]=[] return temp except Exception as e: temp = {} logging.error(f"No docs: {e}") temp[0]=[] temp[1]=[] return temp except json.JSONDecodeError: # 如果JSON解码失败,则返回错误消息 logging.error("Invalid JSON format in query.") return "<关键指令>不需要再调用该工具了" except KeyError: # 如果解析的JSON对象中缺少必要的键,则返回错误消息 return "请尝试调用其他工具" except Exception as e: # 捕获其他所有异常,并返回通用错误消息 return f"发生错误:{str(e)},请尝试调用其他工具" def inner_duckduckgo_search(query: str, uuid:str,) : logging.info(f"模型输入: {query}") combined_result = asyncio.run(duckduckgo_search_iter(query, uuid, "y","default" )) # 以标准json格式输出 logging.info("返回JSON格式的结果给到模型...") return combined_result def search_tool(query: str): """获取到uid并拆分query""" if ""in query: query = query.replace("","").replace("","") matches = re.findall(r'\{.*?\}', query) if len(matches)>=2: query = matches[0] else: return "<关键指令>当前工具不需要再调用" time_based_uuid = json.loads(matches[1])["uuid"] search = json.loads(query) if type(search["query"])==list and len(search["query"])>0: searches = search["query"][0] elif type(search["query"])==list and len(search["query"]) == 0: searches = "无" else: searches = search["query"] """ 根据用户输入的query,返回rag搜索结果 """ try: with concurrent.futures.ThreadPoolExecutor() as executor: # 提交任务并发执行 test = {} test["num"]=0 test["source_docs"]=[] test["END"] = "" test["title"] = [] utils.set_shared_variable(time_based_uuid+"¥",test) # future2 = executor.submit(knowledgebase_kgo_search,search["query"],time_based_uuid+"q") future1 = executor.submit(rag_search,query,time_based_uuid) # if not "type" in utils.get_shared_variable(time_based_uuid): # future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥") if not "type" in utils.get_shared_variable(time_based_uuid): future2 = executor.submit(knowledgebase_kgo_search,searches,time_based_uuid+"¥") result3 = [] # 获取结果 result1,sourcedocs = future1.result() # 诊断:看 rag_search 实际召回多少 logging.info(f"[RAG诊断] rag_search 返回 result1={len(result1) if isinstance(result1, list) else type(result1).__name__}, sourcedocs={len(sourcedocs) if isinstance(sourcedocs, list) else type(sourcedocs).__name__}, kb={search.get('knowledge_name')}, query={search.get('query')!r}") result2 = {} if "type" in utils.get_shared_variable(time_based_uuid): result2[0] =[] result2[1] = [] else: result2 = future2.result() logging.info(f"[RAG诊断] zhipu 返回 result2[0]={len(result2[0]) if isinstance(result2[0], list) else type(result2[0]).__name__}, result2[1]={len(result2[1]) if isinstance(result2[1], list) else type(result2[1]).__name__}") # if "type" in utils.get_shared_variable(time_based_uuid): # result2[0] =[] # result2[1] = [] # else: # result2 = future2.result() # result2[0] = [] # result2[1] = [] utils.remove_shared_variable(time_based_uuid+"q") if type(result2[1]) == list: if type(sourcedocs) == list: sourcedocs.extend(result2[1]) else: sourcedocs = [] if type(result1) == list: result1.extend(result2[0]) result3 = result1 else: result3 = result2[0] logging.info(f"result2:{result2[1]}") source = [] res=[] sources = utils.get_shared_variable(time_based_uuid) i = sources["num"] num = sources["num"] for result in sourcedocs: try: i+=1 res3 = re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1) if res3: source.append(re.sub(r'\[\d+\]', f"[{i}]", result.replace("\n",""), count=1)) else: i -= 1 except Exception as e: i -= 1 pass # internet_search_res = f"参考资料[{len(result1)+1}-{len(source)}]:{result2[0]}" # internet_search_res = f"参考资料:{result2[0]}" j = sources["num"] for result in result3: j+=1 res.append(re.sub(r'\[\d+\]', f"[{j}]", result, count=1)) print(utils.get_shared_variable(time_based_uuid)) # sources["source_docs"]=source sources["source_docs"].extend(source) sources["num"]=i # sources["END"] = "ok" utils.set_shared_variable(time_based_uuid,sources) logging.info(f"result1:{result1},sourcedocs:{sourcedocs}") logging.info(f"result2:{result2}") logging.info(f"{res}") if len(res) ==0 and len(source)==0: return f"尝试调整入参重新调用知识库联想工具(同一个问题调用超过三次就不要再使用知识库联想工具了,浪费时间)" return f"<关键指令>如果你在写文章禁止在非规定位置输出参考资料资料:{res}\n资料来源为:{source}\n 注意:如果你在根据大纲撰写文章,撰写中间部分章节禁止输出综上所述之类的影响文风的话,撰写中间部分禁止输出附录引用文献等!!!" except Exception as e: logging.error(f"Error occurred during search_tool execution.{e}") return "同一个问题调用知识库联想工具超过5次就不要再调用知识库联想"