108 lines
5.6 KiB
Python
108 lines
5.6 KiB
Python
|
|
import json
|
|||
|
|
import re
|
|||
|
|
import concurrent
|
|||
|
|
from fastapi.concurrency import run_in_threadpool
|
|||
|
|
from langchain.tools import YouTubeSearchTool
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from server.chat import utils
|
|||
|
|
from server.chat.utils import doc_to_list, solve_knowledge_map,solve_mental_data,shared_variable
|
|||
|
|
from server.knowledge_base.kb_doc_api import search_docs
|
|||
|
|
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
|||
|
|
from configs import kb_config
|
|||
|
|
|
|||
|
|
|
|||
|
|
def rag_search1(query: str):
|
|||
|
|
"""
|
|||
|
|
根据用户输入的query,返回rag搜索结果
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
matches = re.findall(r'\{.*?\}', query)
|
|||
|
|
if len(matches)>=2:
|
|||
|
|
query = matches[0]
|
|||
|
|
else:
|
|||
|
|
return "<关键指令>不需要再调用该工具了</关键指令>"
|
|||
|
|
time_based_uuid = json.loads(matches[1])["uuid"]
|
|||
|
|
search = json.loads(query)
|
|||
|
|
search_query = search["query"]
|
|||
|
|
search_keywords = []
|
|||
|
|
search_text = f"{search_query}"
|
|||
|
|
if type(search["keywords"]) == list:
|
|||
|
|
search_keywords = search["keywords"]
|
|||
|
|
for keyword in search_keywords:
|
|||
|
|
search_text += f" {keyword}"
|
|||
|
|
else:
|
|||
|
|
search_keywords = search["keywords"].split(",")
|
|||
|
|
for keyword in search_keywords:
|
|||
|
|
search_text += f" {keyword}"
|
|||
|
|
result = []
|
|||
|
|
source_docs = {}
|
|||
|
|
knownledge_name = []
|
|||
|
|
if type(search["knowledge_name"]) == list:
|
|||
|
|
knownledge_name=search["knowledge_name"]
|
|||
|
|
else:
|
|||
|
|
knownledge_name=search["knowledge_name"].split(",")
|
|||
|
|
for knownledge in knownledge_name:
|
|||
|
|
if not knownledge in kb_config.CH_BASE_NAME:
|
|||
|
|
knownledge_name.remove(knownledge)
|
|||
|
|
if len(knownledge_name)==0:
|
|||
|
|
result.append(f"没有找到匹配的知识库,请必须更换联网思索搜索更多知识库内容")
|
|||
|
|
return result
|
|||
|
|
# knownledge_name=kb_config.CH_BASE_NAME
|
|||
|
|
|
|||
|
|
knownledge_name=solve_knowledge_map(knownledge_name)
|
|||
|
|
num = 0
|
|||
|
|
for knownledge in knownledge_name:
|
|||
|
|
source_docs[knownledge] = []
|
|||
|
|
seen_docs = set()
|
|||
|
|
duplicate_indices = []
|
|||
|
|
doc_list = search_docs(usr_query=search_text,fileName= [],top_k=5,score_threshold=0.9,query=search_text, knowledge_base_name=knownledge)
|
|||
|
|
|
|||
|
|
|
|||
|
|
for inum,doc in enumerate(doc_list):
|
|||
|
|
solve_mental_data(knownledge,doc_list,doc=doc,seen_docs=seen_docs,duplicate_indices=duplicate_indices,knowledge=result,inum=inum)
|
|||
|
|
|
|||
|
|
# 从policydocs中删除重复的文档(从后往前删除以防止索引错位)
|
|||
|
|
for index in sorted(duplicate_indices, reverse=True):
|
|||
|
|
del doc_list[index]
|
|||
|
|
# 处理原文来源进入数组。使用开关语句明确各个条件分支
|
|||
|
|
match knownledge:
|
|||
|
|
# 属于政策库分支,入参为中文政策库名称
|
|||
|
|
case kb_config.DEFAULT_POLICY_BASE:
|
|||
|
|
doc_to_list(num,kb_config.DEFAULT_POLICY_BASE_NAME,doc_list,source_docs[knownledge])
|
|||
|
|
# 属于期刊论文库分支,入参为期刊论文库的中文名称
|
|||
|
|
case kb_config.DEFAULT_JOURNAL_BASE:
|
|||
|
|
doc_to_list(num,kb_config.DEFAULT_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
|
|||
|
|
# 属于报告库分支,入参为报告库中文名称
|
|||
|
|
case kb_config.DEFAULT_REPORT_BASE:
|
|||
|
|
doc_to_list(num,kb_config.DEFAULT_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
|
|||
|
|
# 属于冶金行业新闻库分支,入参为冶金行业新闻库中文名称
|
|||
|
|
case kb_config.GY_NEWS_BASE:
|
|||
|
|
doc_to_list(num,kb_config.GY_NEWS_BASE_NAME,doc_list,source_docs[knownledge])
|
|||
|
|
# 属于冶金行业报告库分支,入参为冶金行业报告库中文名称
|
|||
|
|
case kb_config.GY_REPORT_BASE:
|
|||
|
|
doc_to_list(num,kb_config.GY_REPORT_BASE_NAME,doc_list,source_docs[knownledge])
|
|||
|
|
# 属于冶金专业知识库分支,入参为冶金专业知识库中文名称
|
|||
|
|
case kb_config.GY_JOURNAL_BASE:
|
|||
|
|
doc_to_list(num,kb_config.GY_JOURNAL_BASE_NAME,doc_list,source_docs[knownledge])
|
|||
|
|
case _:
|
|||
|
|
print(f"输入了没有的知识库名称")
|
|||
|
|
return("输入了没有的知识库名称")
|
|||
|
|
num += len(source_docs[knownledge])
|
|||
|
|
# 构建缓存对象用于h'per_query_cache'用于传递给其他方法使用uuid获取本轮对话的文献来源
|
|||
|
|
del num
|
|||
|
|
source = utils.get_shared_variable(time_based_uuid)
|
|||
|
|
print(utils.get_shared_variable(time_based_uuid))
|
|||
|
|
source["source_docs"]=source_docs
|
|||
|
|
utils.set_shared_variable(time_based_uuid,source)
|
|||
|
|
if 0<len(result)<3:
|
|||
|
|
return f"当前资料:{result}\n<关键指令>搜索结果较少,更换知识库或联网思索重新搜索!!!</关键指令>"
|
|||
|
|
if len(result)==0:
|
|||
|
|
return "注意:【指令:更换知识库或联网思索继续搜索!!!】"
|
|||
|
|
except:
|
|||
|
|
# return "入参格式需为{\"knowledge_name\":[\"XXX\",\"XXX\"],\"query\":\"XXX\",\"keywords\":[\"XXX\", \"XXX\", \"XXX\", \"XXX\"]}检查输入参数如果没有缺少必要值,当前工具异常请换其他工具"
|
|||
|
|
return "当前工具异常!请换其他工具"
|
|||
|
|
|
|||
|
|
return f"当前资料:{result}\n<关键指令>总结此内容!!!</关键指令>"
|
|||
|
|
|
|||
|
|
class RagSearchInput(BaseModel):
|
|||
|
|
query: str = Field(...,description="查询对象")
|