from fastapi import Body, Request from sse_starlette.sse import EventSourceResponse from fastapi.concurrency import run_in_threadpool from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, SELF_TOP_K, SELF_SCORE_THRESHOLD, TEMPERATURE, SELF_TEMPERATURE, USE_RERANKER, RERANKER_MODEL, RERANKER_MAX_LENGTH, SELF_MAX_TOKENS, SELF_USE_RERANKER, MODEL_PATH) from server.chat.policy_fun_iast import get_llm_model_response from server.utils import wrap_done, get_ChatOpenAI from server.utils import BaseResponse, get_prompt_template from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable, List, Optional import asyncio from langchain.prompts.chat import ChatPromptTemplate from server.chat.utils import History, get_first_sentence_by_regex, get_text_by_regex from server.knowledge_base.kb_service.base import KBServiceFactory, TextRank import json from urllib.parse import urlencode from server.knowledge_base.kb_doc_api import search_self_docs from server.reranker.reranker import LangchainReranker from server.utils import embedding_device from configs.basic_config import * from langchain.memory import ConversationBufferMemory from langchain.chains.question_answering import load_qa_chain async def self_kb_chat( query: str = Body(..., description="用户输入", examples=["智慧科协是什么"]), quote: str = Body(..., description="用户引用的文段,引用问答时传该参数", examples=["今年,“智慧科协2.0”要持之以恒深入贯彻落实习近平总书记的重要指示精神"]), # word: str = Body(..., description="用户需要解释的名词,名词解释时传该参数", examples=["GDP"]), fileNames: List = Body([], description="文件名称", examples=[["孟庆海同志在“智慧科协2.0”5·30场景建设工作部署会议上的讲话.docx"]]), knowledge_base_name_list: list = Body(..., description="知识库列表", examples=[[ "p_cast0101011"]]), history: List[History] = Body( [], description="历史对话", examples=[[ {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, {"role": "assistant", "content": "虎头虎脑"}]] ), stream: bool = Body(True, description="流式输出"), web_search: bool = Body(False, description="是否开启联网搜索"), ): """ 个人知识库对话api\n -入参信息:\n query: 用户输入\n quote: 用户引用的文段,引用问答时传该参数\n fileNames: 文件名称\n knowledge_base_name: 知识库名称\n history: 历史对话\n stream: 是否流式输出\n """ logger.info(f"个人知识库对话入参:\nquery:{query}\nquote:{quote}\nfileNames:{fileNames}\nknowledge_base_name_list:{knowledge_base_name_list}\nhistory:{history}\nstream:{stream}") for knowledge_base_name in knowledge_base_name_list: kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") history = [History.to_msg_tuple(h) for h in history] async def knowledge_base_chat_iterator( query: str, model_name: str = LLM_MODELS[0], model_name1: str = LLM_MODELS[0], prompt_name: str = "self_default", ) -> AsyncIterable[str]: nonlocal fileNames, history callback = AsyncIteratorCallbackHandler() model = get_ChatOpenAI( model_name=model_name, temperature=SELF_TEMPERATURE, max_tokens=SELF_MAX_TOKENS, callbacks=[callback], ) model1 = get_ChatOpenAI( model_name=model_name1, temperature=SELF_TEMPERATURE, max_tokens=SELF_MAX_TOKENS, callbacks=[callback], ) # 改写原问题 # 遍历历史消息并收集用户消息 user_queries = [] # 初始化列表来收集用户消息 for message in history: role, content = message # 解包元组 if role == 'user': user_queries.append(content) search_query = get_llm_model_response( strategy_name="self_query_rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="self_query_rewrite", prompt_param_dict={"query": query, "history": user_queries, "quote": quote}, temperature=0.01, max_tokens=512 ) logger.info(f"个人知识库问答query: {query}") logger.info(f"个人知识库问答query_history: {user_queries}") json_string = search_query.strip("```json\n").strip("```") try: # 防止json格式错误 # 读取改写后的query data = json.loads(json_string) query = data['query'] search_query = '' for q in query: search_query += q except: search_query = query logger.info(f"个人知识库问答search_query: {search_query}") self_kb_route=get_llm_model_response( strategy_name="self_kb_route", llm_model_name=LLM_MODELS[0], template_prompt_name="self_kb_route", prompt_param_dict={"query": query}, temperature=0.01, max_tokens=512 ) try: if self_kb_route == '0': logger.info(f"个人知识库问答路由结果:【全局问题】") docs = await run_in_threadpool( search_self_docs, query="", fileNames=fileNames, knowledge_base_name=knowledge_base_name, top_k=999, score_threshold=2 ) elif self_kb_route == '1': logger.info(f"个人知识库问答路由结果:【局部问题】") docs = await run_in_threadpool( search_self_docs, query=search_query, fileNames=fileNames, knowledge_base_name=knowledge_base_name, top_k=SELF_TOP_K, score_threshold=SELF_SCORE_THRESHOLD ) except Exception as e: logger.error(f"个人知识库问答路由错误: {self_kb_route}", exc_info=True) docs = [] logger.info(f"个人知识库问答source_documents: {len(docs)}条") # 联网搜索 web_search_context = "" web_search_results = [] # 保存搜索结果供后面引用 if web_search: try: from server.chat.ZhipuSearchAPI import ZhipuSearchAPIWrapper searcher = ZhipuSearchAPIWrapper() web_results = searcher.zhipu_search(search_query) web_search_results = web_results[:5] if web_results else [] if web_results: web_parts = [] for i, r in enumerate(web_results[:5], 1): title = r.get("title", "") content = r.get("content", "")[:300] url = r.get("url", "") web_parts.append(f"[{i}] {title}\n{content}\n来源: {url}") web_search_context = "\n\n【联网搜索结果】\n" + "\n\n".join(web_parts) logger.info(f"联网搜索获取到 {len(web_results)} 条结果") except Exception as e: logger.error(f"联网搜索失败: {e}") # if SELF_USE_RERANKER: # reranker_model_path = MODEL_PATH["reranker"].get(RERANKER_MODEL,"BAAI/bge-reranker-large") # print("-----------------model path------------------") # print(reranker_model_path) # reranker_model = LangchainReranker(top_n=SELF_TOP_K, # device=embedding_device(), # max_length=RERANKER_MAX_LENGTH, # model_name_or_path=reranker_model_path # ) # for idx, doc in enumerate(docs, start=1): # print(f"{idx}: score={doc.score}") # docs = reranker_model.compress_documents(documents=docs, # query=query) # print("---------after rerank------------------") # for idx, doc in enumerate(docs, start=1): # print(f"{idx}: score={doc.score}") # context = "\n".join([doc.page_content for doc in docs]) #使用load_qa_chain需要送入DocumentWithVSId类型的资料 # 判断是否找到相关文档 if len(docs) == 0: prompt_name = "self_empty" # 如果没有找到相关文档,使用empty模板 elif quote: # 根据quote的值选择不同的模板 prompt_name = "self_quote" if quote else prompt_name # elif word: # 根据word的值选择不同的模板 # prompt_name = "word_explain" if word else prompt_name # 获取模板并生成消息 prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) input_msg = History(role="system", content=prompt_template).to_msg_template(False) chat_prompt = ChatPromptTemplate.from_messages([input_msg]) if '0' in self_kb_route: context = "\n".join([doc.page_content for doc in docs]).strip("xa0") logger.info(f"个人知识库问答 context 长度:{len(context)}") context = context[:30000] if len(context)>30000 else context if web_search_context: context += web_search_context logger.info(f"最终 context 长度:{len(context)}") if history: history = history if len(history) < 20000 else TextRank(history,num_sentences=1) chain = LLMChain(prompt=chat_prompt, llm=model1, verbose=True) task = asyncio.create_task(wrap_done( chain.acall({"context": context, "question": query, "history": history, "quote": quote, "fileName":fileNames}), callback.done), ) elif '1' in self_kb_route: # 联网搜索结果作为额外文档加入 if web_search_context: from langchain.docstore.document import Document as LCDocument docs.append(LCDocument(page_content=web_search_context, metadata={"source": "web_search"})) chain = load_qa_chain( model, chain_type="stuff", prompt=chat_prompt, verbose=True ) task = asyncio.create_task(wrap_done( chain.acall({"input_documents": docs, "question": query, "history": history, "quote": quote, "fileName":fileNames}), callback.done), ) # source_documents = [] # seen_texts = set() # 记录已出现过的 processed_text # counter = 1 # 初始化计数器 # for doc in docs: # text = doc.metadata.get("summary", "") # # processed_text = get_first_sentence_by_regex(text) # processed_text = get_text_by_regex(text) # # 如果 processed_text 不在 seen_texts 中,才添加到结果中 # if processed_text and processed_text not in seen_texts: # source_document = f"[{counter}] {processed_text}" # source_documents.append(source_document) # seen_texts.add(processed_text) # 标记为已出现 # counter += 1 if stream: async for token in callback.aiter(): # Use server-sent-events to stream the response yield json.dumps({"text": token}, ensure_ascii=False) else: answer = "" async for token in callback.aiter(): answer += token response = {"text": answer} yield json.dumps(response, ensure_ascii=False) await task source_documents = [] if len(docs) == 0 and not web_search_context: source_documents.append(f"""暂未从本篇文献中找到答案,该回答为大模型自身能力解答!""") else: if len(docs) > 0: source_documents.append(f"""[{len(source_documents) + 1}] [{docs[0].metadata.get("source")}]()\n""") # 联网搜索结果链接 if web_search_results: for r in web_search_results: title = r.get("title", "").replace("\n", "") url = r.get("url", "") if title and url: source_documents.append(f"""[{len(source_documents) + 1}] [{title}]({url})\n""") yield json.dumps({"docs": source_documents}, ensure_ascii=False) return EventSourceResponse(knowledge_base_chat_iterator(query))