Files
gangyan/langchain-chat/server/chat/file_chat.py

358 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body, File, Form, UploadFile
from sse_starlette.sse import EventSourceResponse
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.chat.agent_chat_test import run_sync
from server.chat.policy_fun_iast import get_llm_model_response
from server.utils import (wrap_done, get_ChatOpenAI,
BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool)
from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History, split_questions
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from server.knowledge_base.utils import KnowledgeFile
import json
import os
from pathlib import Path
from langchain.chains.question_answering import load_qa_chain
from langchain.memory import ConversationBufferMemory,ConversationSummaryMemory, ConversationBufferWindowMemory
from langchain.docstore.document import Document
from langchain_core.prompts import PromptTemplate
from datetime import datetime
from server.knowledge_base.kb_service.base import TextRank
from configs.basic_config import *
# def _parse_files_in_thread(
# files: List[UploadFile],
# dir: str,
# zh_title_enhance: bool,
# chunk_size: int,
# chunk_overlap: int,
# ):
# """
# 通过多线程将上传的文件保存到对应目录内。
# 生成器返回保存结果:[success or error, filename, msg, docs]
# """
# def parse_file(file: UploadFile) -> dict:
# '''
# 保存单个文件。
# '''
# try:
# filename = file.filename
# file_path = os.path.join(dir, filename)
# file_content = file.file.read() # 读取上传文件的内容
# if not os.path.isdir(os.path.dirname(file_path)):
# os.makedirs(os.path.dirname(file_path))
# with open(file_path, "wb") as f:
# f.write(file_content)
# kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp")
# kb_file.filepath = file_path
# docs = kb_file.file2text(zh_title_enhance=zh_title_enhance,
# chunk_size=chunk_size,
# chunk_overlap=chunk_overlap)
# return True, filename, f"成功上传文件 {filename}", docs
# except Exception as e:
# msg = f"{filename} 文件上传失败,报错信息为: {e}"
# return False, filename, msg, []
# params = [{"file": file} for file in files]
# for result in run_in_thread_pool(parse_file, params=params):
# yield result
# def upload_temp_docs(
# files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
# prev_id: str = Form(None, description="前知识库ID"),
# chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
# chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
# zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
# ) -> BaseResponse:
# '''
# 将文件保存到临时目录,并进行向量化。
# 返回临时目录名称作为ID同时也是临时向量库的ID。
# '''
# if prev_id is not None:
# memo_faiss_pool.pop(prev_id)
# failed_files = []
# documents = []
# path, id = get_temp_dir(prev_id)
# for success, file, msg, docs in _parse_files_in_thread(files=files,
# dir=path,
# zh_title_enhance=zh_title_enhance,
# chunk_size=chunk_size,
# chunk_overlap=chunk_overlap):
# if success:
# documents += docs
# else:
# failed_files.append({file: msg})
# with memo_faiss_pool.load_vector_store(id).acquire() as vs:
# vs.add_documents(documents)
# return BaseResponse(data={"id": id, "failed_files": failed_files})
def _parse_files_in_thread(
files: List[UploadFile],
dir: str,
zh_title_enhance: bool,
chunk_size: int,
chunk_overlap: int,
):
"""
通过多线程将上传的文件保存到对应目录内。
生成器返回保存结果:[success or error, filename, msg, docs]
"""
def parse_file(file: UploadFile) -> dict:
'''
保存单个文件。
'''
try:
filename = file.filename
file_path = os.path.join(dir, filename)
file_content = file.file.read() # 读取上传文件的内容
if not os.path.isdir(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, "wb") as f:
f.write(file_content)
kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp")
kb_file.filepath = file_path
docs = kb_file.file2text(zh_title_enhance=zh_title_enhance,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
for doc in docs:
if isinstance(doc, Document):
# 去除分词处理多余换行符
doc.page_content = doc.page_content.replace('\n', '')
return True, filename, f"成功上传文件 {filename}", docs
except Exception as e:
msg = f"{filename} 文件上传失败,报错信息为: {e}"
return False, filename, msg, []
params = [{"file": file} for file in files]
context = ""
for result in run_in_thread_pool(parse_file, params=params):
yield result
if result[0]: # success
for doc in result[3]: # docs
context += doc.page_content + "\n"
return context
def generate_summary(text: str) -> str:
# 根据文本长度,每 100 字生成一句摘要,最多生成 300 句
num_sentences = min(len(text) // 50, 300)
# num_sentences = 80
# 使用 TextRank 算法生成摘要
summary = TextRank(text, num_sentences=num_sentences)
return summary
@ timing_decorator
def upload_temp_docs(
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
prev_id: str = Form(None, description="前知识库ID"),
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
) -> BaseResponse:
'''
将文件保存到临时目录并返回临时目录名称作为ID同时返回文件的全文。
'''
if prev_id is not None:
memo_faiss_pool.pop(prev_id)
failed_files = []
documents = []
path, id = get_temp_dir(prev_id)
context = ""
summary = ""
for success, file, msg, docs in _parse_files_in_thread(files=files,
dir=path,
zh_title_enhance=zh_title_enhance,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap):
if success:
documents += docs
for doc in docs:
context += doc.page_content + "\n"
if len(context) > 30000:
summary = generate_summary(context)
else:
failed_files.append({file: msg})
return BaseResponse(data={"id": id, "context": context, "summary": summary})
async def file_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
file_name: str = Body("", description="文件名称", examples=["123.txt"]),
# knowledge_id: str = Body(..., description="临时知识库ID"),
# top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
# score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=2),
history: List[History] = Body([],
description="历史对话",
examples=[[
{"role": "user",
"content": "现在开始针对我上传的文件回答我的问题"},
{"role": "assistant",
"content": "好的,让我们开始看吧"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body("0.5", description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
prompt_name: str = Body("file_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
context: str = Body("", description="文件内容")):
# if knowledge_id not in memo_faiss_pool.keys():
# return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件")
# history = [History.from_data(h) for h in history]
# # 调用模型进行总结
# def llm_summary(
# llm_model_name: "str",
# prompt_template: str,
# temperature: float,
# max_tokens: int,
# ) -> str:
# '''调用大模型进行总结'''
# # 读取指定的大模型这里不能加入callback否则会把这部分模型响应加入最终的回答
# model = get_ChatOpenAI(
# model_name=llm_model_name,
# temperature=temperature,
# max_tokens=max_tokens,
# callbacks=[],
# )
# # 获取prompt
# prompt_template = get_prompt_template("knowledge_base_chat", "file_summary")
# input_msg = History(role="system", content=prompt_template).to_msg_template(False)
# prompt = ChatPromptTemplate.from_messages([input_msg])
# # 获取模型响应
# llm_chain = LLMChain(prompt=prompt, llm=model)
# summary = llm_chain.run(prompt_param_dict)
# return summary
async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
nonlocal history, context, max_tokens
callback = AsyncIteratorCallbackHandler()
memory = None
# 获取当前时间并格式化为YYYYMMDD
time = datetime.now().strftime("%Y%m%d")
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
# embed_func = EmbeddingsFunAdapter()
# embeddings = await embed_func.aembed_query(query)
# with memo_faiss_pool.acquire(knowledge_id) as vs:
# docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
# docs = [x[0] for x in docs]
# context = "\n".join([summary])
# if len(kdocs) == 0: ## 如果没有找到相关文档使用Empty模板
# prompt_template = get_prompt_template("knowledge_base_chat", "empty")
# else:
# prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
# input_msg = History(role="user", content=prompt_template).to_msg_template(False)
# chat_prompt = ChatPromptTemplate.from_messages(
# [i.to_msg_template() for i in history] + [input_msg])
if history:
history = [History.from_data(h) for h in history]
prompt_template = get_prompt_template("knowledge_base_chat", "file_chat_history")
# input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
# input_msg = History(role="user", content=query).to_msg_template(False)
# chat_prompt = ChatPromptTemplate.from_messages([input_prompt] + [input_msg])
chat_prompt = PromptTemplate.from_template(prompt_template)
# 把history转成memory
memory = ConversationBufferWindowMemory(k=1, input_key="question")
# memory = ConversationSummaryMemory(llm=model)
for message in history:
# 检查消息的角色
if message.role == 'user':
# 添加用户消息
memory.chat_memory.add_user_message(message.content)
elif message.role == 'assistant':
# 添加AI消息
memory.chat_memory.add_ai_message(message.content)
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
input_msg = History(role="user", content=query).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_prompt] + [input_msg])
# chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory)
chain = load_qa_chain(model, chain_type="stuff", memory=memory, prompt=chat_prompt, verbose=True)
# print("file memory:>>",memory)
# print("file chat_prompt:>>",chat_prompt)
# source_documents = []
# for inum, doc in enumerate(docs):
# source = doc.metadata.get("source")
# print("file source: \n", source)
# file = source.split('/')[-1]
# title = file.split('.')[0]
# text = f"""出处 [{inum + 1}] [{title}] \n\n{doc.page_content}\n\n"""
# source_documents.append(text)
# knowledgeFile = KnowledgeFile(
# filename=file_name,
# knowledge_base_name="temp"
# )
# summary = knowledgeFile.file2docs()
# print("file summary: \n", summary)
# summary = ''
# print("file summary: \n", summary)
# Begin a task that runs in the background.
# summary = llm_summary()
# docs =
# 确保input_documents是Document对象列表
context = [Document(page_content=context)]
# task = asyncio.create_task(wrap_done(
# chain.acall({"context": context, "question": query}),
# callback.done),
# )
task = asyncio.create_task(wrap_done(
chain.ainvoke({"input_documents": context,"question": query, "title": file_name, "time":time}, return_only_outputs=True),
callback.done),
)
# print("file_chain:\n", chain)
# if len(source_documents) == 0: # 没有找到相关文档
# source_documents.append(f"""<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>""")
history_summary = ""
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
history_summary += token
yield json.dumps({"answer": token}, ensure_ascii=False)
# yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": answer},
ensure_ascii=False)
question_history = [
{"role": "user", "content": query},
{"role": "assistant", "content": history_summary}
]
question = (await run_sync(
get_llm_model_response,
strategy_name="question_recommend",
llm_model_name=LLM_MODELS[0],
template_prompt_name="question_recommend",
prompt_param_dict={"history": question_history},
temperature=0.3,
max_tokens=512,
)).strip()
formatted = split_questions(question)
logger.info(f"推荐问题: \n{formatted}")
yield json.dumps({"question": formatted}, ensure_ascii=False)
await task
return EventSourceResponse(knowledge_base_chat_iterator())