import asyncio import json from typing import AsyncIterable, List, Optional from urllib.parse import urlencode from fastapi import Body, Request from fastapi.concurrency import run_in_threadpool from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.chains import LLMChain from langchain.prompts import PromptTemplate from langchain.prompts.chat import ChatPromptTemplate from sse_starlette.sse import EventSourceResponse from configs import (TEMPERATURE, USE_RERANKER, RERANKER_MODEL, RERANKER_MAX_LENGTH, LLM_MODELS, MODEL_PATH, MAX_TOKENS) from server.chat.utils import History from server.knowledge_base.kb_service.base import KBServiceFactory from server.reranker.reranker import LangchainReranker from server.utils import BaseResponse, get_prompt_template from server.utils import embedding_device from server.utils import wrap_done, get_ChatOpenAI async def article_overview(query: str = Body("你好", description="用户输入", examples=["你好"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["t_policy_total_bce_v1"]), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body( MAX_TOKENS, description="限制LLM生成Token数量,默认None代表模型最大值" ), prompt_name: str = Body( "Article Overview", description="使用的prompt模板名称(在configs/prompt_config.py中配置)" ), source_name_list: List[str] = Body([], description="资源列表"), request: Request = None, ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") query = "帮我对以下文件进行总结 :" + ",".join(source_name_list) if len(source_name_list) > 1: prompt_name = "Article Overview2" else: prompt_name = "Article Overview" async def article_overview_iterator( query: str, model_name: str = model_name, prompt_name: str = prompt_name, ) -> AsyncIterable[str]: nonlocal max_tokens callback = AsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None model = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=[callback], ) docs = [] docs = await run_in_threadpool(kb.get_doc_by_sources_name, source_name_list=source_name_list) # 加入reranker if USE_RERANKER: reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large") print("-----------------model path------------------") print(reranker_model_path) reranker_model = LangchainReranker(top_n=3, device=embedding_device(), max_length=RERANKER_MAX_LENGTH, model_name_or_path=reranker_model_path ) print("---------before rerank------------------") print(docs) docs = reranker_model.compress_documents(documents=docs, query=query) print("---------after rerank------------------") print(docs) # context = "\n".join([doc.page_content for doc in docs]) # 相关信息把标题和内容进行整合 if knowledge_base_name == 't_policy_total_bce_v1': knowledge = [] for doc in docs: if doc.metadata["_type"] == "title": knowledge.append(doc.page_content + "\n" + doc.metadata['content']) if doc.metadata["_type"] == "content": knowledge.append(doc.metadata['title'] + "\n" + doc.page_content) context = "\n\n".join(knowledge) # 非政策知识库 else: context = "\n".join([doc.page_content for doc in docs]) print(f"context:{context}\n") if len(docs) == 0: # 如果没有找到相关文档,使用empty模板 prompt_template = get_prompt_template("knowledge_base_chat", "empty") else: prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) input_msg = History(role="user", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_msg]) print(f"chat_prompt:{chat_prompt}\n") chain = LLMChain(prompt=chat_prompt, llm=model) # Begin a task that runs in the background. task = asyncio.create_task(wrap_done( chain.acall({"context": context, "question": query}), callback.done), ) source_documents = [] # 政策知识库 if knowledge_base_name == 't_policy_total_bce_v1': for inum, doc in enumerate(docs): # 获取标题以及详情地址(url) filename = doc.metadata.get("title") detail_url = doc.metadata.get("source") # parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename}) # base_url = request.base_url # url = f"{base_url}knowledge_base/download_doc?" + parameters # text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" if filename: text = f"""出处: [{filename}]({detail_url}) \n\n""" else: text = f"""出处: [{"原文地址"}]({detail_url}) \n\n""" source_documents.append(text) # 非政策知识库 else: for inum, doc in enumerate(docs): filename = doc.metadata.get("source") parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename}) base_url = request.base_url url = f"{base_url}knowledge_base/download_doc?" + parameters if filename: text = f"""出处: [{filename}]({url}) \n\n""" else: text = f"""出处: [{"原文地址"}]({url}) \n\n""" source_documents.append(text) if len(source_documents) == 0: # 没有找到相关文档 source_documents.append(f"未找到相关文档,该回答为大模型自身能力解答!") if stream: async for token in callback.aiter(): # Use server-sent-events to stream the response yield json.dumps({"answer": token}, ensure_ascii=False) yield json.dumps({"docs": source_documents}, ensure_ascii=False) else: answer = "" async for token in callback.aiter(): answer += token yield json.dumps({"answer": answer, "docs": source_documents}, ensure_ascii=False) await task return EventSourceResponse(article_overview_iterator(query, model_name=model_name, prompt_name=prompt_name)) class ArticleOverview: query = "请给我对文件进行一下总结" def __init__(self): self._PROMPT_TEMPLATE = """ '<角色> 你是由浪潮开发的知冶大模型中所选定的文件综述助手。 \n\n' 'Your task is to write a detailed summary of the provided {{context}} file. Ensure that your summary is ' 'longer than 300 words and captures the essence of the content. Focus on the main points, ' 'key findings, and any important implications or conclusions. Maintain an unbiased tone and avoid relying ' 'on stereotypes. Organize the summary in a clear and coherent manner, using appropriate headings or ' 'bullet points if necessary. Remember to keep the summary concise while preserving the core information. ' 'Let\'s start with a brief overview of the file\'s main topic and then delve into the specifics.' 'PLEASE ALWAYS RESPOND IN CHINESE!\n' '<已知信息>{{ context }}\n' '<问题>{{ question }}\n', """ self.PROMPT = PromptTemplate( input_variables=["question", "database_names"], template=self._PROMPT_TEMPLATE, ) def query_out(self, knowledge_base_name: str, source_name_list: list): self.query = "帮我对以下文件进行总结 :" + ",".join(source_name_list) return article_overview(self.query, knowledge_base_name=knowledge_base_name, source_name_list=source_name_list )