Files
gangyan/langchain-chat/server/chat/self_kb_chat.py

279 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body, Request
from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
SELF_TOP_K,
SELF_SCORE_THRESHOLD,
TEMPERATURE,
SELF_TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
SELF_MAX_TOKENS,
SELF_USE_RERANKER,
MODEL_PATH)
from server.chat.policy_fun_iast import get_llm_model_response
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History, get_first_sentence_by_regex, get_text_by_regex
from server.knowledge_base.kb_service.base import KBServiceFactory, TextRank
import json
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_self_docs
from server.reranker.reranker import LangchainReranker
from server.utils import embedding_device
from configs.basic_config import *
from langchain.memory import ConversationBufferMemory
from langchain.chains.question_answering import load_qa_chain
async def self_kb_chat(
query: str = Body(..., description="用户输入", examples=["智慧科协是什么"]),
quote: str = Body(..., description="用户引用的文段,引用问答时传该参数", examples=["今年“智慧科协2.0”要持之以恒深入贯彻落实习近平总书记的重要指示精神"]),
# word: str = Body(..., description="用户需要解释的名词,名词解释时传该参数", examples=["GDP"]),
fileNames: List = Body([], description="文件名称", examples=[["孟庆海同志在“智慧科协2.0”5·30场景建设工作部署会议上的讲话.docx"]]),
knowledge_base_name_list: list = Body(..., description="知识库列表",
examples=[[ "p_cast0101011"]]),
history: List[History] = Body(
[],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(True, description="流式输出"),
web_search: bool = Body(False, description="是否开启联网搜索"),
):
"""
个人知识库对话api\n
-入参信息:\n
query: 用户输入\n
quote: 用户引用的文段,引用问答时传该参数\n
fileNames: 文件名称\n
knowledge_base_name: 知识库名称\n
history: 历史对话\n
stream: 是否流式输出\n
"""
logger.info(f"个人知识库对话入参:\nquery:{query}\nquote:{quote}\nfileNames:{fileNames}\nknowledge_base_name_list:{knowledge_base_name_list}\nhistory:{history}\nstream:{stream}")
for knowledge_base_name in knowledge_base_name_list:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History.to_msg_tuple(h) for h in history]
async def knowledge_base_chat_iterator(
query: str,
model_name: str = LLM_MODELS[0],
model_name1: str = LLM_MODELS[0],
prompt_name: str = "self_default",
) -> AsyncIterable[str]:
nonlocal fileNames, history
callback = AsyncIteratorCallbackHandler()
model = get_ChatOpenAI(
model_name=model_name,
temperature=SELF_TEMPERATURE,
max_tokens=SELF_MAX_TOKENS,
callbacks=[callback],
)
model1 = get_ChatOpenAI(
model_name=model_name1,
temperature=SELF_TEMPERATURE,
max_tokens=SELF_MAX_TOKENS,
callbacks=[callback],
)
# 改写原问题
# 遍历历史消息并收集用户消息
user_queries = [] # 初始化列表来收集用户消息
for message in history:
role, content = message # 解包元组
if role == 'user':
user_queries.append(content)
search_query = get_llm_model_response(
strategy_name="self_query_rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="self_query_rewrite",
prompt_param_dict={"query": query, "history": user_queries, "quote": quote},
temperature=0.01,
max_tokens=512
)
logger.info(f"个人知识库问答query: {query}")
logger.info(f"个人知识库问答query_history: {user_queries}")
json_string = search_query.strip("```json\n").strip("```")
try: # 防止json格式错误
# 读取改写后的query
data = json.loads(json_string)
query = data['query']
search_query = ''
for q in query:
search_query += q
except:
search_query = query
logger.info(f"个人知识库问答search_query: {search_query}")
self_kb_route=get_llm_model_response(
strategy_name="self_kb_route",
llm_model_name=LLM_MODELS[0],
template_prompt_name="self_kb_route",
prompt_param_dict={"query": query},
temperature=0.01,
max_tokens=512
)
try:
if self_kb_route == '0':
logger.info(f"个人知识库问答路由结果:【全局问题】")
docs = await run_in_threadpool(
search_self_docs,
query="",
fileNames=fileNames,
knowledge_base_name=knowledge_base_name,
top_k=999,
score_threshold=2
)
elif self_kb_route == '1':
logger.info(f"个人知识库问答路由结果:【局部问题】")
docs = await run_in_threadpool(
search_self_docs,
query=search_query,
fileNames=fileNames,
knowledge_base_name=knowledge_base_name,
top_k=SELF_TOP_K,
score_threshold=SELF_SCORE_THRESHOLD
)
except Exception as e:
logger.error(f"个人知识库问答路由错误: {self_kb_route}", exc_info=True)
docs = []
logger.info(f"个人知识库问答source_documents: {len(docs)}")
# 联网搜索
web_search_context = ""
web_search_results = [] # 保存搜索结果供后面引用
if web_search:
try:
from server.chat.ZhipuSearchAPI import ZhipuSearchAPIWrapper
searcher = ZhipuSearchAPIWrapper()
web_results = searcher.zhipu_search(search_query)
web_search_results = web_results[:5] if web_results else []
if web_results:
web_parts = []
for i, r in enumerate(web_results[:5], 1):
title = r.get("title", "")
content = r.get("content", "")[:300]
url = r.get("url", "")
web_parts.append(f"[{i}] {title}\n{content}\n来源: {url}")
web_search_context = "\n\n【联网搜索结果】\n" + "\n\n".join(web_parts)
logger.info(f"联网搜索获取到 {len(web_results)} 条结果")
except Exception as e:
logger.error(f"联网搜索失败: {e}")
# if SELF_USE_RERANKER:
# reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large")
# print("-----------------model path------------------")
# print(reranker_model_path)
# reranker_model = LangchainReranker(top_n=SELF_TOP_K,
# device=embedding_device(),
# max_length=RERANKER_MAX_LENGTH,
# model_name_or_path=reranker_model_path
# )
# for idx, doc in enumerate(docs, start=1):
# print(f"{idx}: score={doc.score}")
# docs = reranker_model.compress_documents(documents=docs,
# query=query)
# print("---------after rerank------------------")
# for idx, doc in enumerate(docs, start=1):
# print(f"{idx}: score={doc.score}")
# context = "\n".join([doc.page_content for doc in docs]) #使用load_qa_chain需要送入DocumentWithVSId类型的资料
# 判断是否找到相关文档
if len(docs) == 0:
prompt_name = "self_empty" # 如果没有找到相关文档使用empty模板
elif quote:
# 根据quote的值选择不同的模板
prompt_name = "self_quote" if quote else prompt_name
# elif word:
# 根据word的值选择不同的模板
# prompt_name = "word_explain" if word else prompt_name
# 获取模板并生成消息
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="system", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
if '0' in self_kb_route:
context = "\n".join([doc.page_content for doc in docs]).strip("xa0")
logger.info(f"个人知识库问答 context 长度:{len(context)}")
context = context[:30000] if len(context)>30000 else context
if web_search_context:
context += web_search_context
logger.info(f"最终 context 长度:{len(context)}")
if history:
history = history if len(history) < 20000 else TextRank(history,num_sentences=1)
chain = LLMChain(prompt=chat_prompt, llm=model1, verbose=True)
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query, "history": history, "quote": quote, "fileName":fileNames}),
callback.done),
)
elif '1' in self_kb_route:
# 联网搜索结果作为额外文档加入
if web_search_context:
from langchain.docstore.document import Document as LCDocument
docs.append(LCDocument(page_content=web_search_context, metadata={"source": "web_search"}))
chain = load_qa_chain(
model,
chain_type="stuff",
prompt=chat_prompt,
verbose=True
)
task = asyncio.create_task(wrap_done(
chain.acall({"input_documents": docs, "question": query, "history": history, "quote": quote, "fileName":fileNames}),
callback.done),
)
# source_documents = []
# seen_texts = set() # 记录已出现过的 processed_text
# counter = 1 # 初始化计数器
# for doc in docs:
# text = doc.metadata.get("summary", "")
# # processed_text = get_first_sentence_by_regex(text)
# processed_text = get_text_by_regex(text)
# # 如果 processed_text 不在 seen_texts 中,才添加到结果中
# if processed_text and processed_text not in seen_texts:
# source_document = f"[{counter}] {processed_text}"
# source_documents.append(source_document)
# seen_texts.add(processed_text) # 标记为已出现
# counter += 1
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"text": token}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
response = {"text": answer}
yield json.dumps(response, ensure_ascii=False)
await task
source_documents = []
if len(docs) == 0 and not web_search_context:
source_documents.append(f"""暂未从本篇文献中找到答案,该回答为大模型自身能力解答!""")
else:
if len(docs) > 0:
source_documents.append(f"""[{len(source_documents) + 1}] [{docs[0].metadata.get("source")}]()\n""")
# 联网搜索结果链接
if web_search_results:
for r in web_search_results:
title = r.get("title", "").replace("\n", "")
url = r.get("url", "")
if title and url:
source_documents.append(f"""[{len(source_documents) + 1}] [{title}]({url})\n""")
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
return EventSourceResponse(knowledge_base_chat_iterator(query))