from fastapi import Body, File, Form, UploadFile from sse_starlette.sse import EventSourceResponse from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) from server.chat.agent_chat_test import run_sync from server.chat.policy_fun_iast import get_llm_model_response from server.utils import (wrap_done, get_ChatOpenAI, BaseResponse, get_prompt_template, get_temp_dir, run_in_thread_pool) from server.knowledge_base.kb_cache.faiss_cache import memo_faiss_pool from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable, List, Optional import asyncio from langchain.prompts.chat import ChatPromptTemplate from server.chat.utils import History, split_questions from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter from server.knowledge_base.utils import KnowledgeFile import json import os from pathlib import Path from langchain.chains.question_answering import load_qa_chain from langchain.memory import ConversationBufferMemory,ConversationSummaryMemory, ConversationBufferWindowMemory from langchain.docstore.document import Document from langchain_core.prompts import PromptTemplate from datetime import datetime from server.knowledge_base.kb_service.base import TextRank from configs.basic_config import * # def _parse_files_in_thread( # files: List[UploadFile], # dir: str, # zh_title_enhance: bool, # chunk_size: int, # chunk_overlap: int, # ): # """ # 通过多线程将上传的文件保存到对应目录内。 # 生成器返回保存结果:[success or error, filename, msg, docs] # """ # def parse_file(file: UploadFile) -> dict: # ''' # 保存单个文件。 # ''' # try: # filename = file.filename # file_path = os.path.join(dir, filename) # file_content = file.file.read() # 读取上传文件的内容 # if not os.path.isdir(os.path.dirname(file_path)): # os.makedirs(os.path.dirname(file_path)) # with open(file_path, "wb") as f: # f.write(file_content) # kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp") # kb_file.filepath = file_path # docs = kb_file.file2text(zh_title_enhance=zh_title_enhance, # chunk_size=chunk_size, # chunk_overlap=chunk_overlap) # return True, filename, f"成功上传文件 {filename}", docs # except Exception as e: # msg = f"{filename} 文件上传失败,报错信息为: {e}" # return False, filename, msg, [] # params = [{"file": file} for file in files] # for result in run_in_thread_pool(parse_file, params=params): # yield result # def upload_temp_docs( # files: List[UploadFile] = File(..., description="上传文件,支持多文件"), # prev_id: str = Form(None, description="前知识库ID"), # chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), # chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), # zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), # ) -> BaseResponse: # ''' # 将文件保存到临时目录,并进行向量化。 # 返回临时目录名称作为ID,同时也是临时向量库的ID。 # ''' # if prev_id is not None: # memo_faiss_pool.pop(prev_id) # failed_files = [] # documents = [] # path, id = get_temp_dir(prev_id) # for success, file, msg, docs in _parse_files_in_thread(files=files, # dir=path, # zh_title_enhance=zh_title_enhance, # chunk_size=chunk_size, # chunk_overlap=chunk_overlap): # if success: # documents += docs # else: # failed_files.append({file: msg}) # with memo_faiss_pool.load_vector_store(id).acquire() as vs: # vs.add_documents(documents) # return BaseResponse(data={"id": id, "failed_files": failed_files}) def _parse_files_in_thread( files: List[UploadFile], dir: str, zh_title_enhance: bool, chunk_size: int, chunk_overlap: int, ): """ 通过多线程将上传的文件保存到对应目录内。 生成器返回保存结果:[success or error, filename, msg, docs] """ def parse_file(file: UploadFile) -> dict: ''' 保存单个文件。 ''' try: filename = file.filename file_path = os.path.join(dir, filename) file_content = file.file.read() # 读取上传文件的内容 if not os.path.isdir(os.path.dirname(file_path)): os.makedirs(os.path.dirname(file_path)) with open(file_path, "wb") as f: f.write(file_content) kb_file = KnowledgeFile(filename=filename, knowledge_base_name="temp") kb_file.filepath = file_path docs = kb_file.file2text(zh_title_enhance=zh_title_enhance, chunk_size=chunk_size, chunk_overlap=chunk_overlap) for doc in docs: if isinstance(doc, Document): # 去除分词处理多余换行符 doc.page_content = doc.page_content.replace('\n', '') return True, filename, f"成功上传文件 {filename}", docs except Exception as e: msg = f"{filename} 文件上传失败,报错信息为: {e}" return False, filename, msg, [] params = [{"file": file} for file in files] context = "" for result in run_in_thread_pool(parse_file, params=params): yield result if result[0]: # success for doc in result[3]: # docs context += doc.page_content + "\n" return context def generate_summary(text: str) -> str: # 根据文本长度,每 100 字生成一句摘要,最多生成 300 句 num_sentences = min(len(text) // 50, 300) # num_sentences = 80 # 使用 TextRank 算法生成摘要 summary = TextRank(text, num_sentences=num_sentences) return summary @ timing_decorator def upload_temp_docs( files: List[UploadFile] = File(..., description="上传文件,支持多文件"), prev_id: str = Form(None, description="前知识库ID"), chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"), chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"), zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"), ) -> BaseResponse: ''' 将文件保存到临时目录,并返回临时目录名称作为ID,同时返回文件的全文。 ''' if prev_id is not None: memo_faiss_pool.pop(prev_id) failed_files = [] documents = [] path, id = get_temp_dir(prev_id) context = "" summary = "" for success, file, msg, docs in _parse_files_in_thread(files=files, dir=path, zh_title_enhance=zh_title_enhance, chunk_size=chunk_size, chunk_overlap=chunk_overlap): if success: documents += docs for doc in docs: context += doc.page_content + "\n" if len(context) > 30000: summary = generate_summary(context) else: failed_files.append({file: msg}) return BaseResponse(data={"id": id, "context": context, "summary": summary}) async def file_chat(query: str = Body(..., description="用户输入", examples=["你好"]), file_name: str = Body("", description="文件名称", examples=["123.txt"]), # knowledge_id: str = Body(..., description="临时知识库ID"), # top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), # score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2), history: List[History] = Body([], description="历史对话", examples=[[ {"role": "user", "content": "现在开始针对我上传的文件回答我的问题"}, {"role": "assistant", "content": "好的,让我们开始看吧"}]] ), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body("0.5", description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("file_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), context: str = Body("", description="文件内容")): # if knowledge_id not in memo_faiss_pool.keys(): # return BaseResponse(code=404, msg=f"未找到临时知识库 {knowledge_id},请先上传文件") # history = [History.from_data(h) for h in history] # # 调用模型进行总结 # def llm_summary( # llm_model_name: "str", # prompt_template: str, # temperature: float, # max_tokens: int, # ) -> str: # '''调用大模型进行总结''' # # 读取指定的大模型,这里不能加入callback,否则会把这部分模型响应加入最终的回答 # model = get_ChatOpenAI( # model_name=llm_model_name, # temperature=temperature, # max_tokens=max_tokens, # callbacks=[], # ) # # 获取prompt # prompt_template = get_prompt_template("knowledge_base_chat", "file_summary") # input_msg = History(role="system", content=prompt_template).to_msg_template(False) # prompt = ChatPromptTemplate.from_messages([input_msg]) # # 获取模型响应 # llm_chain = LLMChain(prompt=prompt, llm=model) # summary = llm_chain.run(prompt_param_dict) # return summary async def knowledge_base_chat_iterator() -> AsyncIterable[str]: nonlocal history, context, max_tokens callback = AsyncIteratorCallbackHandler() memory = None # 获取当前时间并格式化为YYYYMMDD time = datetime.now().strftime("%Y%m%d") if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None model = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=[callback], ) # embed_func = EmbeddingsFunAdapter() # embeddings = await embed_func.aembed_query(query) # with memo_faiss_pool.acquire(knowledge_id) as vs: # docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) # docs = [x[0] for x in docs] # context = "\n".join([summary]) # if len(kdocs) == 0: ## 如果没有找到相关文档,使用Empty模板 # prompt_template = get_prompt_template("knowledge_base_chat", "empty") # else: # prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) # input_msg = History(role="user", content=prompt_template).to_msg_template(False) # chat_prompt = ChatPromptTemplate.from_messages( # [i.to_msg_template() for i in history] + [input_msg]) if history: history = [History.from_data(h) for h in history] prompt_template = get_prompt_template("knowledge_base_chat", "file_chat_history") # input_prompt = History(role="system", content=prompt_template).to_msg_template(False) # input_msg = History(role="user", content=query).to_msg_template(False) # chat_prompt = ChatPromptTemplate.from_messages([input_prompt] + [input_msg]) chat_prompt = PromptTemplate.from_template(prompt_template) # 把history转成memory memory = ConversationBufferWindowMemory(k=1, input_key="question") # memory = ConversationSummaryMemory(llm=model) for message in history: # 检查消息的角色 if message.role == 'user': # 添加用户消息 memory.chat_memory.add_user_message(message.content) elif message.role == 'assistant': # 添加AI消息 memory.chat_memory.add_ai_message(message.content) else: prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) input_prompt = History(role="system", content=prompt_template).to_msg_template(False) input_msg = History(role="user", content=query).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_prompt] + [input_msg]) # chain = LLMChain(prompt=chat_prompt, llm=model, memory=memory) chain = load_qa_chain(model, chain_type="stuff", memory=memory, prompt=chat_prompt, verbose=True) # print("file memory:>>",memory) # print("file chat_prompt:>>",chat_prompt) # source_documents = [] # for inum, doc in enumerate(docs): # source = doc.metadata.get("source") # print("file source: \n", source) # file = source.split('/')[-1] # title = file.split('.')[0] # text = f"""出处 [{inum + 1}] [{title}] \n\n{doc.page_content}\n\n""" # source_documents.append(text) # knowledgeFile = KnowledgeFile( # filename=file_name, # knowledge_base_name="temp" # ) # summary = knowledgeFile.file2docs() # print("file summary: \n", summary) # summary = '' # print("file summary: \n", summary) # Begin a task that runs in the background. # summary = llm_summary() # docs = # 确保input_documents是Document对象列表 context = [Document(page_content=context)] # task = asyncio.create_task(wrap_done( # chain.acall({"context": context, "question": query}), # callback.done), # ) task = asyncio.create_task(wrap_done( chain.ainvoke({"input_documents": context,"question": query, "title": file_name, "time":time}, return_only_outputs=True), callback.done), ) # print("file_chain:\n", chain) # if len(source_documents) == 0: # 没有找到相关文档 # source_documents.append(f"""未找到相关文档,该回答为大模型自身能力解答!""") history_summary = "" if stream: async for token in callback.aiter(): # Use server-sent-events to stream the response history_summary += token yield json.dumps({"answer": token}, ensure_ascii=False) # yield json.dumps({"docs": source_documents}, ensure_ascii=False) else: answer = "" async for token in callback.aiter(): answer += token yield json.dumps({"answer": answer}, ensure_ascii=False) question_history = [ {"role": "user", "content": query}, {"role": "assistant", "content": history_summary} ] question = (await run_sync( get_llm_model_response, strategy_name="question_recommend", llm_model_name=LLM_MODELS[0], template_prompt_name="question_recommend", prompt_param_dict={"history": question_history}, temperature=0.3, max_tokens=512, )).strip() formatted = split_questions(question) logger.info(f"推荐问题: \n{formatted}") yield json.dumps({"question": formatted}, ensure_ascii=False) await task return EventSourceResponse(knowledge_base_chat_iterator())