Files
gangyan/langchain-chat/server/agent/tools/rag_search.py

108 lines
5.6 KiB
Python
Raw Normal View History

import json
import re
import concurrent
from fastapi.concurrency import run_in_threadpool
from langchain.tools import YouTubeSearchTool
from pydantic import BaseModel, Field
from server.chat import utils
from server.chat.utils import doc_to_list, solve_knowledge_map,solve_mental_data,shared_variable
from server.knowledge_base.kb_doc_api import search_docs
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from configs import kb_config
def rag_search1(query: str):
"""
根据用户输入的query返回rag搜索结果
"""
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
time_based_uuid = json.loads(matches[1])["uuid"]
search = json.loads(query)
search_query = search["query"]
search_keywords = []
search_text = f"{search_query}"
if type(search["keywords"]) == list:
search_keywords = search["keywords"]
for keyword in search_keywords:
search_text += f" {keyword}"
else:
search_keywords = search["keywords"].split(",")
for keyword in search_keywords:
search_text += f" {keyword}"
result = []
source_docs = {}
knownledge_name = []
if type(search["knowledge_name"]) == list:
knownledge_name=search["knowledge_name"]
else:
knownledge_name=search["knowledge_name"].split(",")
for knownledge in knownledge_name:
if not knownledge in kb_config.CH_BASE_NAME:
knownledge_name.remove(knownledge)
if len(knownledge_name)==0:
result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
return result
# knownledge_name=kb_config.CH_BASE_NAME
knownledge_name=solve_knowledge_map(knownledge_name)
num = 0
for knownledge in knownledge_name:
source_docs[knownledge] = []
seen_docs = set()
duplicate_indices = []
doc_list = search_docs(usr_query=search_text,fileName= [],top_k=5,score_threshold=0.9,query=search_text, knowledge_base_name=knownledge)
for inum,doc in enumerate(doc_list):
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
# 从policydocs中删除重复的文档从后往前删除以防止索引错位
for index in sorted(duplicate_indices, reverse=True):
del doc_list[index]
# 处理原文来源进入数组。使用开关语句明确各个条件分支
match knownledge:
# 属于政策库分支,入参为中文政策库名称
case kb_config.DEFAULT_POLICY_BASE:
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs[knownledge])
# 属于期刊论文库分支,入参为期刊论文库的中文名称
case kb_config.DEFAULT_JOURNAL_BASE:
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
# 属于报告库分支,入参为报告库中文名称
case kb_config.DEFAULT_REPORT_BASE:
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
case kb_config.GY_NEWS_BASE:
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs[knownledge])
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
case kb_config.GY_REPORT_BASE:
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
case kb_config.GY_JOURNAL_BASE:
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
case _:
print(f"输入了没有的知识库名称")
return("输入了没有的知识库名称")
num += len(source_docs[knownledge])
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
del num
source = utils.get_shared_variable(time_based_uuid)
print(utils.get_shared_variable(time_based_uuid))
source["source_docs"]=source_docs
utils.set_shared_variable(time_based_uuid,source)
if 0<len(result)<3:
return f"当前资料:{result}\n<关键指令>搜索结果较少,更换知识库或联网思索重新搜索!!!</关键指令>"
if len(result)==0:
return "注意:【指令:更换知识库或联网思索继续搜索!!!】"
except:
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
return "当前工具异常!请换其他工具"
return f"当前资料:{result}\n<关键指令>总结此内容!!!</关键指令>"
class RagSearchInput(BaseModel):
query: str = Field(...,description="查询对象")