171 lines
8.2 KiB
Python
171 lines
8.2 KiB
Python
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
import re
|
|||
|
|
from typing import List, Any, Union
|
|||
|
|
import concurrent
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from difflib import SequenceMatcher
|
|||
|
|
from configs import (VECTOR_SEARCH_TOP_K,
|
|||
|
|
SCORE_THRESHOLD,
|
|||
|
|
DEFAULT_POLICY_BASE)
|
|||
|
|
from server.agent.tools import search_internet
|
|||
|
|
from server.chat import utils
|
|||
|
|
from server.knowledge_base.kb_doc_api import search_docs
|
|||
|
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
|||
|
|
from server.utils import BaseResponse
|
|||
|
|
|
|||
|
|
|
|||
|
|
class KnowledgeKgoInput(BaseModel):
|
|||
|
|
location: str = Field(description="Query for Internet search")
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
def preprocess_text(text: str) -> str:
|
|||
|
|
# 去除空格和特殊符号
|
|||
|
|
text = re.sub(r'[\s\W]', '', text.strip())
|
|||
|
|
return text
|
|||
|
|
|
|||
|
|
|
|||
|
|
def knowledge_temperature(a: str, b: str) -> float:
|
|||
|
|
# 使用difflib中的SequenceMatcher计算相似度
|
|||
|
|
return SequenceMatcher(None, a, b).ratio()
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
|
|||
|
|
# def knowledgebase_kgo_iter(query: str,
|
|||
|
|
# fileName: List = [],
|
|||
|
|
# knowledge_base_name: str = DEFAULT_POLICY_BASE,
|
|||
|
|
# top_k: int = VECTOR_SEARCH_TOP_K,
|
|||
|
|
# score_threshold: float = SCORE_THRESHOLD) -> BaseResponse | list[str] | Any:
|
|||
|
|
# kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
|||
|
|
# if kb is None:
|
|||
|
|
# return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
|||
|
|
# query = query.strip()
|
|||
|
|
|
|||
|
|
# docs = search_docs(fileName=fileName,
|
|||
|
|
# query=query,
|
|||
|
|
# knowledge_base_name=knowledge_base_name,
|
|||
|
|
# top_k=top_k,
|
|||
|
|
# score_threshold=score_threshold)
|
|||
|
|
|
|||
|
|
# # 预处理查询文本
|
|||
|
|
# processed_query = preprocess_text(query).replace("Observ","")
|
|||
|
|
# print("processed_query:", processed_query)
|
|||
|
|
# knowledge_docs = []
|
|||
|
|
# knowledge_content = []
|
|||
|
|
# # 知识库返回的文档与query的相似度
|
|||
|
|
# if docs:
|
|||
|
|
# for enum, doc in enumerate(docs):
|
|||
|
|
# filename = doc.metadata.get("title")
|
|||
|
|
# detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html"
|
|||
|
|
# if filename:
|
|||
|
|
# text = f"""政策类资料[{enum + 1}]: [{filename}]({detail_url})\n"""
|
|||
|
|
# else:
|
|||
|
|
# text = f"""政策类资料[{enum + 1}]: [{"原文地址"}]({detail_url})\n"""
|
|||
|
|
# knowledge_docs.append(text)
|
|||
|
|
# # print("knowledge_docs:", knowledge_docs)
|
|||
|
|
# knowledge_content = [doc.page_content for doc in docs]
|
|||
|
|
# # print("knowledge_content:", knowledge_content)
|
|||
|
|
# # 计算知识库返回的文档与query的相似度
|
|||
|
|
# titles = [doc.metadata.get("title") for doc in docs]
|
|||
|
|
# print("titles:", titles)
|
|||
|
|
# def check_similarity_threshold(titles: List[str], query: str, knowledge_docs: List[str], knowledge_content: List[str]) -> Union[
|
|||
|
|
# List[str], None]:
|
|||
|
|
# # 用于记录是否存在相似度大于0.55的标题
|
|||
|
|
# has_similar_title = False
|
|||
|
|
# for title in titles:
|
|||
|
|
# processed_title = preprocess_text(title)
|
|||
|
|
# similarity = knowledge_temperature(processed_query, processed_title)
|
|||
|
|
# print("processed_title:", processed_title)
|
|||
|
|
# print("similarity:", similarity)
|
|||
|
|
# if similarity >= 0.55:
|
|||
|
|
# has_similar_title = True
|
|||
|
|
# break
|
|||
|
|
# # 如果存在相似度大于0.55的标题,则直接返回 knowledge_docs
|
|||
|
|
# if has_similar_title:
|
|||
|
|
# knowledge = knowledge_content + knowledge_docs
|
|||
|
|
# return knowledge
|
|||
|
|
# # 如果所有标题的相似度都不大于0.55,则返回 None
|
|||
|
|
# return None
|
|||
|
|
|
|||
|
|
# # 在原函数中使用新的函数进行相似度阈值的判断
|
|||
|
|
# similar_docs = check_similarity_threshold(titles, query, knowledge_docs, knowledge_content)
|
|||
|
|
# if similar_docs is None:
|
|||
|
|
# # 如果所有标题的相似度都不大于0.55,则执行搜索引擎查询
|
|||
|
|
# kgo_docs = search_internet(processed_query)
|
|||
|
|
# # print("kgo_docs", kgo_docs)
|
|||
|
|
# return kgo_docs
|
|||
|
|
# else:
|
|||
|
|
# kgo_docs = search_internet(processed_query)
|
|||
|
|
# # print("similar_docs", similar_docs)
|
|||
|
|
# # print("kgo_docs", kgo_docs)
|
|||
|
|
# similar_docs.extend(kgo_docs)
|
|||
|
|
# return similar_docs
|
|||
|
|
# else:
|
|||
|
|
# # 执行搜索引擎查询
|
|||
|
|
# kgo_docs = search_internet(query)
|
|||
|
|
# return kgo_docs
|
|||
|
|
|
|||
|
|
def knowledgebase_kgo_iter(query: str, uid: str) -> BaseResponse | list[str] | Any:
|
|||
|
|
kgo_docs = search_internet(query , uid)
|
|||
|
|
|
|||
|
|
return kgo_docs
|
|||
|
|
def knowledgebase_kgo_search(query: str) -> List[str]:
|
|||
|
|
try:
|
|||
|
|
matches = re.findall(r'\{.*?\}', query)
|
|||
|
|
if len(matches)>=2:
|
|||
|
|
query = matches[0]
|
|||
|
|
parsed_query = json.loads(query)
|
|||
|
|
# 继续使用解析后的查询进行后续操作
|
|||
|
|
time_based_uuid = json.loads(matches[1])["uuid"]
|
|||
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|||
|
|
future = executor.submit(knowledgebase_kgo_iter, parsed_query["query"], time_based_uuid)
|
|||
|
|
res = future.result()
|
|||
|
|
# res = knowledgebase_kgo_iter(parsed_query["query"],time_based_uuid)
|
|||
|
|
try:
|
|||
|
|
if type(res[0])==list and len(res[0])>0:
|
|||
|
|
doc_content = "资料内容"
|
|||
|
|
for doc in res[0]:
|
|||
|
|
doc_content += doc
|
|||
|
|
doc_content += "资料来源"
|
|||
|
|
for source in res[1]:
|
|||
|
|
doc_content += source
|
|||
|
|
return f"{doc_content}"
|
|||
|
|
elif type(res[1])==list and len(res[1])>0:
|
|||
|
|
doc_content += "资料来源"
|
|||
|
|
for source in res[1]:
|
|||
|
|
doc_content += source
|
|||
|
|
return "只有标题没有内容,标题为:{doc_content}"
|
|||
|
|
else:
|
|||
|
|
# return "<system>不要再调用工具了</system>"
|
|||
|
|
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
|||
|
|
except Exception as e:
|
|||
|
|
logging.error(f"Error occurred while processing query: {e}")
|
|||
|
|
# return "<system>不要再调用该工具了,根据已有资料或自身能力回答</system>"
|
|||
|
|
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
|||
|
|
else:
|
|||
|
|
logging.error("Invalid JSON format in query.")
|
|||
|
|
# return "<system>不要再调用该工具了</system>"
|
|||
|
|
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
# 如果JSON解码失败,则返回错误消息
|
|||
|
|
logging.error("Invalid JSON format in query.")
|
|||
|
|
# return "<system>不要再调用该工具了</system>"
|
|||
|
|
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
|||
|
|
except KeyError:
|
|||
|
|
# 如果解析的JSON对象中缺少必要的键,则返回错误消息
|
|||
|
|
# return "<system>不要再调用该工具了</system>"
|
|||
|
|
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
|||
|
|
except Exception as e:
|
|||
|
|
# 捕获其他所有异常,并返回通用错误消息
|
|||
|
|
logging.error(f"Error occurred while processing query: {e}")
|
|||
|
|
# return f"<system>不要再调用该工具了</system>"
|
|||
|
|
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
result = knowledgebase_kgo_iter("《区块链和分布式记账技术标准体系建设指南》")
|
|||
|
|
print("检索结果:", result)
|