Files
gangyan/langchain-chat/server/chat/word_explain.py

320 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime
import hashlib
import re
from fastapi import Body, Request
from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from configs import (
LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
SELF_TOP_K,
SELF_SCORE_THRESHOLD,
TEMPERATURE,
SELF_TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
SELF_MAX_TOKENS,
SELF_USE_RERANKER,
MODEL_PATH
)
from server.chat.ZhipuSearchAPI import ZhipuSearchAPIWrapper
from server.chat.policy_fun_iast import get_llm_model_response
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional, Dict, Tuple, Union, AsyncGenerator
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History, get_similar_documents, get_text_by_regex
from server.knowledge_base.kb_service.base import KBServiceFactory
import json
from server.knowledge_base.kb_doc_api import search_self_docs
from configs.basic_config import *
from langchain.chains.question_answering import load_qa_chain
from typing import Generator
logger = logging.getLogger(__name__)
async def word_explain(
query: str = Body(..., description="用户输入", examples=["智慧科协"]),
fileNames: List = Body([], description="文件名称", examples=[["孟庆海同志在“智慧科协2.0”5·30场景建设工作部署会议上的讲话.docx"]]),
knowledge_base_name_list: List[str] = Body(..., description="知识库列表", examples=[["p_cast0101011"]]),
history: List[History] = Body(
[],
description="历史对话",
examples=[[
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]
]
),
# stream: bool = Body(False, description="流式输出"),
) -> EventSourceResponse:
"""
名词解释api\n
入参:\n
query: str 用户输入\n
fileNames: List[str] 文件名称\n
knowledge_base_name_list: List[str] 知识库列表\n
history: List[History] 历史对话\n
"""
logger.info(
f"名词解释请求: query={query}, "
f"文件={fileNames}, "
f"知识库数={knowledge_base_name_list}, "
f"历史记录数={len(history)}, "
# f"流式模式={stream}"
)
try:
# 知识库存在性校验
for kb_name in knowledge_base_name_list:
if not KBServiceFactory.get_service_by_name(kb_name):
return BaseResponse(code=404, msg=f"知识库 {kb_name} 不存在")
# 名词检查
check_result = await run_in_threadpool(
get_llm_model_response,
strategy_name="word_check",
llm_model_name=LLM_MODELS[0],
template_prompt_name="word_check",
prompt_param_dict={"query": query},
temperature=0.01,
max_tokens=512
)
logger.info(f"名词检查结果: {check_result}")
docs = []
internet_context = ""
internet_doc = []
if '0' in check_result:
# logger.warning(f"非名词请求被拒绝: {query}")
# return BaseResponse(code=500, msg="请选择需要解释的单独名词。")
prompt_name = "word_explain_reject"
else:
prompt_name = "word_explain"
# 并发获取文档和网络结果
docs, internet_context, internet_doc = await retrieve_documents(
query=query,
fileNames=fileNames,
knowledge_base_names=knowledge_base_name_list,
top_k=SELF_TOP_K,
score_threshold=SELF_SCORE_THRESHOLD
)
return EventSourceResponse(
generate_chat_response(
query=query,
docs=docs,
internet_context=internet_context,
internet_docs=internet_doc,
fileNames=fileNames,
# stream=stream,
model_name=LLM_MODELS[0],
prompt_name=prompt_name
)
)
except Exception as e:
logger.error(f"处理请求异常: {str(e)}", exc_info=True)
return BaseResponse(code=500, msg=f"处理请求时发生错误: {str(e)}")
async def retrieve_documents(
query: str,
fileNames: List[str],
knowledge_base_names: List[str],
top_k: int,
score_threshold: float
) -> Tuple[List[DocumentWithVSId], str, List[str]]:
"""
并发检索文档和网络信息
返回格式: (文档列表, 网络上下文, 名称+超链接格式列表)
"""
async def fetch_kb_docs(kb_name: str) -> List[DocumentWithVSId]:
"""获取知识库匹配结果"""
try:
return await run_in_threadpool(
search_self_docs,
query=query,
fileNames=fileNames,
knowledge_base_name=kb_name,
top_k=top_k,
score_threshold=score_threshold
)
except Exception as e:
logger.error(f"知识库 {kb_name} 检索失败: {str(e)}")
return []
# TODO 更换duckduckgo
async def fetch_web_results() -> List[dict]:
"""获取网络搜索结果"""
try:
wrapper = ZhipuSearchAPIWrapper()
return await run_in_threadpool(
wrapper.zhipu_search,
origin_query=query
)
except Exception as e:
logger.error(f"网络检索失败: {str(e)}")
return []
try:
# 并发执行所有检索任务
tasks = [
*[fetch_kb_docs(kb) for kb in knowledge_base_names],
fetch_web_results()
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
self_docs: List[DocumentWithVSId] = []
internet_context = ""
internet_docs = []
for result in results:
if isinstance(result, Exception):
# 记录并发任务中的异常
logger.error(f"并发任务出现异常: {str(result)}")
continue
elif isinstance(result, list) and all(isinstance(doc, DocumentWithVSId) for doc in result):
self_docs.extend(result)
elif isinstance(result, list) and all(isinstance(item, dict) for item in result):
# 相似度过滤互联网检索结果
if len(results[1]) > 0:
try:
sentences = [doc["title"] for doc in results[1]]
sentences_page_content = [str(i+1)+":【"+doc["title"]+doc["content"]+"" for i,doc in enumerate(results[1])]
except Exception as e:
sentences = [doc["source"] for doc in results[1]]
sentences_page_content = [str(i+1)+":【"+doc["source"]+doc["content"]+"" for i,doc in enumerate(results[1])]
res = get_llm_model_response(
strategy_name="default_similar",
llm_model_name=LLM_MODELS[0],
template_prompt_name="default_similar",
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=512
)
try:
index =[]
if res == "":
index = []
else:
index = res.split(",")
index = [int(i)-1 for i in index]
similar_document = get_similar_documents(index=index,sentences=sentences,query=query, docs=results[1], top_k=top_k)
except Exception as e:
print(e)
similar_document = get_similar_documents(index=[],sentences=sentences,query=query, docs=results[1], top_k=top_k)
# 提取所有 content 字段并拼接
for num,item in enumerate(similar_document, start=1):
content = item.get('content', '')
title = item.get('title', '')
cleaned_title = re.sub(r'(发布时间:.*?', '', title)
link = item.get('link', '')
internet_docs.append(f"[{num}] [{cleaned_title}]({link})")
internet_context += content
return self_docs, internet_context, internet_docs
except Exception as e:
logger.error(f"文档检索异常: {str(e)}")
return [], "", []
def generate_chat_response(
query: str,
docs: List[Dict],
internet_context: str,
fileNames: List[str],
internet_docs: List[str],
# stream: bool,
model_name: str,
prompt_name: str
) -> Union[EventSourceResponse, BaseResponse]:
"""
生成聊天响应,整合本地文档和网络上下文
"""
callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI(
model_name=model_name,
temperature=SELF_TEMPERATURE,
max_tokens=SELF_MAX_TOKENS,
callbacks=[callback],
)
# 构建提示模板
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
chat_prompt = ChatPromptTemplate.from_messages([
History(role="system", content=prompt_template).to_msg_template(False)
])
# 配置问答链
qa_chain = load_qa_chain(
model,
chain_type="stuff",
prompt=chat_prompt,
verbose=True
)
async def generation_wrapper() -> AsyncGenerator[str, None]:
try:
# 预处理文件名
clean_files = [re.sub(r"\.\w+$", "", f) for f in fileNames]
# 构建上下文数据
context_data = {
"input_documents": docs,
"question": query,
"fileName": ", ".join(clean_files),
"internet_context": internet_context
}
task = asyncio.create_task(wrap_done(
qa_chain.acall(context_data),
callback.done
))
# 流式输出处理
async for token in callback.aiter():
yield json.dumps({"text": token}, ensure_ascii=False)
# 添加来源文档
if docs or internet_docs:
# response = {}
# source_documents = "\n".join(internet_docs)
yield json.dumps({"docs": internet_docs}, ensure_ascii=False)
await task
except Exception as e:
logger.error(f"生成响应异常: {str(e)}")
yield json.dumps({"error": "生成响应时发生错误"}, ensure_ascii=False)
return generation_wrapper()
def process_source_documents(docs: List[Dict]) -> List[str]:
"""
处理并格式化来源文档信息
"""
seen = set()
processed = []
for idx, doc in enumerate(docs, 1):
# 提取文档摘要
summary = doc.metadata.get("summary", "")
text = get_text_by_regex(summary) or doc.page_content[:200]
# 生成唯一标识
content_hash = hashlib.md5(text.encode()).hexdigest()
if content_hash in seen:
continue
seen.add(content_hash)
source = doc.metadata.get("source", "未知来源")
processed.append(f"[{idx}] {text[:10]}...(来源:{source}")
logger.info(f"处理后的引用文档数: {len(processed)}")
return processed