320 lines
12 KiB
Python
320 lines
12 KiB
Python
|
|
from datetime import datetime
|
|||
|
|
import hashlib
|
|||
|
|
import re
|
|||
|
|
from fastapi import Body, Request
|
|||
|
|
from sse_starlette.sse import EventSourceResponse
|
|||
|
|
from fastapi.concurrency import run_in_threadpool
|
|||
|
|
from configs import (
|
|||
|
|
LLM_MODELS,
|
|||
|
|
VECTOR_SEARCH_TOP_K,
|
|||
|
|
SCORE_THRESHOLD,
|
|||
|
|
SELF_TOP_K,
|
|||
|
|
SELF_SCORE_THRESHOLD,
|
|||
|
|
TEMPERATURE,
|
|||
|
|
SELF_TEMPERATURE,
|
|||
|
|
USE_RERANKER,
|
|||
|
|
RERANKER_MODEL,
|
|||
|
|
RERANKER_MAX_LENGTH,
|
|||
|
|
SELF_MAX_TOKENS,
|
|||
|
|
SELF_USE_RERANKER,
|
|||
|
|
MODEL_PATH
|
|||
|
|
)
|
|||
|
|
from server.chat.ZhipuSearchAPI import ZhipuSearchAPIWrapper
|
|||
|
|
from server.chat.policy_fun_iast import get_llm_model_response
|
|||
|
|
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
|||
|
|
from server.utils import wrap_done, get_ChatOpenAI
|
|||
|
|
from server.utils import BaseResponse, get_prompt_template
|
|||
|
|
from langchain.chains import LLMChain
|
|||
|
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||
|
|
from typing import AsyncIterable, List, Optional, Dict, Tuple, Union, AsyncGenerator
|
|||
|
|
import asyncio
|
|||
|
|
from langchain.prompts.chat import ChatPromptTemplate
|
|||
|
|
from server.chat.utils import History, get_similar_documents, get_text_by_regex
|
|||
|
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
|||
|
|
import json
|
|||
|
|
from server.knowledge_base.kb_doc_api import search_self_docs
|
|||
|
|
from configs.basic_config import *
|
|||
|
|
from langchain.chains.question_answering import load_qa_chain
|
|||
|
|
from typing import Generator
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
async def word_explain(
|
|||
|
|
query: str = Body(..., description="用户输入", examples=["智慧科协"]),
|
|||
|
|
fileNames: List = Body([], description="文件名称", examples=[["孟庆海同志在“智慧科协2.0”5·30场景建设工作部署会议上的讲话.docx"]]),
|
|||
|
|
knowledge_base_name_list: List[str] = Body(..., description="知识库列表", examples=[["p_cast0101011"]]),
|
|||
|
|
history: List[History] = Body(
|
|||
|
|
[],
|
|||
|
|
description="历史对话",
|
|||
|
|
examples=[[
|
|||
|
|
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
|||
|
|
{"role": "assistant", "content": "虎头虎脑"}]
|
|||
|
|
]
|
|||
|
|
),
|
|||
|
|
# stream: bool = Body(False, description="流式输出"),
|
|||
|
|
) -> EventSourceResponse:
|
|||
|
|
"""
|
|||
|
|
名词解释api\n
|
|||
|
|
入参:\n
|
|||
|
|
query: str 用户输入\n
|
|||
|
|
fileNames: List[str] 文件名称\n
|
|||
|
|
knowledge_base_name_list: List[str] 知识库列表\n
|
|||
|
|
history: List[History] 历史对话\n
|
|||
|
|
"""
|
|||
|
|
logger.info(
|
|||
|
|
f"名词解释请求: query={query}, "
|
|||
|
|
f"文件={fileNames}, "
|
|||
|
|
f"知识库数={knowledge_base_name_list}, "
|
|||
|
|
f"历史记录数={len(history)}, "
|
|||
|
|
# f"流式模式={stream}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 知识库存在性校验
|
|||
|
|
for kb_name in knowledge_base_name_list:
|
|||
|
|
if not KBServiceFactory.get_service_by_name(kb_name):
|
|||
|
|
return BaseResponse(code=404, msg=f"知识库 {kb_name} 不存在")
|
|||
|
|
|
|||
|
|
# 名词检查
|
|||
|
|
check_result = await run_in_threadpool(
|
|||
|
|
get_llm_model_response,
|
|||
|
|
strategy_name="word_check",
|
|||
|
|
llm_model_name=LLM_MODELS[0],
|
|||
|
|
template_prompt_name="word_check",
|
|||
|
|
prompt_param_dict={"query": query},
|
|||
|
|
temperature=0.01,
|
|||
|
|
max_tokens=512
|
|||
|
|
)
|
|||
|
|
logger.info(f"名词检查结果: {check_result}")
|
|||
|
|
docs = []
|
|||
|
|
internet_context = ""
|
|||
|
|
internet_doc = []
|
|||
|
|
if '0' in check_result:
|
|||
|
|
# logger.warning(f"非名词请求被拒绝: {query}")
|
|||
|
|
# return BaseResponse(code=500, msg="请选择需要解释的单独名词。")
|
|||
|
|
prompt_name = "word_explain_reject"
|
|||
|
|
else:
|
|||
|
|
prompt_name = "word_explain"
|
|||
|
|
# 并发获取文档和网络结果
|
|||
|
|
docs, internet_context, internet_doc = await retrieve_documents(
|
|||
|
|
query=query,
|
|||
|
|
fileNames=fileNames,
|
|||
|
|
knowledge_base_names=knowledge_base_name_list,
|
|||
|
|
top_k=SELF_TOP_K,
|
|||
|
|
score_threshold=SELF_SCORE_THRESHOLD
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return EventSourceResponse(
|
|||
|
|
generate_chat_response(
|
|||
|
|
query=query,
|
|||
|
|
docs=docs,
|
|||
|
|
internet_context=internet_context,
|
|||
|
|
internet_docs=internet_doc,
|
|||
|
|
fileNames=fileNames,
|
|||
|
|
# stream=stream,
|
|||
|
|
model_name=LLM_MODELS[0],
|
|||
|
|
prompt_name=prompt_name
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"处理请求异常: {str(e)}", exc_info=True)
|
|||
|
|
return BaseResponse(code=500, msg=f"处理请求时发生错误: {str(e)}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def retrieve_documents(
|
|||
|
|
query: str,
|
|||
|
|
fileNames: List[str],
|
|||
|
|
knowledge_base_names: List[str],
|
|||
|
|
top_k: int,
|
|||
|
|
score_threshold: float
|
|||
|
|
) -> Tuple[List[DocumentWithVSId], str, List[str]]:
|
|||
|
|
"""
|
|||
|
|
并发检索文档和网络信息
|
|||
|
|
返回格式: (文档列表, 网络上下文, 名称+超链接格式列表)
|
|||
|
|
"""
|
|||
|
|
async def fetch_kb_docs(kb_name: str) -> List[DocumentWithVSId]:
|
|||
|
|
"""获取知识库匹配结果"""
|
|||
|
|
try:
|
|||
|
|
return await run_in_threadpool(
|
|||
|
|
search_self_docs,
|
|||
|
|
query=query,
|
|||
|
|
fileNames=fileNames,
|
|||
|
|
knowledge_base_name=kb_name,
|
|||
|
|
top_k=top_k,
|
|||
|
|
score_threshold=score_threshold
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"知识库 {kb_name} 检索失败: {str(e)}")
|
|||
|
|
return []
|
|||
|
|
# TODO 更换duckduckgo
|
|||
|
|
async def fetch_web_results() -> List[dict]:
|
|||
|
|
"""获取网络搜索结果"""
|
|||
|
|
try:
|
|||
|
|
wrapper = ZhipuSearchAPIWrapper()
|
|||
|
|
return await run_in_threadpool(
|
|||
|
|
wrapper.zhipu_search,
|
|||
|
|
origin_query=query
|
|||
|
|
)
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"网络检索失败: {str(e)}")
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 并发执行所有检索任务
|
|||
|
|
tasks = [
|
|||
|
|
*[fetch_kb_docs(kb) for kb in knowledge_base_names],
|
|||
|
|
fetch_web_results()
|
|||
|
|
]
|
|||
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|||
|
|
|
|||
|
|
# 处理结果
|
|||
|
|
self_docs: List[DocumentWithVSId] = []
|
|||
|
|
internet_context = ""
|
|||
|
|
internet_docs = []
|
|||
|
|
|
|||
|
|
for result in results:
|
|||
|
|
if isinstance(result, Exception):
|
|||
|
|
# 记录并发任务中的异常
|
|||
|
|
logger.error(f"并发任务出现异常: {str(result)}")
|
|||
|
|
continue
|
|||
|
|
elif isinstance(result, list) and all(isinstance(doc, DocumentWithVSId) for doc in result):
|
|||
|
|
self_docs.extend(result)
|
|||
|
|
elif isinstance(result, list) and all(isinstance(item, dict) for item in result):
|
|||
|
|
# 相似度过滤互联网检索结果
|
|||
|
|
if len(results[1]) > 0:
|
|||
|
|
try:
|
|||
|
|
sentences = [doc["title"] for doc in results[1]]
|
|||
|
|
sentences_page_content = [str(i+1)+":【"+doc["title"]+doc["content"]+"】" for i,doc in enumerate(results[1])]
|
|||
|
|
except Exception as e:
|
|||
|
|
sentences = [doc["source"] for doc in results[1]]
|
|||
|
|
sentences_page_content = [str(i+1)+":【"+doc["source"]+doc["content"]+"】" for i,doc in enumerate(results[1])]
|
|||
|
|
res = get_llm_model_response(
|
|||
|
|
strategy_name="default_similar",
|
|||
|
|
llm_model_name=LLM_MODELS[0],
|
|||
|
|
template_prompt_name="default_similar",
|
|||
|
|
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
|
|||
|
|
temperature=0.01,
|
|||
|
|
max_tokens=512
|
|||
|
|
)
|
|||
|
|
try:
|
|||
|
|
index =[]
|
|||
|
|
if res == "无":
|
|||
|
|
index = []
|
|||
|
|
else:
|
|||
|
|
index = res.split(",")
|
|||
|
|
index = [int(i)-1 for i in index]
|
|||
|
|
similar_document = get_similar_documents(index=index,sentences=sentences,query=query, docs=results[1], top_k=top_k)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(e)
|
|||
|
|
similar_document = get_similar_documents(index=[],sentences=sentences,query=query, docs=results[1], top_k=top_k)
|
|||
|
|
# 提取所有 content 字段并拼接
|
|||
|
|
for num,item in enumerate(similar_document, start=1):
|
|||
|
|
content = item.get('content', '')
|
|||
|
|
title = item.get('title', '')
|
|||
|
|
cleaned_title = re.sub(r'(发布时间:.*?)', '', title)
|
|||
|
|
link = item.get('link', '')
|
|||
|
|
internet_docs.append(f"[{num}] [{cleaned_title}]({link})")
|
|||
|
|
internet_context += content
|
|||
|
|
|
|||
|
|
return self_docs, internet_context, internet_docs
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"文档检索异常: {str(e)}")
|
|||
|
|
return [], "", []
|
|||
|
|
|
|||
|
|
|
|||
|
|
def generate_chat_response(
|
|||
|
|
query: str,
|
|||
|
|
docs: List[Dict],
|
|||
|
|
internet_context: str,
|
|||
|
|
fileNames: List[str],
|
|||
|
|
internet_docs: List[str],
|
|||
|
|
# stream: bool,
|
|||
|
|
model_name: str,
|
|||
|
|
prompt_name: str
|
|||
|
|
) -> Union[EventSourceResponse, BaseResponse]:
|
|||
|
|
"""
|
|||
|
|
生成聊天响应,整合本地文档和网络上下文
|
|||
|
|
"""
|
|||
|
|
callback = AsyncIteratorCallbackHandler()
|
|||
|
|
model = get_ChatOpenAI(
|
|||
|
|
model_name=model_name,
|
|||
|
|
temperature=SELF_TEMPERATURE,
|
|||
|
|
max_tokens=SELF_MAX_TOKENS,
|
|||
|
|
callbacks=[callback],
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 构建提示模板
|
|||
|
|
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
|
|||
|
|
chat_prompt = ChatPromptTemplate.from_messages([
|
|||
|
|
History(role="system", content=prompt_template).to_msg_template(False)
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
# 配置问答链
|
|||
|
|
qa_chain = load_qa_chain(
|
|||
|
|
model,
|
|||
|
|
chain_type="stuff",
|
|||
|
|
prompt=chat_prompt,
|
|||
|
|
verbose=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def generation_wrapper() -> AsyncGenerator[str, None]:
|
|||
|
|
try:
|
|||
|
|
# 预处理文件名
|
|||
|
|
clean_files = [re.sub(r"\.\w+$", "", f) for f in fileNames]
|
|||
|
|
|
|||
|
|
# 构建上下文数据
|
|||
|
|
context_data = {
|
|||
|
|
"input_documents": docs,
|
|||
|
|
"question": query,
|
|||
|
|
"fileName": ", ".join(clean_files),
|
|||
|
|
"internet_context": internet_context
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
task = asyncio.create_task(wrap_done(
|
|||
|
|
qa_chain.acall(context_data),
|
|||
|
|
callback.done
|
|||
|
|
))
|
|||
|
|
|
|||
|
|
# 流式输出处理
|
|||
|
|
async for token in callback.aiter():
|
|||
|
|
yield json.dumps({"text": token}, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
# 添加来源文档
|
|||
|
|
if docs or internet_docs:
|
|||
|
|
# response = {}
|
|||
|
|
# source_documents = "\n".join(internet_docs)
|
|||
|
|
yield json.dumps({"docs": internet_docs}, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
await task
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.error(f"生成响应异常: {str(e)}")
|
|||
|
|
yield json.dumps({"error": "生成响应时发生错误"}, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
return generation_wrapper()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def process_source_documents(docs: List[Dict]) -> List[str]:
|
|||
|
|
"""
|
|||
|
|
处理并格式化来源文档信息
|
|||
|
|
"""
|
|||
|
|
seen = set()
|
|||
|
|
processed = []
|
|||
|
|
|
|||
|
|
for idx, doc in enumerate(docs, 1):
|
|||
|
|
# 提取文档摘要
|
|||
|
|
summary = doc.metadata.get("summary", "")
|
|||
|
|
text = get_text_by_regex(summary) or doc.page_content[:200]
|
|||
|
|
|
|||
|
|
# 生成唯一标识
|
|||
|
|
content_hash = hashlib.md5(text.encode()).hexdigest()
|
|||
|
|
if content_hash in seen:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
seen.add(content_hash)
|
|||
|
|
source = doc.metadata.get("source", "未知来源")
|
|||
|
|
processed.append(f"[{idx}] {text[:10]}...(来源:{source})")
|
|||
|
|
|
|||
|
|
logger.info(f"处理后的引用文档数: {len(processed)}")
|
|||
|
|
return processed
|