Files
gangyan/langchain-chat/server/chat/report_chat.py

232 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body, Request
from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH,
MAX_TOKENS,
MAX_CUT_TOKENS)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template, get_format_template
from server.utils import get_strategy_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
import json
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
from server.reranker.reranker import LangchainReranker
from server.utils import embedding_device
from server.chat.policy_fun import add_summary_retrieved_results, get_llm_model_response
import json
async def report_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
fileName: List = Body([], description="文件名称", examples=[[]]),
knowledge_base_name: str = Body(..., description="知识库名称",
examples=["t_strategy_report_bge_v1"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(
SCORE_THRESHOLD,
description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右",
ge=0,
le=2
),
history: List[History] = Body(
[],
description="历史对话",
examples=[[]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
MAX_TOKENS,
description="限制LLM生成Token数量默认None代表模型最大值"
),
prompt_name: str = Body(
"default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
request: Request = None,
use_summary = True,
chunk_size: int = 20000,
min_chunk_size: int = 2000,
summary_model_name = LLM_MODELS[0],
query_rewrite_model_name = LLM_MODELS[0]
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator(
query: str,
top_k: int,
history: Optional[List[History]],
model_name: str = model_name,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
# print('-------------- debug', query)
search_query = get_llm_model_response(
strategy_name="query rewrite",
llm_model_name=query_rewrite_model_name,
template_prompt_name="query_rewrite_report",
prompt_param_dict={"query": query},
temperature=0.01,
max_tokens=512
)
# print('search query', search_query)
json_string = search_query.strip("```json\n").strip("```")
# print('search query----json string', json_string)
try: # 防止json格式错误
# 读取改写后的query
data = json.loads(json_string)
policies = data['report']
search_query = ''
for policy in policies:
search_query += policy
except:
search_query = query
print('search query', search_query)
docs = await run_in_threadpool(search_docs,
fileName=fileName,
query=search_query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
# print(docs)
# doc加入metadata的summary字段
if use_summary:
docs = await add_summary_retrieved_results(docs, query, 512,chunk_size,min_chunk_size,summary_model_name)
print(docs)
# context = "\n".join([doc.page_content for doc in docs])
# 需要规范格式的prompt_name
# 默认default即为空不用管
format_list = ["Abstract Assistant", "Outline Assistant"]
if prompt_name in format_list:
format_template = get_format_template("knowledge_base_chat", "abstract_format")
else:
format_template = get_format_template("knowledge_base_chat", "default")
# 政策知识库
# 相关信息把标题和内容进行整合
if knowledge_base_name == 't_strategy_report_bge_v1':
knowledge = []
newdocs =[]
for inum,doc in enumerate(docs):
if use_summary :
if len(doc.metadata['summary'])>15:
knowledge.append(f"""参考报告[{len(knowledge) + 1}] 报告来源: {doc.metadata['source']} \n报告内容: {doc.metadata['summary']}""")
newdocs.append(doc)
else:
pass
else:
knowledge.append(f"""参考报告[{inum + 1}] 报告来源: {doc.metadata['source']} \n报告内容: {doc.page_content}""")
context = "\n\n".join(knowledge)
docs = newdocs
# 非报告知识库
else:
context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0 and fileName: # 如果没有找到相关文档使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
elif len(docs) == 0 and not fileName and prompt_name != "Abstract Assistant":
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
elif prompt_name == 'iast_report_chat' or (knowledge_base_name == "t_strategy_report_bge_v1" and prompt_name == 'default'):
print("use report prompt_template")
prompt_template = get_strategy_prompt_template("knowledge_base_chat", 'iast_report_chat')
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
print("prompt_template", prompt_template)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_template() for i in history] + [input_msg])
chain = LLMChain(prompt=chat_prompt, llm=model)
print(
f"\n知识库问答开始调用:参数:\nkb:{knowledge_base_name}\nquery:{query}\nhistory:{history}\ncontext:{context}\nfile_name:{fileName}\nformat_template:{format_template}\n\n")
query = query.replace("原文", "")
task = asyncio.create_task(wrap_done(
chain.acall({"context": context,
"history": history,
"question": query,
"file_name": str(fileName),
"format_template": format_template}),
callback.done),
)
source_documents = []
# 报告知识库
if knowledge_base_name == 't_strategy_report_bge_v1':
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
print("filename", filename)
if filename:
text = f"""[{inum + 1}] 报告出处: [{filename}]\n\n{doc.metadata['summary']}\n\n"""
else:
text = f"""[{inum + 1}] \n\n{doc.metadata['summary']}\n\n"""
source_documents.append(text)
# 非报告知识库
else:
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
base_url = request.base_url
url = f"{base_url}knowledge_base/download_doc?" + parameters
if filename:
text = f"""出处: [{filename}]({url}) \n\n"""
else:
text = f"""出处: [{"原文地址"}]({url}) \n\n"""
source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": answer})
await task
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history, model_name, prompt_name))