Files
gangyan/langchain-chat/server/agent/tools/rag_search.py

108 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import re
import concurrent
from fastapi.concurrency import run_in_threadpool
from langchain.tools import YouTubeSearchTool
from pydantic import BaseModel, Field
from server.chat import utils
from server.chat.utils import doc_to_list, solve_knowledge_map,solve_mental_data,shared_variable
from server.knowledge_base.kb_doc_api import search_docs
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from configs import kb_config
def rag_search1(query: str):
"""
根据用户输入的query返回rag搜索结果
"""
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
else:
return "<关键指令>不需要再调用该工具了</关键指令>"
time_based_uuid = json.loads(matches[1])["uuid"]
search = json.loads(query)
search_query = search["query"]
search_keywords = []
search_text = f"{search_query}"
if type(search["keywords"]) == list:
search_keywords = search["keywords"]
for keyword in search_keywords:
search_text += f" {keyword}"
else:
search_keywords = search["keywords"].split(",")
for keyword in search_keywords:
search_text += f" {keyword}"
result = []
source_docs = {}
knownledge_name = []
if type(search["knowledge_name"]) == list:
knownledge_name=search["knowledge_name"]
else:
knownledge_name=search["knowledge_name"].split(",")
for knownledge in knownledge_name:
if not knownledge in kb_config.CH_BASE_NAME:
knownledge_name.remove(knownledge)
if len(knownledge_name)==0:
result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
return result
# knownledge_name=kb_config.CH_BASE_NAME
knownledge_name=solve_knowledge_map(knownledge_name)
num = 0
for knownledge in knownledge_name:
source_docs[knownledge] = []
seen_docs = set()
duplicate_indices = []
doc_list = search_docs(usr_query=search_text,fileName= [],top_k=5,score_threshold=0.9,query=search_text, knowledge_base_name=knownledge)
for inum,doc in enumerate(doc_list):
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
# 从policydocs中删除重复的文档从后往前删除以防止索引错位
for index in sorted(duplicate_indices, reverse=True):
del doc_list[index]
# 处理原文来源进入数组。使用开关语句明确各个条件分支
match knownledge:
# 属于政策库分支,入参为中文政策库名称
case kb_config.DEFAULT_POLICY_BASE:
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs[knownledge])
# 属于期刊论文库分支,入参为期刊论文库的中文名称
case kb_config.DEFAULT_JOURNAL_BASE:
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
# 属于报告库分支,入参为报告库中文名称
case kb_config.DEFAULT_REPORT_BASE:
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
case kb_config.GY_NEWS_BASE:
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs[knownledge])
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
case kb_config.GY_REPORT_BASE:
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
case kb_config.GY_JOURNAL_BASE:
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
case _:
print(f"输入了没有的知识库名称")
return("输入了没有的知识库名称")
num += len(source_docs[knownledge])
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
del num
source = utils.get_shared_variable(time_based_uuid)
print(utils.get_shared_variable(time_based_uuid))
source["source_docs"]=source_docs
utils.set_shared_variable(time_based_uuid,source)
if 0<len(result)<3:
return f"当前资料:{result}\n<关键指令>搜索结果较少,更换知识库或联网思索重新搜索!!!</关键指令>"
if len(result)==0:
return "注意:【指令:更换知识库或联网思索继续搜索!!!】"
except:
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
return "当前工具异常!请换其他工具"
return f"当前资料:{result}\n<关键指令>总结此内容!!!</关键指令>"
class RagSearchInput(BaseModel):
query: str = Field(...,description="查询对象")