Files
gangyan/langchain-chat/server/agent/tools/knowledgebase_kgo_search.py

171 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import re
from typing import List, Any, Union
import concurrent
from pydantic import BaseModel, Field
from difflib import SequenceMatcher
from configs import (VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
DEFAULT_POLICY_BASE)
from server.agent.tools import search_internet
from server.chat import utils
from server.knowledge_base.kb_doc_api import search_docs
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.utils import BaseResponse
class KnowledgeKgoInput(BaseModel):
location: str = Field(description="Query for Internet search")
def preprocess_text(text: str) -> str:
# 去除空格和特殊符号
text = re.sub(r'[\s\W]', '', text.strip())
return text
def knowledge_temperature(a: str, b: str) -> float:
# 使用difflib中的SequenceMatcher计算相似度
return SequenceMatcher(None, a, b).ratio()
# def knowledgebase_kgo_iter(query: str,
# fileName: List = [],
# knowledge_base_name: str = DEFAULT_POLICY_BASE,
# top_k: int = VECTOR_SEARCH_TOP_K,
# score_threshold: float = SCORE_THRESHOLD) -> BaseResponse | list[str] | Any:
# kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
# if kb is None:
# return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
# query = query.strip()
# docs = search_docs(fileName=fileName,
# query=query,
# knowledge_base_name=knowledge_base_name,
# top_k=top_k,
# score_threshold=score_threshold)
# # 预处理查询文本
# processed_query = preprocess_text(query).replace("Observ","")
# print("processed_query:", processed_query)
# knowledge_docs = []
# knowledge_content = []
# # 知识库返回的文档与query的相似度
# if docs:
# for enum, doc in enumerate(docs):
# filename = doc.metadata.get("title")
# detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html"
# if filename:
# text = f"""政策类资料[{enum + 1}]: [{filename}]({detail_url})\n"""
# else:
# text = f"""政策类资料[{enum + 1}]: [{"原文地址"}]({detail_url})\n"""
# knowledge_docs.append(text)
# # print("knowledge_docs:", knowledge_docs)
# knowledge_content = [doc.page_content for doc in docs]
# # print("knowledge_content:", knowledge_content)
# # 计算知识库返回的文档与query的相似度
# titles = [doc.metadata.get("title") for doc in docs]
# print("titles:", titles)
# def check_similarity_threshold(titles: List[str], query: str, knowledge_docs: List[str], knowledge_content: List[str]) -> Union[
# List[str], None]:
# # 用于记录是否存在相似度大于0.55的标题
# has_similar_title = False
# for title in titles:
# processed_title = preprocess_text(title)
# similarity = knowledge_temperature(processed_query, processed_title)
# print("processed_title:", processed_title)
# print("similarity:", similarity)
# if similarity >= 0.55:
# has_similar_title = True
# break
# # 如果存在相似度大于0.55的标题,则直接返回 knowledge_docs
# if has_similar_title:
# knowledge = knowledge_content + knowledge_docs
# return knowledge
# # 如果所有标题的相似度都不大于0.55,则返回 None
# return None
# # 在原函数中使用新的函数进行相似度阈值的判断
# similar_docs = check_similarity_threshold(titles, query, knowledge_docs, knowledge_content)
# if similar_docs is None:
# # 如果所有标题的相似度都不大于0.55,则执行搜索引擎查询
# kgo_docs = search_internet(processed_query)
# # print("kgo_docs", kgo_docs)
# return kgo_docs
# else:
# kgo_docs = search_internet(processed_query)
# # print("similar_docs", similar_docs)
# # print("kgo_docs", kgo_docs)
# similar_docs.extend(kgo_docs)
# return similar_docs
# else:
# # 执行搜索引擎查询
# kgo_docs = search_internet(query)
# return kgo_docs
def knowledgebase_kgo_iter(query: str, uid: str) -> BaseResponse | list[str] | Any:
kgo_docs = search_internet(query , uid)
return kgo_docs
def knowledgebase_kgo_search(query: str) -> List[str]:
try:
matches = re.findall(r'\{.*?\}', query)
if len(matches)>=2:
query = matches[0]
parsed_query = json.loads(query)
# 继续使用解析后的查询进行后续操作
time_based_uuid = json.loads(matches[1])["uuid"]
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(knowledgebase_kgo_iter, parsed_query["query"], time_based_uuid)
res = future.result()
# res = knowledgebase_kgo_iter(parsed_query["query"],time_based_uuid)
try:
if type(res[0])==list and len(res[0])>0:
doc_content = "资料内容"
for doc in res[0]:
doc_content += doc
doc_content += "资料来源"
for source in res[1]:
doc_content += source
return f"{doc_content}"
elif type(res[1])==list and len(res[1])>0:
doc_content += "资料来源"
for source in res[1]:
doc_content += source
return "只有标题没有内容,标题为:{doc_content}"
else:
# return "<system>不要再调用工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
except Exception as e:
logging.error(f"Error occurred while processing query: {e}")
# return "<system>不要再调用该工具了,根据已有资料或自身能力回答</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
else:
logging.error("Invalid JSON format in query.")
# return "<system>不要再调用该工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
except json.JSONDecodeError:
# 如果JSON解码失败则返回错误消息
logging.error("Invalid JSON format in query.")
# return "<system>不要再调用该工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
except KeyError:
# 如果解析的JSON对象中缺少必要的键则返回错误消息
# return "<system>不要再调用该工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
except Exception as e:
# 捕获其他所有异常,并返回通用错误消息
logging.error(f"Error occurred while processing query: {e}")
# return f"<system>不要再调用该工具了</system>"
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
if __name__ == "__main__":
result = knowledgebase_kgo_iter("《区块链和分布式记账技术标准体系建设指南》")
print("检索结果:", result)