Files
gangyan/langchain-chat/server/custom/article_overview.py

194 lines
9.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
from typing import AsyncIterable, List, Optional
from urllib.parse import urlencode
from fastapi import Body, Request
from fastapi.concurrency import run_in_threadpool
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from sse_starlette.sse import EventSourceResponse
from configs import (TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
LLM_MODELS,
MODEL_PATH,
MAX_TOKENS)
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.reranker.reranker import LangchainReranker
from server.utils import BaseResponse, get_prompt_template
from server.utils import embedding_device
from server.utils import wrap_done, get_ChatOpenAI
async def article_overview(query: str = Body("你好", description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["t_policy_total_bce_v1"]),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
MAX_TOKENS,
description="限制LLM生成Token数量默认None代表模型最大值"
),
prompt_name: str = Body(
"Article Overview",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
source_name_list: List[str] = Body([], description="资源列表"),
request: Request = None,
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
query = "帮我对以下文件进行总结 " + ",".join(source_name_list)
if len(source_name_list) > 1:
prompt_name = "Article Overview2"
else:
prompt_name = "Article Overview"
async def article_overview_iterator(
query: str,
model_name: str = model_name,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
docs = []
docs = await run_in_threadpool(kb.get_doc_by_sources_name,
source_name_list=source_name_list)
# 加入reranker
if USE_RERANKER:
reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL, "BAAI/bge-reranker-large")
print("-----------------model path------------------")
print(reranker_model_path)
reranker_model = LangchainReranker(top_n=3,
device=embedding_device(),
max_length=RERANKER_MAX_LENGTH,
model_name_or_path=reranker_model_path
)
print("---------before rerank------------------")
print(docs)
docs = reranker_model.compress_documents(documents=docs,
query=query)
print("---------after rerank------------------")
print(docs)
# context = "\n".join([doc.page_content for doc in docs])
# 相关信息把标题和内容进行整合
if knowledge_base_name == 't_policy_total_bce_v1':
knowledge = []
for doc in docs:
if doc.metadata["_type"] == "title":
knowledge.append(doc.page_content + "\n" + doc.metadata['content'])
if doc.metadata["_type"] == "content":
knowledge.append(doc.metadata['title'] + "\n" + doc.page_content)
context = "\n\n".join(knowledge)
# 非政策知识库
else:
context = "\n".join([doc.page_content for doc in docs])
print(f"context:{context}\n")
if len(docs) == 0: # 如果没有找到相关文档使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
input_msg = History(role="user", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_msg])
print(f"chat_prompt:{chat_prompt}\n")
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall({"context": context, "question": query}),
callback.done),
)
source_documents = []
# 政策知识库
if knowledge_base_name == 't_policy_total_bce_v1':
for inum, doc in enumerate(docs):
# 获取标题以及详情地址url
filename = doc.metadata.get("title")
detail_url = doc.metadata.get("source")
# parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
# base_url = request.base_url
# url = f"{base_url}knowledge_base/download_doc?" + parameters
# text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
if filename:
text = f"""出处: [{filename}]({detail_url}) \n\n"""
else:
text = f"""出处: [{"原文地址"}]({detail_url}) \n\n"""
source_documents.append(text)
# 非政策知识库
else:
for inum, doc in enumerate(docs):
filename = doc.metadata.get("source")
parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
base_url = request.base_url
url = f"{base_url}knowledge_base/download_doc?" + parameters
if filename:
text = f"""出处: [{filename}]({url}) \n\n"""
else:
text = f"""出处: [{"原文地址"}]({url}) \n\n"""
source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps({"answer": token}, ensure_ascii=False)
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield json.dumps({"answer": answer,
"docs": source_documents},
ensure_ascii=False)
await task
return EventSourceResponse(article_overview_iterator(query, model_name=model_name, prompt_name=prompt_name))
class ArticleOverview:
query = "请给我对文件进行一下总结"
def __init__(self):
self._PROMPT_TEMPLATE = """
'<角色> 你是由浪潮开发的知冶大模型中所选定的文件综述助手。</角色> \n\n'
'Your task is to write a detailed summary of the provided {{context}} file. Ensure that your summary is '
'longer than 300 words and captures the essence of the content. Focus on the main points, '
'key findings, and any important implications or conclusions. Maintain an unbiased tone and avoid relying '
'on stereotypes. Organize the summary in a clear and coherent manner, using appropriate headings or '
'bullet points if necessary. Remember to keep the summary concise while preserving the core information. '
'Let\'s start with a brief overview of the file\'s main topic and then delve into the specifics.'
'PLEASE ALWAYS RESPOND IN CHINESE!\n'
'<已知信息>{{ context }}</已知信息>\n'
'<问题>{{ question }}</问题>\n',
"""
self.PROMPT = PromptTemplate(
input_variables=["question", "database_names"],
template=self._PROMPT_TEMPLATE,
)
def query_out(self, knowledge_base_name: str, source_name_list: list):
self.query = "帮我对以下文件进行总结 " + ",".join(source_name_list)
return article_overview(self.query,
knowledge_base_name=knowledge_base_name,
source_name_list=source_name_list
)