Files
gangyan/langchain-chat/server/chat/word_explain.py

320 lines
12 KiB
Python
Raw Normal View History

from datetime import datetime
import hashlib
import re
from fastapi import Body, Request
from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from configs import (
LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
SELF_TOP_K,
SELF_SCORE_THRESHOLD,
TEMPERATURE,
SELF_TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
SELF_MAX_TOKENS,
SELF_USE_RERANKER,
MODEL_PATH
)
from server.chat.ZhipuSearchAPI import ZhipuSearchAPIWrapper
from server.chat.policy_fun_iast import get_llm_model_response
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional, Dict, Tuple, Union, AsyncGenerator
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History, get_similar_documents, get_text_by_regex
from server.knowledge_base.kb_service.base import KBServiceFactory
import json
from server.knowledge_base.kb_doc_api import search_self_docs
from configs.basic_config import *
from langchain.chains.question_answering import load_qa_chain
from typing import Generator
logger = logging.getLogger(__name__)
async def word_explain(
query: str = Body(..., description="用户输入", examples=["智慧科协"]),
fileNames: List = Body([], description="文件名称", examples=[["孟庆海同志在“智慧科协2.0”5·30场景建设工作部署会议上的讲话.docx"]]),
knowledge_base_name_list: List[str] = Body(..., description="知识库列表", examples=[["p_cast0101011"]]),
history: List[History] = Body(
[],
description="历史对话",
examples=[[
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]
]
),
# stream: bool = Body(False, description="流式输出"),
) -> EventSourceResponse:
"""
名词解释api\n
入参\n
query: str 用户输入\n
fileNames: List[str] 文件名称\n
knowledge_base_name_list: List[str] 知识库列表\n
history: List[History] 历史对话\n
"""
logger.info(
f"名词解释请求: query={query}, "
f"文件={fileNames}, "
f"知识库数={knowledge_base_name_list}, "
f"历史记录数={len(history)}, "
# f"流式模式={stream}"
)
try:
# 知识库存在性校验
for kb_name in knowledge_base_name_list:
if not KBServiceFactory.get_service_by_name(kb_name):
return BaseResponse(code=404, msg=f"知识库 {kb_name} 不存在")
# 名词检查
check_result = await run_in_threadpool(
get_llm_model_response,
strategy_name="word_check",
llm_model_name=LLM_MODELS[0],
template_prompt_name="word_check",
prompt_param_dict={"query": query},
temperature=0.01,
max_tokens=512
)
logger.info(f"名词检查结果: {check_result}")
docs = []
internet_context = ""
internet_doc = []
if '0' in check_result:
# logger.warning(f"非名词请求被拒绝: {query}")
# return BaseResponse(code=500, msg="请选择需要解释的单独名词。")
prompt_name = "word_explain_reject"
else:
prompt_name = "word_explain"
# 并发获取文档和网络结果
docs, internet_context, internet_doc = await retrieve_documents(
query=query,
fileNames=fileNames,
knowledge_base_names=knowledge_base_name_list,
top_k=SELF_TOP_K,
score_threshold=SELF_SCORE_THRESHOLD
)
return EventSourceResponse(
generate_chat_response(
query=query,
docs=docs,
internet_context=internet_context,
internet_docs=internet_doc,
fileNames=fileNames,
# stream=stream,
model_name=LLM_MODELS[0],
prompt_name=prompt_name
)
)
except Exception as e:
logger.error(f"处理请求异常: {str(e)}", exc_info=True)
return BaseResponse(code=500, msg=f"处理请求时发生错误: {str(e)}")
async def retrieve_documents(
query: str,
fileNames: List[str],
knowledge_base_names: List[str],
top_k: int,
score_threshold: float
) -> Tuple[List[DocumentWithVSId], str, List[str]]:
"""
并发检索文档和网络信息
返回格式: (文档列表, 网络上下文, 名称+超链接格式列表)
"""
async def fetch_kb_docs(kb_name: str) -> List[DocumentWithVSId]:
"""获取知识库匹配结果"""
try:
return await run_in_threadpool(
search_self_docs,
query=query,
fileNames=fileNames,
knowledge_base_name=kb_name,
top_k=top_k,
score_threshold=score_threshold
)
except Exception as e:
logger.error(f"知识库 {kb_name} 检索失败: {str(e)}")
return []
# TODO 更换duckduckgo
async def fetch_web_results() -> List[dict]:
"""获取网络搜索结果"""
try:
wrapper = ZhipuSearchAPIWrapper()
return await run_in_threadpool(
wrapper.zhipu_search,
origin_query=query
)
except Exception as e:
logger.error(f"网络检索失败: {str(e)}")
return []
try:
# 并发执行所有检索任务
tasks = [
*[fetch_kb_docs(kb) for kb in knowledge_base_names],
fetch_web_results()
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
self_docs: List[DocumentWithVSId] = []
internet_context = ""
internet_docs = []
for result in results:
if isinstance(result, Exception):
# 记录并发任务中的异常
logger.error(f"并发任务出现异常: {str(result)}")
continue
elif isinstance(result, list) and all(isinstance(doc, DocumentWithVSId) for doc in result):
self_docs.extend(result)
elif isinstance(result, list) and all(isinstance(item, dict) for item in result):
# 相似度过滤互联网检索结果
if len(results[1]) > 0:
try:
sentences = [doc["title"] for doc in results[1]]
sentences_page_content = [str(i+1)+":【"+doc["title"]+doc["content"]+"" for i,doc in enumerate(results[1])]
except Exception as e:
sentences = [doc["source"] for doc in results[1]]
sentences_page_content = [str(i+1)+":【"+doc["source"]+doc["content"]+"" for i,doc in enumerate(results[1])]
res = get_llm_model_response(
strategy_name="default_similar",
llm_model_name=LLM_MODELS[0],
template_prompt_name="default_similar",
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=512
)
try:
index =[]
if res == "":
index = []
else:
index = res.split(",")
index = [int(i)-1 for i in index]
similar_document = get_similar_documents(index=index,sentences=sentences,query=query, docs=results[1], top_k=top_k)
except Exception as e:
print(e)
similar_document = get_similar_documents(index=[],sentences=sentences,query=query, docs=results[1], top_k=top_k)
# 提取所有 content 字段并拼接
for num,item in enumerate(similar_document, start=1):
content = item.get('content', '')
title = item.get('title', '')
cleaned_title = re.sub(r'(发布时间:.*?', '', title)
link = item.get('link', '')
internet_docs.append(f"[{num}] [{cleaned_title}]({link})")
internet_context += content
return self_docs, internet_context, internet_docs
except Exception as e:
logger.error(f"文档检索异常: {str(e)}")
return [], "", []
def generate_chat_response(
query: str,
docs: List[Dict],
internet_context: str,
fileNames: List[str],
internet_docs: List[str],
# stream: bool,
model_name: str,
prompt_name: str
) -> Union[EventSourceResponse, BaseResponse]:
"""
生成聊天响应整合本地文档和网络上下文
"""
callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI(
model_name=model_name,
temperature=SELF_TEMPERATURE,
max_tokens=SELF_MAX_TOKENS,
callbacks=[callback],
)
# 构建提示模板
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
chat_prompt = ChatPromptTemplate.from_messages([
History(role="system", content=prompt_template).to_msg_template(False)
])
# 配置问答链
qa_chain = load_qa_chain(
model,
chain_type="stuff",
prompt=chat_prompt,
verbose=True
)
async def generation_wrapper() -> AsyncGenerator[str, None]:
try:
# 预处理文件名
clean_files = [re.sub(r"\.\w+$", "", f) for f in fileNames]
# 构建上下文数据
context_data = {
"input_documents": docs,
"question": query,
"fileName": ", ".join(clean_files),
"internet_context": internet_context
}
task = asyncio.create_task(wrap_done(
qa_chain.acall(context_data),
callback.done
))
# 流式输出处理
async for token in callback.aiter():
yield json.dumps({"text": token}, ensure_ascii=False)
# 添加来源文档
if docs or internet_docs:
# response = {}
# source_documents = "\n".join(internet_docs)
yield json.dumps({"docs": internet_docs}, ensure_ascii=False)
await task
except Exception as e:
logger.error(f"生成响应异常: {str(e)}")
yield json.dumps({"error": "生成响应时发生错误"}, ensure_ascii=False)
return generation_wrapper()
def process_source_documents(docs: List[Dict]) -> List[str]:
"""
处理并格式化来源文档信息
"""
seen = set()
processed = []
for idx, doc in enumerate(docs, 1):
# 提取文档摘要
summary = doc.metadata.get("summary", "")
text = get_text_by_regex(summary) or doc.page_content[:200]
# 生成唯一标识
content_hash = hashlib.md5(text.encode()).hexdigest()
if content_hash in seen:
continue
seen.add(content_hash)
source = doc.metadata.get("source", "未知来源")
processed.append(f"[{idx}] {text[:10]}...(来源:{source}")
logger.info(f"处理后的引用文档数: {len(processed)}")
return processed