Files
gangyan/langchain-chat/server/chat/knowledge_base_chat.py

565 lines
34 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from fastapi import Body, Request
from langchain.chains.question_answering import load_qa_chain
from sse_starlette.sse import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
TEMPERATURE,
USE_RERANKER,
RERANKER_MODEL,
RERANKER_MAX_LENGTH,
MODEL_PATH,
MAX_TOKENS,
MAX_CUT_TOKENS,
POLICY_KNOWLEDGE_BASE,
REPORT_KNOWLEDGE_BASE,
JOURNAL_KNOWLEDGE_BASE,
OLD_POLICY_BASE
)
from configs.kb_config import OLD_JOURNAL_BASE
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template, get_format_template
from server.utils import get_strategy_prompt_template
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
import json
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
from server.reranker.reranker import LangchainReranker
from server.utils import embedding_device
from server.chat.policy_fun import add_summary_retrieved_results,get_llm_model_response
from server.chat.policy_fun_iast import get_llm_model_response
import json
from langchain.memory import ConversationSummaryBufferMemory, ConversationBufferWindowMemory, ConversationBufferMemory
from langchain_core.prompts import PromptTemplate
import itertools
from datetime import datetime
import time
from langchain.schema import Document
REPLACEMENT_RULES = [
(OLD_POLICY_BASE, "t_policy_total_bge_new_v2"),
(OLD_JOURNAL_BASE, "t_journal_article_bge_v1")
]
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
fileName: List = Body([], description="文件名称", examples=[["123.txt"]]),
knowledge_base_name_list: list = Body(..., description="多种知识库名称",
examples=[[ "t_policy_total_bge_v1","t_strategy_report_20_bge_v2","t_journal_article_bge_v1"]]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(
SCORE_THRESHOLD,
description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右",
ge=0,
le=2
),
history: List[History] = Body(
[],
description="历史对话",
examples=[[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(
MAX_TOKENS,
description="限制LLM生成Token数量默认None代表模型最大值"
),
prompt_name: str = Body(
"default",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"
),
request: Request = None,
use_summary = True,
use_model_self_response = True,
chunk_size: int = 20000,
min_chunk_size: int = 2000,
summary_model_name = LLM_MODELS[0],
query_rewrite_model_name = LLM_MODELS[0]
):
# 创建集合提高查找效率
original_kb_set = set(knowledge_base_name_list)
new_elements_added = []
# 批量处理替换规则
for old_bases, new_base in REPLACEMENT_RULES:
# 使用集合运算快速找到需要移除的元素
to_remove = original_kb_set & set(old_bases)
if to_remove:
# 使用列表推导式生成新列表(保持原有顺序)
knowledge_base_name_list = [
elem for elem in knowledge_base_name_list
if elem not in to_remove
]
new_elements_added.append(new_base)
# 去重后添加新元素(如果原列表已存在则不添加)
for new_base in new_elements_added:
if new_base not in knowledge_base_name_list:
knowledge_base_name_list.append(new_base)
print(f'========== 当前检索的知识库:{knowledge_base_name_list} ==========')
new_knowledge_base_name_list = knowledge_base_name_list[:]
for knowledge_base_name in knowledge_base_name_list:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
history = [History.from_data(h) for h in history]
# 记录开始时间
start_time = time.time()
history = [History.from_data(h) for h in history]
print(f"========== 当前的对话历史为==========\n{history}")
# 获取当前时间并格式化为YYYYMMDD
current_time = datetime.now().strftime("%Y%m%d")
async def knowledge_base_chat_iterator(
query: str,
top_k: int,
history: Optional[List[History]],
model_name: str = model_name,
prompt_name: str = prompt_name,
) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler()
memory = None
policydocs = []
reportdocs = []
journaldocs = []
personaldocs = []
docs = []
if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None
if prompt_name == "policy_chat":
model_name = LLM_MODELS[0]
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
knowledge = []
self_knowledge = []
user_queries = [] # 初始化列表来收集用户消息
if use_model_self_response:
# 获取大模型本身对用户问题的回答
modelself_response=get_llm_model_response(
strategy_name="self response",
llm_model_name=query_rewrite_model_name,
template_prompt_name="self_response",
prompt_param_dict={"query": query},
temperature=0.01,
max_tokens=512
)
self_knowledge.append(f"""{modelself_response}""")
if len(knowledge_base_name_list) != 0:
# 政策库
if POLICY_KNOWLEDGE_BASE in knowledge_base_name_list:
# 遍历历史消息并收集用户消息
for message in history:
if message.role == 'user':
user_queries.append(message.content)
#改写原问题
search_query = get_llm_model_response(
strategy_name="query rewrite",
llm_model_name=query_rewrite_model_name,
template_prompt_name="query_rewrite_policy",
prompt_param_dict={"query": query, "history": user_queries, "time": current_time},
temperature=0.01,
max_tokens=512
)
print("search_query: ", query)
print("search_history: ", user_queries)
json_string = search_query.strip("```json\n").strip("```")
try: # 防止json格式错误
# 读取改写后的query
data = json.loads(json_string)
policies = data['policies']
search_query = ''
for policy in policies:
search_query += policy
except:
search_query = query
print('policy search query', search_query)
#搜索政策相关的docs
policydocs = await run_in_threadpool(search_docs,
fileName=fileName,
query=search_query,
usr_query=query,
knowledge_base_name=POLICY_KNOWLEDGE_BASE,
top_k=top_k,
score_threshold=score_threshold)
# print('政策数据库共搜索出:',len(policydocs))
#使用概括将只有文章标题的内容总结成段落
if use_summary:
# policydocs = await add_summary_retrieved_results(policydocs, query, 512,chunk_size,min_chunk_size,summary_model_name)
seen_docs = set() # 用于跟踪已见过的标题和内容组合
duplicate_indices = [] # 用于跟踪重复文档的索引
for inum,doc in enumerate(policydocs):
if len(doc.metadata['summary'])>15:
doc_identifier = (doc.metadata['title'], doc.page_content)
# 检查此标识符是否已存在于集合中
if doc_identifier not in seen_docs:
# 如果不存在,将其添加到集合中
seen_docs.add(doc_identifier)
knowledge.append(f"""参考资料[{len(knowledge) + 1}] 文章标题: {doc.metadata['title']} \n文章内容: {doc.metadata['summary']}""")
else:
# 如果存在,将当前索引添加到重复索引列表中
duplicate_indices.append(inum)
else:
duplicate_indices.append(inum)
# 从policydocs中删除重复的文档从后往前删除以防止索引错位
for index in sorted(duplicate_indices, reverse=True):
del policydocs[index]
else:
for inum,doc in enumerate(policydocs):
if doc.metadata["_type"] == "title":
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content} \n文章内容 {doc.metadata['content']}""")
if doc.metadata["_type"] == "content":
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']} \n文章内容 {doc.page_content}""")
new_knowledge_base_name_list.remove(POLICY_KNOWLEDGE_BASE)
# print('政策数据库剩下:',len(policydocs))
# 报告库
if REPORT_KNOWLEDGE_BASE in knowledge_base_name_list:
# 遍历历史消息并收集用户消息
for message in history:
if message.role == 'user':
user_queries.append(message.content)
#先改写原问题
search_query = get_llm_model_response(
strategy_name="query rewrite",
llm_model_name=query_rewrite_model_name,
template_prompt_name="query_rewrite_report",
prompt_param_dict={"query": query, "history": user_queries + [query]},
temperature=0.01,
max_tokens=512
)
print("search_query: ", query)
print("search_history: ", user_queries)
json_string = search_query.strip("```json\n").strip("```")
try: # 防止json格式错误
# 读取改写后的query
data = json.loads(json_string)
policies = data['report']
search_query = ''
for policy in policies:
search_query += policy
except:
search_query = query
print('report search query', search_query)
reportdocs = await run_in_threadpool(search_docs,
fileName=fileName,
query=search_query,
knowledge_base_name=REPORT_KNOWLEDGE_BASE,
top_k=top_k,
score_threshold=score_threshold,
expr = " _type == 'content'")
# print('报告数据库共搜索出:',len(reportdocs))
seen_docs = set() # 用于跟踪已见过的标题和内容组合
duplicate_indices = [] # 用于跟踪重复文档的索引
for inum,doc in enumerate(reportdocs):
doc_identifier = (doc.metadata['source'], doc.page_content)
# 检查此标识符是否已存在于集合中
if doc_identifier not in seen_docs:
# 如果不存在,将其添加到集合中
seen_docs.add(doc_identifier)
# 并将文档信息添加到knowledge列表中
knowledge.append(f"""参考资料[{len(knowledge) + 1}] 报告来源: {doc.metadata['source'].replace('.pdf','')} \n报告内容: {doc.page_content}""")
else:
duplicate_indices.append(inum)
# print('重复报告',doc_identifier)
# 从reportdocs中删除重复的文档从后往前删除以防止索引错位
for index in sorted(duplicate_indices, reverse=True):
del reportdocs[index]
new_knowledge_base_name_list.remove(REPORT_KNOWLEDGE_BASE)
# 期刊库
if JOURNAL_KNOWLEDGE_BASE in knowledge_base_name_list:
# 遍历历史消息并收集用户消息
for message in history:
if message.role == 'user':
user_queries.append(message.content)
#先改写原问题
search_query = get_llm_model_response(
strategy_name="query rewrite",
llm_model_name=query_rewrite_model_name,
template_prompt_name="query_rewrite",
prompt_param_dict={"query": query, "history": user_queries + [query]},
temperature=0.01,
max_tokens=512
)
print("search_query: ", query)
print("search_history: ", user_queries)
json_string = search_query.strip("```json\n").strip("```")
try: # 防止json格式错误
# 读取改写后的query
data = json.loads(json_string)
policies = data['report']
search_query = ''
for policy in policies:
search_query += policy
except:
search_query = query
print('journal search query', search_query)
journaldocs = await run_in_threadpool(search_docs,
fileName=fileName,
query=search_query,
knowledge_base_name=JOURNAL_KNOWLEDGE_BASE,
top_k=top_k,
score_threshold=score_threshold)
# print('期刊数据库共搜索出:',len(journaldocs))
seen_docs = set() # 用于跟踪已见过的标题和内容组合
duplicate_indices = [] # 用于跟踪重复文档的索引
for inum,doc in enumerate(journaldocs):
doc_identifier = (doc.metadata['title'], doc.metadata['abstract'])
# 检查此标识符是否已存在于集合中
if doc_identifier not in seen_docs:
# 如果不存在,将其添加到集合中
seen_docs.add(doc_identifier)
# 并将文档信息添加到knowledge列表中
knowledge.append(f"""参考资料[{len(knowledge) + 1}] 论文标题: {doc.metadata['title']} \n论文摘要: {doc.metadata['abstract']}""")
else:
duplicate_indices.append(inum)
# print('重复期刊',doc_identifier)
# 从journaldocs中删除重复的文档从后往前删除以防止索引错位
for index in sorted(duplicate_indices, reverse=True):
del journaldocs[index]
new_knowledge_base_name_list.remove(JOURNAL_KNOWLEDGE_BASE)
if len(new_knowledge_base_name_list)>0:
# 个人知识库
for knowledge_base_name in new_knowledge_base_name_list:
if knowledge_base_name == 'yj_oa_journal_bge_v2_yejinbak':
knowledge_base_name = 'yj_oa_article_v1_yejinbak' #采集数据代替oa资源
personaldocs = await run_in_threadpool(search_docs,
fileName=fileName,
query=query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
seen_docs = set() # 用于跟踪已见过的标题和内容组合
for inum,doc in enumerate(personaldocs):
doc_identifier = (doc.page_content)
# 检查此标识符是否已存在于集合中
if doc_identifier not in seen_docs:
# 如果不存在,将其添加到集合中
seen_docs.add(doc_identifier)
# 并将文档信息添加到knowledge列表中
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.page_content}""")
else:
personaldocs = await run_in_threadpool(search_docs,
fileName=fileName,
query=query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
seen_docs = set() # 用于跟踪已见过的标题和内容组合
for inum,doc in enumerate(personaldocs):
doc_identifier = (doc.page_content)
# 检查此标识符是否已存在于集合中
if doc_identifier not in seen_docs:
# 如果不存在,将其添加到集合中
seen_docs.add(doc_identifier)
# 并将文档信息添加到knowledge列表中
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.page_content}""")
# context = "\n\n".join(knowledge)
docs = [Document(page_content=k) for k in knowledge]
# print(f"=========================知识库问答参考资料====================\n{docs}\n====================知识库问答参考资料====================")
format_list = ["Abstract Assistant", "Outline Assistant"]
if prompt_name in format_list:
format_template = get_format_template("knowledge_base_chat", "abstract_format")
else:
format_template = get_format_template("knowledge_base_chat", "default")
# 政策知识库
# 相关信息把标题和内容进行整合
if len(knowledge) == 0 and not fileName and prompt_name != "Abstract Assistant":
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
# elif prompt_name == 'default' and "t_policy_total_bge_v1" in knowledge_base_name_list:
# if len(knowledge_base_name_list) == 1: # 如果是科学研究院policy推荐功能则使用如下模板
# prompt_template = get_strategy_prompt_template("knowledge_base_chat", 'iast_policy_chat')
# else:
# prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
else:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
print("prompt_name(no history):", prompt_name)
if history and prompt_name not in ["Question Assistant"]:
prompt_template = get_prompt_template("knowledge_base_chat", prompt_name)
print("prompt_name(with history):", prompt_name)
chat_prompt = PromptTemplate.from_template(template=prompt_template, template_format='jinja2')
# 把history转成memory
memory = ConversationBufferMemory(memory_key="history", input_key="question")
for message in history:
# 检查消息的角色
if message.role == 'user':
# 添加用户消息
memory.chat_memory.add_user_message(message.content)
elif message.role == 'assistant':
# 添加AI消息
memory.chat_memory.add_ai_message(message.content)
else:
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
query = query.replace("原文", "")
chain = load_qa_chain(
model, chain_type="stuff", memory=memory, prompt=chat_prompt, verbose=True
)
# docs = list(itertools.chain(policydocs, reportdocs, journaldocs, personaldocs))
task = asyncio.create_task(wrap_done(
chain.acall({
# "context": context,
"input_documents": docs,
"self_knowledge":self_knowledge,
"history": history,
"question": query,
"file_name": str(fileName),
"format_template": format_template,
"time": current_time
}),
callback.done),
)
source_documents = []
if len(knowledge_base_name_list) != 0:
# 政策库
if POLICY_KNOWLEDGE_BASE in knowledge_base_name_list:
for inum, doc in enumerate(policydocs):
# 获取标题以及详情地址url
filename = doc.metadata.get("title")
# detail_url = doc.metadata.get("source")
detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html"
# parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
# base_url = request.base_url
# url = f"{base_url}knowledge_base/download_doc?" + parameters
# text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n"""
if filename:
# print(doc.metadata.get('_type'), detail_url)
# if doc.metadata.get('_type') == 'title':
filename = filename.replace('\r', '').replace('\n', '')
text = f"""_政策[{len(source_documents) + 1}] [{filename}]({detail_url})_\n"""
# else:
# text = f"""政策: [{len(source_documents) + 1}][{filename}]({detail_url})\n\n{doc.page_content} \n\n"""
else:
# if doc.metadata.get('_type') == 'title':
text = f"""_政策[{len(source_documents) + 1}] [{"原文地址"}]({detail_url})_"""
# else:
# text = f"""政策: [{len(source_documents) + 1}][{"原文地址"}]({detail_url})\n\n{doc.page_content}\n\n"""
source_documents.append(text)
# 报告库
if REPORT_KNOWLEDGE_BASE in knowledge_base_name_list:
for inum, doc in enumerate(reportdocs):
text = f"""_报告[{len(source_documents) + 1}] [{doc.metadata.get("source").replace('.pdf','')}](https://kgo.ckcest.cn/kgo/list?dbId=1010&word=&shortName=ALL&page=1&order=1)_"""
source_documents.append(text)
# 期刊库
if JOURNAL_KNOWLEDGE_BASE in knowledge_base_name_list:
for inum, doc in enumerate(journaldocs):
text = f"""_期刊论文[{len(source_documents) + 1}] [{doc.metadata.get("title")}](https://kgo.ckcest.cn/kgo/detail/1002/ads_journal_article/{doc.metadata.get("ID")}.html)_"""
source_documents.append(text)
#个人知识库
if len(new_knowledge_base_name_list)>0:
for knowledge_base_name in new_knowledge_base_name_list:
if knowledge_base_name == 'yj_oa_journal_bge_v2_yejinbak':
knowledge_base_name = 'yj_oa_article_v1_yejinbak' #采集数据代替oa资源
docs = await run_in_threadpool(search_docs,
fileName=fileName,
query=query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
seen_docs = set() # 用于跟踪已见过的内容组合
for inum,doc in enumerate(docs):
doc_identifier = (doc.page_content)
hasSummary = doc.metadata.get("summary")
if doc_identifier not in seen_docs:
# 如果不存在,将其添加到集合中
seen_docs.add(doc_identifier)
if doc.metadata.get('_type') == 'title' and hasSummary and knowledge_base_name in ["yj_policys_bge_v1_yejinbak","yj_oa_journal_bge_v2_yejinbak","yj_for_journal_bge_v1_yejinbak","yj_ch_journal_bge_v1_yejinbak"]:
text = f"""[{len(source_documents) + 1}] 《{doc.page_content}\n{doc.metadata.get("summary")}\n资料年份:{doc.metadata.get("publish_year")}\n\n"""
elif doc.metadata.get('_type') == 'title' and knowledge_base_name in ["yj_policys_bge_v1_yejinbak","yj_oa_journal_bge_v2_yejinbak","yj_for_journal_bge_v1_yejinbak","yj_ch_journal_bge_v1_yejinbak"]:
text = f"""[{len(source_documents) + 1}] 《{doc.page_content}\n资料年份:{doc.metadata.get("publish_year")}\n\n"""
elif knowledge_base_name in ["yj_policys_bge_v1_yejinbak","yj_oa_journal_bge_v2_yejinbak","yj_for_journal_bge_v1_yejinbak","yj_ch_journal_bge_v1_yejinbak"]:
text = f"""[{len(source_documents) + 1}] 《{doc.metadata.get("title")}\n资料年份:{doc.metadata.get("publish_year")}\n\n"""
else:
# text = f"""参考文档[{len(source_documents) + 1}] 《{doc.metadata.get("source", "").split('.')[0]}》"""
text = f"""参考文档[{len(source_documents) + 1}] [{doc.metadata.get("source")}]()\n"""
source_documents.append(text)
else:
docs = await run_in_threadpool(search_docs,
fileName=fileName,
query=query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
seen_docs = set() # 用于跟踪已见过的内容组合
for inum,doc in enumerate(docs):
doc_identifier = (doc.page_content)
hasSummary = doc.metadata.get("summary")
if doc_identifier not in seen_docs:
# 如果不存在,将其添加到集合中
seen_docs.add(doc_identifier)
if doc.metadata.get('_type') == 'title' and hasSummary and knowledge_base_name in ["yj_policys_bge_v1_yejinbak","yj_oa_journal_bge_v2_yejinbak","yj_for_journal_bge_v1_yejinbak","yj_ch_journal_bge_v1_yejinbak"]:
text = f"""[{len(source_documents) + 1}] 《{doc.page_content}\n{doc.metadata.get("summary")}\n资料年份:{doc.metadata.get("publish_year")}\n\n"""
elif doc.metadata.get('_type') == 'title' and knowledge_base_name in ["yj_policys_bge_v1_yejinbak","yj_oa_journal_bge_v2_yejinbak","yj_for_journal_bge_v1_yejinbak","yj_ch_journal_bge_v1_yejinbak"]:
text = f"""[{len(source_documents) + 1}] 《{doc.page_content}\n资料年份:{doc.metadata.get("publish_year")}\n\n"""
elif knowledge_base_name in ["yj_policys_bge_v1_yejinbak","yj_oa_journal_bge_v2_yejinbak","yj_for_journal_bge_v1_yejinbak","yj_ch_journal_bge_v1_yejinbak"]:
text = f"""[{len(source_documents) + 1}] 《{doc.metadata.get("title")}\n资料年份:{doc.metadata.get("publish_year")}\n\n"""
else:
# text = f"""参考文档[{len(source_documents) + 1}] 《{doc.metadata.get("source", "").split('.')[0]}》"""
text = f"""参考文档[{len(source_documents) + 1}] [{doc.metadata.get("source")}]()\n"""
source_documents.append(text)
# for inum, doc in enumerate(docs):
# filename = doc.metadata.get("source")
# parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename})
# base_url = request.base_url
# url = f"{base_url}knowledge_base/download_doc?" + parameters
# if filename:
# text = f"""出处: [{filename}]({url}) \n\n"""
# else:
# text = f"""出处: [{"原文地址"}]({url}) \n\n"""
# source_documents.append(text)
if len(source_documents) == 0: # 没有找到相关文档
source_documents.append(f"<span style='color:red'>未找到相关文档,该回答为大模型自身能力解答!</span>")
first_token = True # 记录是否为第一个token
if stream:
answer = ""
async for token in callback.aiter():
if first_token:
first_token = False
# 记录第一个token返回的时间
time_elapsed = time.time() - start_time
print(f"接收响应到模型吐出第一个字耗时: {time_elapsed:.2f} seconds")
# Use server-sent-events to stream the response
answer += token
yield json.dumps({"answer": token}, ensure_ascii=False)
# print(f'====返回结果====\n {answer}')
print(f'=====知识库问答模型返回结果=====\n {answer}')
else:
answer = ""
async for token in callback.aiter():
if first_token:
first_token = False
# 记录第一个token返回的时间
time_elapsed = time.time() - start_time
print(f"接收响应到模型吐出第一个字耗时: {time_elapsed:.2f} seconds")
answer += token
yield json.dumps({"answer": answer})
await task
yield json.dumps({"docs": source_documents}, ensure_ascii=False)
return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history, model_name, prompt_name))