from fastapi import Body, Request from langchain.chains.question_answering import load_qa_chain from sse_starlette.sse import EventSourceResponse from fastapi.concurrency import run_in_threadpool from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE, USE_RERANKER, RERANKER_MODEL, RERANKER_MAX_LENGTH, MODEL_PATH, MAX_TOKENS, MAX_CUT_TOKENS, POLICY_KNOWLEDGE_BASE, REPORT_KNOWLEDGE_BASE, JOURNAL_KNOWLEDGE_BASE, OLD_POLICY_BASE ) from configs.kb_config import OLD_JOURNAL_BASE from server.utils import wrap_done, get_ChatOpenAI from server.utils import BaseResponse, get_prompt_template, get_format_template from server.utils import get_strategy_prompt_template from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable, List, Optional import asyncio from langchain.prompts.chat import ChatPromptTemplate from server.chat.utils import History from server.knowledge_base.kb_service.base import KBServiceFactory import json from urllib.parse import urlencode from server.knowledge_base.kb_doc_api import search_docs from server.reranker.reranker import LangchainReranker from server.utils import embedding_device from server.chat.policy_fun import add_summary_retrieved_results,get_llm_model_response from server.chat.policy_fun_iast import get_llm_model_response import json from langchain.memory import ConversationSummaryBufferMemory, ConversationBufferWindowMemory, ConversationBufferMemory from langchain_core.prompts import PromptTemplate import itertools from datetime import datetime import time from langchain.schema import Document REPLACEMENT_RULES = [ (OLD_POLICY_BASE, "t_policy_total_bge_new_v2"), (OLD_JOURNAL_BASE, "t_journal_article_bge_v1") ] async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), fileName: List = Body([], description="文件名称", examples=[["123.txt"]]), knowledge_base_name_list: list = Body(..., description="多种知识库名称", examples=[[ "t_policy_total_bge_v1","t_strategy_report_20_bge_v2","t_journal_article_bge_v1"]]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), score_threshold: float = Body( SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=2 ), history: List[History] = Body( [], description="历史对话", examples=[[ {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, {"role": "assistant", "content": "虎头虎脑"}]] ), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: Optional[int] = Body( MAX_TOKENS, description="限制LLM生成Token数量,默认None代表模型最大值" ), prompt_name: str = Body( "default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)" ), request: Request = None, use_summary = True, use_model_self_response = True, chunk_size: int = 20000, min_chunk_size: int = 2000, summary_model_name = LLM_MODELS[0], query_rewrite_model_name = LLM_MODELS[0] ): # 创建集合提高查找效率 original_kb_set = set(knowledge_base_name_list) new_elements_added = [] # 批量处理替换规则 for old_bases, new_base in REPLACEMENT_RULES: # 使用集合运算快速找到需要移除的元素 to_remove = original_kb_set & set(old_bases) if to_remove: # 使用列表推导式生成新列表(保持原有顺序) knowledge_base_name_list = [ elem for elem in knowledge_base_name_list if elem not in to_remove ] new_elements_added.append(new_base) # 去重后添加新元素(如果原列表已存在则不添加) for new_base in new_elements_added: if new_base not in knowledge_base_name_list: knowledge_base_name_list.append(new_base) print(f'========== 当前检索的知识库:{knowledge_base_name_list} ==========') new_knowledge_base_name_list = knowledge_base_name_list[:] for knowledge_base_name in knowledge_base_name_list: kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") history = [History.from_data(h) for h in history] # 记录开始时间 start_time = time.time() history = [History.from_data(h) for h in history] print(f"========== 当前的对话历史为==========\n{history}") # 获取当前时间并格式化为YYYYMMDD current_time = datetime.now().strftime("%Y%m%d") async def knowledge_base_chat_iterator( query: str, top_k: int, history: Optional[List[History]], model_name: str = model_name, prompt_name: str = prompt_name, ) -> AsyncIterable[str]: nonlocal max_tokens callback = AsyncIteratorCallbackHandler() memory = None policydocs = [] reportdocs = [] journaldocs = [] personaldocs = [] docs = [] if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None if prompt_name == "policy_chat": model_name = LLM_MODELS[0] model = get_ChatOpenAI( model_name=model_name, temperature=temperature, max_tokens=max_tokens, callbacks=[callback], ) knowledge = [] self_knowledge = [] user_queries = [] # 初始化列表来收集用户消息 if use_model_self_response: # 获取大模型本身对用户问题的回答 modelself_response=get_llm_model_response( strategy_name="self response", llm_model_name=query_rewrite_model_name, template_prompt_name="self_response", prompt_param_dict={"query": query}, temperature=0.01, max_tokens=512 ) self_knowledge.append(f"""{modelself_response}""") if len(knowledge_base_name_list) != 0: # 政策库 if POLICY_KNOWLEDGE_BASE in knowledge_base_name_list: # 遍历历史消息并收集用户消息 for message in history: if message.role == 'user': user_queries.append(message.content) #改写原问题 search_query = get_llm_model_response( strategy_name="query rewrite", llm_model_name=query_rewrite_model_name, template_prompt_name="query_rewrite_policy", prompt_param_dict={"query": query, "history": user_queries, "time": current_time}, temperature=0.01, max_tokens=512 ) print("search_query: ", query) print("search_history: ", user_queries) json_string = search_query.strip("```json\n").strip("```") try: # 防止json格式错误 # 读取改写后的query data = json.loads(json_string) policies = data['policies'] search_query = '' for policy in policies: search_query += policy except: search_query = query print('policy search query', search_query) #搜索政策相关的docs policydocs = await run_in_threadpool(search_docs, fileName=fileName, query=search_query, usr_query=query, knowledge_base_name=POLICY_KNOWLEDGE_BASE, top_k=top_k, score_threshold=score_threshold) # print('政策数据库共搜索出:',len(policydocs)) #使用概括将只有文章标题的内容总结成段落 if use_summary: # policydocs = await add_summary_retrieved_results(policydocs, query, 512,chunk_size,min_chunk_size,summary_model_name) seen_docs = set() # 用于跟踪已见过的标题和内容组合 duplicate_indices = [] # 用于跟踪重复文档的索引 for inum,doc in enumerate(policydocs): if len(doc.metadata['summary'])>15: doc_identifier = (doc.metadata['title'], doc.page_content) # 检查此标识符是否已存在于集合中 if doc_identifier not in seen_docs: # 如果不存在,将其添加到集合中 seen_docs.add(doc_identifier) knowledge.append(f"""参考资料[{len(knowledge) + 1}] 文章标题: {doc.metadata['title']} \n文章内容: {doc.metadata['summary']}""") else: # 如果存在,将当前索引添加到重复索引列表中 duplicate_indices.append(inum) else: duplicate_indices.append(inum) # 从policydocs中删除重复的文档(从后往前删除以防止索引错位) for index in sorted(duplicate_indices, reverse=True): del policydocs[index] else: for inum,doc in enumerate(policydocs): if doc.metadata["_type"] == "title": knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content} \n文章内容 {doc.metadata['content']}""") if doc.metadata["_type"] == "content": knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']} \n文章内容 {doc.page_content}""") new_knowledge_base_name_list.remove(POLICY_KNOWLEDGE_BASE) # print('政策数据库剩下:',len(policydocs)) # 报告库 if REPORT_KNOWLEDGE_BASE in knowledge_base_name_list: # 遍历历史消息并收集用户消息 for message in history: if message.role == 'user': user_queries.append(message.content) #先改写原问题 search_query = get_llm_model_response( strategy_name="query rewrite", llm_model_name=query_rewrite_model_name, template_prompt_name="query_rewrite_report", prompt_param_dict={"query": query, "history": user_queries + [query]}, temperature=0.01, max_tokens=512 ) print("search_query: ", query) print("search_history: ", user_queries) json_string = search_query.strip("```json\n").strip("```") try: # 防止json格式错误 # 读取改写后的query data = json.loads(json_string) policies = data['report'] search_query = '' for policy in policies: search_query += policy except: search_query = query print('report search query', search_query) reportdocs = await run_in_threadpool(search_docs, fileName=fileName, query=search_query, knowledge_base_name=REPORT_KNOWLEDGE_BASE, top_k=top_k, score_threshold=score_threshold, expr = " _type == 'content'") # print('报告数据库共搜索出:',len(reportdocs)) seen_docs = set() # 用于跟踪已见过的标题和内容组合 duplicate_indices = [] # 用于跟踪重复文档的索引 for inum,doc in enumerate(reportdocs): doc_identifier = (doc.metadata['source'], doc.page_content) # 检查此标识符是否已存在于集合中 if doc_identifier not in seen_docs: # 如果不存在,将其添加到集合中 seen_docs.add(doc_identifier) # 并将文档信息添加到knowledge列表中 knowledge.append(f"""参考资料[{len(knowledge) + 1}] 报告来源: {doc.metadata['source'].replace('.pdf','')} \n报告内容: {doc.page_content}""") else: duplicate_indices.append(inum) # print('重复报告',doc_identifier) # 从reportdocs中删除重复的文档(从后往前删除以防止索引错位) for index in sorted(duplicate_indices, reverse=True): del reportdocs[index] new_knowledge_base_name_list.remove(REPORT_KNOWLEDGE_BASE) # 期刊库 if JOURNAL_KNOWLEDGE_BASE in knowledge_base_name_list: # 遍历历史消息并收集用户消息 for message in history: if message.role == 'user': user_queries.append(message.content) #先改写原问题 search_query = get_llm_model_response( strategy_name="query rewrite", llm_model_name=query_rewrite_model_name, template_prompt_name="query_rewrite", prompt_param_dict={"query": query, "history": user_queries + [query]}, temperature=0.01, max_tokens=512 ) print("search_query: ", query) print("search_history: ", user_queries) json_string = search_query.strip("```json\n").strip("```") try: # 防止json格式错误 # 读取改写后的query data = json.loads(json_string) policies = data['report'] search_query = '' for policy in policies: search_query += policy except: search_query = query print('journal search query', search_query) journaldocs = await run_in_threadpool(search_docs, fileName=fileName, query=search_query, knowledge_base_name=JOURNAL_KNOWLEDGE_BASE, top_k=top_k, score_threshold=score_threshold) # print('期刊数据库共搜索出:',len(journaldocs)) seen_docs = set() # 用于跟踪已见过的标题和内容组合 duplicate_indices = [] # 用于跟踪重复文档的索引 for inum,doc in enumerate(journaldocs): doc_identifier = (doc.metadata['title'], doc.metadata['abstract']) # 检查此标识符是否已存在于集合中 if doc_identifier not in seen_docs: # 如果不存在,将其添加到集合中 seen_docs.add(doc_identifier) # 并将文档信息添加到knowledge列表中 knowledge.append(f"""参考资料[{len(knowledge) + 1}] 论文标题: {doc.metadata['title']} \n论文摘要: {doc.metadata['abstract']}""") else: duplicate_indices.append(inum) # print('重复期刊',doc_identifier) # 从journaldocs中删除重复的文档(从后往前删除以防止索引错位) for index in sorted(duplicate_indices, reverse=True): del journaldocs[index] new_knowledge_base_name_list.remove(JOURNAL_KNOWLEDGE_BASE) if len(new_knowledge_base_name_list)>0: # 个人知识库 for knowledge_base_name in new_knowledge_base_name_list: if knowledge_base_name == 'yj_oa_journal_bge_v2_yejinbak': knowledge_base_name = 'yj_oa_article_v1_yejinbak' #采集数据代替oa资源 personaldocs = await run_in_threadpool(search_docs, fileName=fileName, query=query, knowledge_base_name=knowledge_base_name, top_k=top_k, score_threshold=score_threshold) seen_docs = set() # 用于跟踪已见过的标题和内容组合 for inum,doc in enumerate(personaldocs): doc_identifier = (doc.page_content) # 检查此标识符是否已存在于集合中 if doc_identifier not in seen_docs: # 如果不存在,将其添加到集合中 seen_docs.add(doc_identifier) # 并将文档信息添加到knowledge列表中 knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.page_content}""") else: personaldocs = await run_in_threadpool(search_docs, fileName=fileName, query=query, knowledge_base_name=knowledge_base_name, top_k=top_k, score_threshold=score_threshold) seen_docs = set() # 用于跟踪已见过的标题和内容组合 for inum,doc in enumerate(personaldocs): doc_identifier = (doc.page_content) # 检查此标识符是否已存在于集合中 if doc_identifier not in seen_docs: # 如果不存在,将其添加到集合中 seen_docs.add(doc_identifier) # 并将文档信息添加到knowledge列表中 knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.page_content}""") # context = "\n\n".join(knowledge) docs = [Document(page_content=k) for k in knowledge] # print(f"=========================知识库问答参考资料====================\n{docs}\n====================知识库问答参考资料====================") format_list = ["Abstract Assistant", "Outline Assistant"] if prompt_name in format_list: format_template = get_format_template("knowledge_base_chat", "abstract_format") else: format_template = get_format_template("knowledge_base_chat", "default") # 政策知识库 # 相关信息把标题和内容进行整合 if len(knowledge) == 0 and not fileName and prompt_name != "Abstract Assistant": prompt_template = get_prompt_template("knowledge_base_chat", "empty") # elif prompt_name == 'default' and "t_policy_total_bge_v1" in knowledge_base_name_list: # if len(knowledge_base_name_list) == 1: # 如果是科学研究院policy推荐功能,则使用如下模板 # prompt_template = get_strategy_prompt_template("knowledge_base_chat", 'iast_policy_chat') # else: # prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) else: prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) print("prompt_name(no history):", prompt_name) if history and prompt_name not in ["Question Assistant"]: prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) print("prompt_name(with history):", prompt_name) chat_prompt = PromptTemplate.from_template(template=prompt_template, template_format='jinja2') # 把history转成memory memory = ConversationBufferMemory(memory_key="history", input_key="question") for message in history: # 检查消息的角色 if message.role == 'user': # 添加用户消息 memory.chat_memory.add_user_message(message.content) elif message.role == 'assistant': # 添加AI消息 memory.chat_memory.add_ai_message(message.content) else: input_prompt = History(role="system", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_prompt]) query = query.replace("原文", "") chain = load_qa_chain( model, chain_type="stuff", memory=memory, prompt=chat_prompt, verbose=True ) # docs = list(itertools.chain(policydocs, reportdocs, journaldocs, personaldocs)) task = asyncio.create_task(wrap_done( chain.acall({ # "context": context, "input_documents": docs, "self_knowledge":self_knowledge, "history": history, "question": query, "file_name": str(fileName), "format_template": format_template, "time": current_time }), callback.done), ) source_documents = [] if len(knowledge_base_name_list) != 0: # 政策库 if POLICY_KNOWLEDGE_BASE in knowledge_base_name_list: for inum, doc in enumerate(policydocs): # 获取标题以及详情地址(url) filename = doc.metadata.get("title") # detail_url = doc.metadata.get("source") detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html" # parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename}) # base_url = request.base_url # url = f"{base_url}knowledge_base/download_doc?" + parameters # text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" if filename: # print(doc.metadata.get('_type'), detail_url) # if doc.metadata.get('_type') == 'title': filename = filename.replace('\r', '').replace('\n', '') text = f"""_政策[{len(source_documents) + 1}] [{filename}]({detail_url})_\n""" # else: # text = f"""政策: [{len(source_documents) + 1}][{filename}]({detail_url})\n\n{doc.page_content} \n\n""" else: # if doc.metadata.get('_type') == 'title': text = f"""_政策[{len(source_documents) + 1}] [{"原文地址"}]({detail_url})_""" # else: # text = f"""政策: [{len(source_documents) + 1}][{"原文地址"}]({detail_url})\n\n{doc.page_content}\n\n""" source_documents.append(text) # 报告库 if REPORT_KNOWLEDGE_BASE in knowledge_base_name_list: for inum, doc in enumerate(reportdocs): text = f"""_报告[{len(source_documents) + 1}] [{doc.metadata.get("source").replace('.pdf','')}](https://kgo.ckcest.cn/kgo/list?dbId=1010&word=&shortName=ALL&page=1&order=1)_""" source_documents.append(text) # 期刊库 if JOURNAL_KNOWLEDGE_BASE in knowledge_base_name_list: for inum, doc in enumerate(journaldocs): text = f"""_期刊论文[{len(source_documents) + 1}] [{doc.metadata.get("title")}](https://kgo.ckcest.cn/kgo/detail/1002/ads_journal_article/{doc.metadata.get("ID")}.html)_""" source_documents.append(text) #个人知识库 if len(new_knowledge_base_name_list)>0: for knowledge_base_name in new_knowledge_base_name_list: if knowledge_base_name == 'yj_oa_journal_bge_v2_yejinbak': knowledge_base_name = 'yj_oa_article_v1_yejinbak' #采集数据代替oa资源 docs = await run_in_threadpool(search_docs, fileName=fileName, query=query, knowledge_base_name=knowledge_base_name, top_k=top_k, score_threshold=score_threshold) seen_docs = set() # 用于跟踪已见过的内容组合 for inum,doc in enumerate(docs): doc_identifier = (doc.page_content) hasSummary = doc.metadata.get("summary") if doc_identifier not in seen_docs: # 如果不存在,将其添加到集合中 seen_docs.add(doc_identifier) if doc.metadata.get('_type') == 'title' and hasSummary and knowledge_base_name in ["yj_policys_bge_v1_yejinbak","yj_oa_journal_bge_v2_yejinbak","yj_for_journal_bge_v1_yejinbak","yj_ch_journal_bge_v1_yejinbak"]: text = f"""[{len(source_documents) + 1}] 《{doc.page_content}》\n{doc.metadata.get("summary")}\n资料年份:{doc.metadata.get("publish_year")}\n\n""" elif doc.metadata.get('_type') == 'title' and knowledge_base_name in ["yj_policys_bge_v1_yejinbak","yj_oa_journal_bge_v2_yejinbak","yj_for_journal_bge_v1_yejinbak","yj_ch_journal_bge_v1_yejinbak"]: text = f"""[{len(source_documents) + 1}] 《{doc.page_content}》\n资料年份:{doc.metadata.get("publish_year")}\n\n""" elif knowledge_base_name in ["yj_policys_bge_v1_yejinbak","yj_oa_journal_bge_v2_yejinbak","yj_for_journal_bge_v1_yejinbak","yj_ch_journal_bge_v1_yejinbak"]: text = f"""[{len(source_documents) + 1}] 《{doc.metadata.get("title")}》\n资料年份:{doc.metadata.get("publish_year")}\n\n""" else: # text = f"""参考文档[{len(source_documents) + 1}]: 《{doc.metadata.get("source", "").split('.')[0]}》""" text = f"""参考文档[{len(source_documents) + 1}] [{doc.metadata.get("source")}]()\n""" source_documents.append(text) else: docs = await run_in_threadpool(search_docs, fileName=fileName, query=query, knowledge_base_name=knowledge_base_name, top_k=top_k, score_threshold=score_threshold) seen_docs = set() # 用于跟踪已见过的内容组合 for inum,doc in enumerate(docs): doc_identifier = (doc.page_content) hasSummary = doc.metadata.get("summary") if doc_identifier not in seen_docs: # 如果不存在,将其添加到集合中 seen_docs.add(doc_identifier) if doc.metadata.get('_type') == 'title' and hasSummary and knowledge_base_name in ["yj_policys_bge_v1_yejinbak","yj_oa_journal_bge_v2_yejinbak","yj_for_journal_bge_v1_yejinbak","yj_ch_journal_bge_v1_yejinbak"]: text = f"""[{len(source_documents) + 1}] 《{doc.page_content}》\n{doc.metadata.get("summary")}\n资料年份:{doc.metadata.get("publish_year")}\n\n""" elif doc.metadata.get('_type') == 'title' and knowledge_base_name in ["yj_policys_bge_v1_yejinbak","yj_oa_journal_bge_v2_yejinbak","yj_for_journal_bge_v1_yejinbak","yj_ch_journal_bge_v1_yejinbak"]: text = f"""[{len(source_documents) + 1}] 《{doc.page_content}》\n资料年份:{doc.metadata.get("publish_year")}\n\n""" elif knowledge_base_name in ["yj_policys_bge_v1_yejinbak","yj_oa_journal_bge_v2_yejinbak","yj_for_journal_bge_v1_yejinbak","yj_ch_journal_bge_v1_yejinbak"]: text = f"""[{len(source_documents) + 1}] 《{doc.metadata.get("title")}》\n资料年份:{doc.metadata.get("publish_year")}\n\n""" else: # text = f"""参考文档[{len(source_documents) + 1}]: 《{doc.metadata.get("source", "").split('.')[0]}》""" text = f"""参考文档[{len(source_documents) + 1}] [{doc.metadata.get("source")}]()\n""" source_documents.append(text) # for inum, doc in enumerate(docs): # filename = doc.metadata.get("source") # parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name": filename}) # base_url = request.base_url # url = f"{base_url}knowledge_base/download_doc?" + parameters # if filename: # text = f"""出处: [{filename}]({url}) \n\n""" # else: # text = f"""出处: [{"原文地址"}]({url}) \n\n""" # source_documents.append(text) if len(source_documents) == 0: # 没有找到相关文档 source_documents.append(f"未找到相关文档,该回答为大模型自身能力解答!") first_token = True # 记录是否为第一个token if stream: answer = "" async for token in callback.aiter(): if first_token: first_token = False # 记录第一个token返回的时间 time_elapsed = time.time() - start_time print(f"接收响应到模型吐出第一个字耗时: {time_elapsed:.2f} seconds") # Use server-sent-events to stream the response answer += token yield json.dumps({"answer": token}, ensure_ascii=False) # print(f'====返回结果====\n {answer}') print(f'=====知识库问答模型返回结果=====\n {answer}') else: answer = "" async for token in callback.aiter(): if first_token: first_token = False # 记录第一个token返回的时间 time_elapsed = time.time() - start_time print(f"接收响应到模型吐出第一个字耗时: {time_elapsed:.2f} seconds") answer += token yield json.dumps({"answer": answer}) await task yield json.dumps({"docs": source_documents}, ensure_ascii=False) return EventSourceResponse(knowledge_base_chat_iterator(query, top_k, history, model_name, prompt_name))