320 lines
12 KiB
Python
320 lines
12 KiB
Python
from datetime import datetime
|
||
import hashlib
|
||
import re
|
||
from fastapi import Body, Request
|
||
from sse_starlette.sse import EventSourceResponse
|
||
from fastapi.concurrency import run_in_threadpool
|
||
from configs import (
|
||
LLM_MODELS,
|
||
VECTOR_SEARCH_TOP_K,
|
||
SCORE_THRESHOLD,
|
||
SELF_TOP_K,
|
||
SELF_SCORE_THRESHOLD,
|
||
TEMPERATURE,
|
||
SELF_TEMPERATURE,
|
||
USE_RERANKER,
|
||
RERANKER_MODEL,
|
||
RERANKER_MAX_LENGTH,
|
||
SELF_MAX_TOKENS,
|
||
SELF_USE_RERANKER,
|
||
MODEL_PATH
|
||
)
|
||
from server.chat.ZhipuSearchAPI import ZhipuSearchAPIWrapper
|
||
from server.chat.policy_fun_iast import get_llm_model_response
|
||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||
from server.utils import wrap_done, get_ChatOpenAI
|
||
from server.utils import BaseResponse, get_prompt_template
|
||
from langchain.chains import LLMChain
|
||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||
from typing import AsyncIterable, List, Optional, Dict, Tuple, Union, AsyncGenerator
|
||
import asyncio
|
||
from langchain.prompts.chat import ChatPromptTemplate
|
||
from server.chat.utils import History, get_similar_documents, get_text_by_regex
|
||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||
import json
|
||
from server.knowledge_base.kb_doc_api import search_self_docs
|
||
from configs.basic_config import *
|
||
from langchain.chains.question_answering import load_qa_chain
|
||
from typing import Generator
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
async def word_explain(
|
||
query: str = Body(..., description="用户输入", examples=["智慧科协"]),
|
||
fileNames: List = Body([], description="文件名称", examples=[["孟庆海同志在“智慧科协2.0”5·30场景建设工作部署会议上的讲话.docx"]]),
|
||
knowledge_base_name_list: List[str] = Body(..., description="知识库列表", examples=[["p_cast0101011"]]),
|
||
history: List[History] = Body(
|
||
[],
|
||
description="历史对话",
|
||
examples=[[
|
||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||
{"role": "assistant", "content": "虎头虎脑"}]
|
||
]
|
||
),
|
||
# stream: bool = Body(False, description="流式输出"),
|
||
) -> EventSourceResponse:
|
||
"""
|
||
名词解释api\n
|
||
入参:\n
|
||
query: str 用户输入\n
|
||
fileNames: List[str] 文件名称\n
|
||
knowledge_base_name_list: List[str] 知识库列表\n
|
||
history: List[History] 历史对话\n
|
||
"""
|
||
logger.info(
|
||
f"名词解释请求: query={query}, "
|
||
f"文件={fileNames}, "
|
||
f"知识库数={knowledge_base_name_list}, "
|
||
f"历史记录数={len(history)}, "
|
||
# f"流式模式={stream}"
|
||
)
|
||
|
||
try:
|
||
# 知识库存在性校验
|
||
for kb_name in knowledge_base_name_list:
|
||
if not KBServiceFactory.get_service_by_name(kb_name):
|
||
return BaseResponse(code=404, msg=f"知识库 {kb_name} 不存在")
|
||
|
||
# 名词检查
|
||
check_result = await run_in_threadpool(
|
||
get_llm_model_response,
|
||
strategy_name="word_check",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="word_check",
|
||
prompt_param_dict={"query": query},
|
||
temperature=0.01,
|
||
max_tokens=512
|
||
)
|
||
logger.info(f"名词检查结果: {check_result}")
|
||
docs = []
|
||
internet_context = ""
|
||
internet_doc = []
|
||
if '0' in check_result:
|
||
# logger.warning(f"非名词请求被拒绝: {query}")
|
||
# return BaseResponse(code=500, msg="请选择需要解释的单独名词。")
|
||
prompt_name = "word_explain_reject"
|
||
else:
|
||
prompt_name = "word_explain"
|
||
# 并发获取文档和网络结果
|
||
docs, internet_context, internet_doc = await retrieve_documents(
|
||
query=query,
|
||
fileNames=fileNames,
|
||
knowledge_base_names=knowledge_base_name_list,
|
||
top_k=SELF_TOP_K,
|
||
score_threshold=SELF_SCORE_THRESHOLD
|
||
)
|
||
|
||
return EventSourceResponse(
|
||
generate_chat_response(
|
||
query=query,
|
||
docs=docs,
|
||
internet_context=internet_context,
|
||
internet_docs=internet_doc,
|
||
fileNames=fileNames,
|
||
# stream=stream,
|
||
model_name=LLM_MODELS[0],
|
||
prompt_name=prompt_name
|
||
)
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理请求异常: {str(e)}", exc_info=True)
|
||
return BaseResponse(code=500, msg=f"处理请求时发生错误: {str(e)}")
|
||
|
||
|
||
async def retrieve_documents(
|
||
query: str,
|
||
fileNames: List[str],
|
||
knowledge_base_names: List[str],
|
||
top_k: int,
|
||
score_threshold: float
|
||
) -> Tuple[List[DocumentWithVSId], str, List[str]]:
|
||
"""
|
||
并发检索文档和网络信息
|
||
返回格式: (文档列表, 网络上下文, 名称+超链接格式列表)
|
||
"""
|
||
async def fetch_kb_docs(kb_name: str) -> List[DocumentWithVSId]:
|
||
"""获取知识库匹配结果"""
|
||
try:
|
||
return await run_in_threadpool(
|
||
search_self_docs,
|
||
query=query,
|
||
fileNames=fileNames,
|
||
knowledge_base_name=kb_name,
|
||
top_k=top_k,
|
||
score_threshold=score_threshold
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"知识库 {kb_name} 检索失败: {str(e)}")
|
||
return []
|
||
# TODO 更换duckduckgo
|
||
async def fetch_web_results() -> List[dict]:
|
||
"""获取网络搜索结果"""
|
||
try:
|
||
wrapper = ZhipuSearchAPIWrapper()
|
||
return await run_in_threadpool(
|
||
wrapper.zhipu_search,
|
||
origin_query=query
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"网络检索失败: {str(e)}")
|
||
return []
|
||
|
||
try:
|
||
# 并发执行所有检索任务
|
||
tasks = [
|
||
*[fetch_kb_docs(kb) for kb in knowledge_base_names],
|
||
fetch_web_results()
|
||
]
|
||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
# 处理结果
|
||
self_docs: List[DocumentWithVSId] = []
|
||
internet_context = ""
|
||
internet_docs = []
|
||
|
||
for result in results:
|
||
if isinstance(result, Exception):
|
||
# 记录并发任务中的异常
|
||
logger.error(f"并发任务出现异常: {str(result)}")
|
||
continue
|
||
elif isinstance(result, list) and all(isinstance(doc, DocumentWithVSId) for doc in result):
|
||
self_docs.extend(result)
|
||
elif isinstance(result, list) and all(isinstance(item, dict) for item in result):
|
||
# 相似度过滤互联网检索结果
|
||
if len(results[1]) > 0:
|
||
try:
|
||
sentences = [doc["title"] for doc in results[1]]
|
||
sentences_page_content = [str(i+1)+":【"+doc["title"]+doc["content"]+"】" for i,doc in enumerate(results[1])]
|
||
except Exception as e:
|
||
sentences = [doc["source"] for doc in results[1]]
|
||
sentences_page_content = [str(i+1)+":【"+doc["source"]+doc["content"]+"】" for i,doc in enumerate(results[1])]
|
||
res = get_llm_model_response(
|
||
strategy_name="default_similar",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="default_similar",
|
||
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
|
||
temperature=0.01,
|
||
max_tokens=512
|
||
)
|
||
try:
|
||
index =[]
|
||
if res == "无":
|
||
index = []
|
||
else:
|
||
index = res.split(",")
|
||
index = [int(i)-1 for i in index]
|
||
similar_document = get_similar_documents(index=index,sentences=sentences,query=query, docs=results[1], top_k=top_k)
|
||
except Exception as e:
|
||
print(e)
|
||
similar_document = get_similar_documents(index=[],sentences=sentences,query=query, docs=results[1], top_k=top_k)
|
||
# 提取所有 content 字段并拼接
|
||
for num,item in enumerate(similar_document, start=1):
|
||
content = item.get('content', '')
|
||
title = item.get('title', '')
|
||
cleaned_title = re.sub(r'(发布时间:.*?)', '', title)
|
||
link = item.get('link', '')
|
||
internet_docs.append(f"[{num}] [{cleaned_title}]({link})")
|
||
internet_context += content
|
||
|
||
return self_docs, internet_context, internet_docs
|
||
|
||
except Exception as e:
|
||
logger.error(f"文档检索异常: {str(e)}")
|
||
return [], "", []
|
||
|
||
|
||
def generate_chat_response(
|
||
query: str,
|
||
docs: List[Dict],
|
||
internet_context: str,
|
||
fileNames: List[str],
|
||
internet_docs: List[str],
|
||
# stream: bool,
|
||
model_name: str,
|
||
prompt_name: str
|
||
) -> Union[EventSourceResponse, BaseResponse]:
|
||
"""
|
||
生成聊天响应,整合本地文档和网络上下文
|
||
"""
|
||
callback = AsyncIteratorCallbackHandler()
|
||
model = get_ChatOpenAI(
|
||
model_name=model_name,
|
||
temperature=SELF_TEMPERATURE,
|
||
max_tokens=SELF_MAX_TOKENS,
|
||
callbacks=[callback],
|
||
)
|
||
|
||
# 构建提示模板
|
||
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
||
chat_prompt = ChatPromptTemplate.from_messages([
|
||
History(role="system", content=prompt_template).to_msg_template(False)
|
||
])
|
||
|
||
# 配置问答链
|
||
qa_chain = load_qa_chain(
|
||
model,
|
||
chain_type="stuff",
|
||
prompt=chat_prompt,
|
||
verbose=True
|
||
)
|
||
|
||
async def generation_wrapper() -> AsyncGenerator[str, None]:
|
||
try:
|
||
# 预处理文件名
|
||
clean_files = [re.sub(r"\.\w+$", "", f) for f in fileNames]
|
||
|
||
# 构建上下文数据
|
||
context_data = {
|
||
"input_documents": docs,
|
||
"question": query,
|
||
"fileName": ", ".join(clean_files),
|
||
"internet_context": internet_context
|
||
}
|
||
|
||
task = asyncio.create_task(wrap_done(
|
||
qa_chain.acall(context_data),
|
||
callback.done
|
||
))
|
||
|
||
# 流式输出处理
|
||
async for token in callback.aiter():
|
||
yield json.dumps({"text": token}, ensure_ascii=False)
|
||
|
||
# 添加来源文档
|
||
if docs or internet_docs:
|
||
# response = {}
|
||
# source_documents = "\n".join(internet_docs)
|
||
yield json.dumps({"docs": internet_docs}, ensure_ascii=False)
|
||
|
||
await task
|
||
except Exception as e:
|
||
logger.error(f"生成响应异常: {str(e)}")
|
||
yield json.dumps({"error": "生成响应时发生错误"}, ensure_ascii=False)
|
||
|
||
return generation_wrapper()
|
||
|
||
|
||
def process_source_documents(docs: List[Dict]) -> List[str]:
|
||
"""
|
||
处理并格式化来源文档信息
|
||
"""
|
||
seen = set()
|
||
processed = []
|
||
|
||
for idx, doc in enumerate(docs, 1):
|
||
# 提取文档摘要
|
||
summary = doc.metadata.get("summary", "")
|
||
text = get_text_by_regex(summary) or doc.page_content[:200]
|
||
|
||
# 生成唯一标识
|
||
content_hash = hashlib.md5(text.encode()).hexdigest()
|
||
if content_hash in seen:
|
||
continue
|
||
|
||
seen.add(content_hash)
|
||
source = doc.metadata.get("source", "未知来源")
|
||
processed.append(f"[{idx}] {text[:10]}...(来源:{source})")
|
||
|
||
logger.info(f"处理后的引用文档数: {len(processed)}")
|
||
return processed |