171 lines
8.2 KiB
Python
171 lines
8.2 KiB
Python
import json
|
||
import logging
|
||
import re
|
||
from typing import List, Any, Union
|
||
import concurrent
|
||
from pydantic import BaseModel, Field
|
||
from difflib import SequenceMatcher
|
||
from configs import (VECTOR_SEARCH_TOP_K,
|
||
SCORE_THRESHOLD,
|
||
DEFAULT_POLICY_BASE)
|
||
from server.agent.tools import search_internet
|
||
from server.chat import utils
|
||
from server.knowledge_base.kb_doc_api import search_docs
|
||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||
from server.utils import BaseResponse
|
||
|
||
|
||
class KnowledgeKgoInput(BaseModel):
|
||
location: str = Field(description="Query for Internet search")
|
||
|
||
|
||
|
||
def preprocess_text(text: str) -> str:
|
||
# 去除空格和特殊符号
|
||
text = re.sub(r'[\s\W]', '', text.strip())
|
||
return text
|
||
|
||
|
||
def knowledge_temperature(a: str, b: str) -> float:
|
||
# 使用difflib中的SequenceMatcher计算相似度
|
||
return SequenceMatcher(None, a, b).ratio()
|
||
|
||
|
||
|
||
|
||
|
||
# def knowledgebase_kgo_iter(query: str,
|
||
# fileName: List = [],
|
||
# knowledge_base_name: str = DEFAULT_POLICY_BASE,
|
||
# top_k: int = VECTOR_SEARCH_TOP_K,
|
||
# score_threshold: float = SCORE_THRESHOLD) -> BaseResponse | list[str] | Any:
|
||
# kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||
# if kb is None:
|
||
# return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||
# query = query.strip()
|
||
|
||
# docs = search_docs(fileName=fileName,
|
||
# query=query,
|
||
# knowledge_base_name=knowledge_base_name,
|
||
# top_k=top_k,
|
||
# score_threshold=score_threshold)
|
||
|
||
# # 预处理查询文本
|
||
# processed_query = preprocess_text(query).replace("Observ","")
|
||
# print("processed_query:", processed_query)
|
||
# knowledge_docs = []
|
||
# knowledge_content = []
|
||
# # 知识库返回的文档与query的相似度
|
||
# if docs:
|
||
# for enum, doc in enumerate(docs):
|
||
# filename = doc.metadata.get("title")
|
||
# detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html"
|
||
# if filename:
|
||
# text = f"""政策类资料[{enum + 1}]: [{filename}]({detail_url})\n"""
|
||
# else:
|
||
# text = f"""政策类资料[{enum + 1}]: [{"原文地址"}]({detail_url})\n"""
|
||
# knowledge_docs.append(text)
|
||
# # print("knowledge_docs:", knowledge_docs)
|
||
# knowledge_content = [doc.page_content for doc in docs]
|
||
# # print("knowledge_content:", knowledge_content)
|
||
# # 计算知识库返回的文档与query的相似度
|
||
# titles = [doc.metadata.get("title") for doc in docs]
|
||
# print("titles:", titles)
|
||
# def check_similarity_threshold(titles: List[str], query: str, knowledge_docs: List[str], knowledge_content: List[str]) -> Union[
|
||
# List[str], None]:
|
||
# # 用于记录是否存在相似度大于0.55的标题
|
||
# has_similar_title = False
|
||
# for title in titles:
|
||
# processed_title = preprocess_text(title)
|
||
# similarity = knowledge_temperature(processed_query, processed_title)
|
||
# print("processed_title:", processed_title)
|
||
# print("similarity:", similarity)
|
||
# if similarity >= 0.55:
|
||
# has_similar_title = True
|
||
# break
|
||
# # 如果存在相似度大于0.55的标题,则直接返回 knowledge_docs
|
||
# if has_similar_title:
|
||
# knowledge = knowledge_content + knowledge_docs
|
||
# return knowledge
|
||
# # 如果所有标题的相似度都不大于0.55,则返回 None
|
||
# return None
|
||
|
||
# # 在原函数中使用新的函数进行相似度阈值的判断
|
||
# similar_docs = check_similarity_threshold(titles, query, knowledge_docs, knowledge_content)
|
||
# if similar_docs is None:
|
||
# # 如果所有标题的相似度都不大于0.55,则执行搜索引擎查询
|
||
# kgo_docs = search_internet(processed_query)
|
||
# # print("kgo_docs", kgo_docs)
|
||
# return kgo_docs
|
||
# else:
|
||
# kgo_docs = search_internet(processed_query)
|
||
# # print("similar_docs", similar_docs)
|
||
# # print("kgo_docs", kgo_docs)
|
||
# similar_docs.extend(kgo_docs)
|
||
# return similar_docs
|
||
# else:
|
||
# # 执行搜索引擎查询
|
||
# kgo_docs = search_internet(query)
|
||
# return kgo_docs
|
||
|
||
def knowledgebase_kgo_iter(query: str, uid: str) -> BaseResponse | list[str] | Any:
|
||
kgo_docs = search_internet(query , uid)
|
||
|
||
return kgo_docs
|
||
def knowledgebase_kgo_search(query: str) -> List[str]:
|
||
try:
|
||
matches = re.findall(r'\{.*?\}', query)
|
||
if len(matches)>=2:
|
||
query = matches[0]
|
||
parsed_query = json.loads(query)
|
||
# 继续使用解析后的查询进行后续操作
|
||
time_based_uuid = json.loads(matches[1])["uuid"]
|
||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||
future = executor.submit(knowledgebase_kgo_iter, parsed_query["query"], time_based_uuid)
|
||
res = future.result()
|
||
# res = knowledgebase_kgo_iter(parsed_query["query"],time_based_uuid)
|
||
try:
|
||
if type(res[0])==list and len(res[0])>0:
|
||
doc_content = "资料内容"
|
||
for doc in res[0]:
|
||
doc_content += doc
|
||
doc_content += "资料来源"
|
||
for source in res[1]:
|
||
doc_content += source
|
||
return f"{doc_content}"
|
||
elif type(res[1])==list and len(res[1])>0:
|
||
doc_content += "资料来源"
|
||
for source in res[1]:
|
||
doc_content += source
|
||
return "只有标题没有内容,标题为:{doc_content}"
|
||
else:
|
||
# return "<system>不要再调用工具了</system>"
|
||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||
except Exception as e:
|
||
logging.error(f"Error occurred while processing query: {e}")
|
||
# return "<system>不要再调用该工具了,根据已有资料或自身能力回答</system>"
|
||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||
else:
|
||
logging.error("Invalid JSON format in query.")
|
||
# return "<system>不要再调用该工具了</system>"
|
||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||
except json.JSONDecodeError:
|
||
# 如果JSON解码失败,则返回错误消息
|
||
logging.error("Invalid JSON format in query.")
|
||
# return "<system>不要再调用该工具了</system>"
|
||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||
except KeyError:
|
||
# 如果解析的JSON对象中缺少必要的键,则返回错误消息
|
||
# return "<system>不要再调用该工具了</system>"
|
||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||
except Exception as e:
|
||
# 捕获其他所有异常,并返回通用错误消息
|
||
logging.error(f"Error occurred while processing query: {e}")
|
||
# return f"<system>不要再调用该工具了</system>"
|
||
return f"尝试调整入参重新调用联网思索工具(同一个问题调用超过三次就不要再使用该工具了,浪费时间)"
|
||
|
||
|
||
if __name__ == "__main__":
|
||
result = knowledgebase_kgo_iter("《区块链和分布式记账技术标准体系建设指南》")
|
||
print("检索结果:", result)
|