from datetime import datetime import hashlib import re from fastapi import Body, Request from sse_starlette.sse import EventSourceResponse from fastapi.concurrency import run_in_threadpool from configs import ( LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, SELF_TOP_K, SELF_SCORE_THRESHOLD, TEMPERATURE, SELF_TEMPERATURE, USE_RERANKER, RERANKER_MODEL, RERANKER_MAX_LENGTH, SELF_MAX_TOKENS, SELF_USE_RERANKER, MODEL_PATH ) from server.chat.ZhipuSearchAPI import ZhipuSearchAPIWrapper from server.chat.policy_fun_iast import get_llm_model_response from server.knowledge_base.model.kb_document_model import DocumentWithVSId from server.utils import wrap_done, get_ChatOpenAI from server.utils import BaseResponse, get_prompt_template from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable, List, Optional, Dict, Tuple, Union, AsyncGenerator import asyncio from langchain.prompts.chat import ChatPromptTemplate from server.chat.utils import History, get_similar_documents, get_text_by_regex from server.knowledge_base.kb_service.base import KBServiceFactory import json from server.knowledge_base.kb_doc_api import search_self_docs from configs.basic_config import * from langchain.chains.question_answering import load_qa_chain from typing import Generator logger = logging.getLogger(__name__) async def word_explain( query: str = Body(..., description="用户输入", examples=["智慧科协"]), fileNames: List = Body([], description="文件名称", examples=[["孟庆海同志在“智慧科协2.0”5·30场景建设工作部署会议上的讲话.docx"]]), knowledge_base_name_list: List[str] = Body(..., description="知识库列表", examples=[["p_cast0101011"]]), history: List[History] = Body( [], description="历史对话", examples=[[ {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, {"role": "assistant", "content": "虎头虎脑"}] ] ), # stream: bool = Body(False, description="流式输出"), ) -> EventSourceResponse: """ 名词解释api\n 入参:\n query: str 用户输入\n fileNames: List[str] 文件名称\n knowledge_base_name_list: List[str] 知识库列表\n history: List[History] 历史对话\n """ logger.info( f"名词解释请求: query={query}, " f"文件={fileNames}, " f"知识库数={knowledge_base_name_list}, " f"历史记录数={len(history)}, " # f"流式模式={stream}" ) try: # 知识库存在性校验 for kb_name in knowledge_base_name_list: if not KBServiceFactory.get_service_by_name(kb_name): return BaseResponse(code=404, msg=f"知识库 {kb_name} 不存在") # 名词检查 check_result = await run_in_threadpool( get_llm_model_response, strategy_name="word_check", llm_model_name=LLM_MODELS[0], template_prompt_name="word_check", prompt_param_dict={"query": query}, temperature=0.01, max_tokens=512 ) logger.info(f"名词检查结果: {check_result}") docs = [] internet_context = "" internet_doc = [] if '0' in check_result: # logger.warning(f"非名词请求被拒绝: {query}") # return BaseResponse(code=500, msg="请选择需要解释的单独名词。") prompt_name = "word_explain_reject" else: prompt_name = "word_explain" # 并发获取文档和网络结果 docs, internet_context, internet_doc = await retrieve_documents( query=query, fileNames=fileNames, knowledge_base_names=knowledge_base_name_list, top_k=SELF_TOP_K, score_threshold=SELF_SCORE_THRESHOLD ) return EventSourceResponse( generate_chat_response( query=query, docs=docs, internet_context=internet_context, internet_docs=internet_doc, fileNames=fileNames, # stream=stream, model_name=LLM_MODELS[0], prompt_name=prompt_name ) ) except Exception as e: logger.error(f"处理请求异常: {str(e)}", exc_info=True) return BaseResponse(code=500, msg=f"处理请求时发生错误: {str(e)}") async def retrieve_documents( query: str, fileNames: List[str], knowledge_base_names: List[str], top_k: int, score_threshold: float ) -> Tuple[List[DocumentWithVSId], str, List[str]]: """ 并发检索文档和网络信息 返回格式: (文档列表, 网络上下文, 名称+超链接格式列表) """ async def fetch_kb_docs(kb_name: str) -> List[DocumentWithVSId]: """获取知识库匹配结果""" try: return await run_in_threadpool( search_self_docs, query=query, fileNames=fileNames, knowledge_base_name=kb_name, top_k=top_k, score_threshold=score_threshold ) except Exception as e: logger.error(f"知识库 {kb_name} 检索失败: {str(e)}") return [] # TODO 更换duckduckgo async def fetch_web_results() -> List[dict]: """获取网络搜索结果""" try: wrapper = ZhipuSearchAPIWrapper() return await run_in_threadpool( wrapper.zhipu_search, origin_query=query ) except Exception as e: logger.error(f"网络检索失败: {str(e)}") return [] try: # 并发执行所有检索任务 tasks = [ *[fetch_kb_docs(kb) for kb in knowledge_base_names], fetch_web_results() ] results = await asyncio.gather(*tasks, return_exceptions=True) # 处理结果 self_docs: List[DocumentWithVSId] = [] internet_context = "" internet_docs = [] for result in results: if isinstance(result, Exception): # 记录并发任务中的异常 logger.error(f"并发任务出现异常: {str(result)}") continue elif isinstance(result, list) and all(isinstance(doc, DocumentWithVSId) for doc in result): self_docs.extend(result) elif isinstance(result, list) and all(isinstance(item, dict) for item in result): # 相似度过滤互联网检索结果 if len(results[1]) > 0: try: sentences = [doc["title"] for doc in results[1]] sentences_page_content = [str(i+1)+":【"+doc["title"]+doc["content"]+"】" for i,doc in enumerate(results[1])] except Exception as e: sentences = [doc["source"] for doc in results[1]] sentences_page_content = [str(i+1)+":【"+doc["source"]+doc["content"]+"】" for i,doc in enumerate(results[1])] res = get_llm_model_response( strategy_name="default_similar", llm_model_name=LLM_MODELS[0], template_prompt_name="default_similar", prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")}, temperature=0.01, max_tokens=512 ) try: index =[] if res == "无": index = [] else: index = res.split(",") index = [int(i)-1 for i in index] similar_document = get_similar_documents(index=index,sentences=sentences,query=query, docs=results[1], top_k=top_k) except Exception as e: print(e) similar_document = get_similar_documents(index=[],sentences=sentences,query=query, docs=results[1], top_k=top_k) # 提取所有 content 字段并拼接 for num,item in enumerate(similar_document, start=1): content = item.get('content', '') title = item.get('title', '') cleaned_title = re.sub(r'(发布时间:.*?)', '', title) link = item.get('link', '') internet_docs.append(f"[{num}] [{cleaned_title}]({link})") internet_context += content return self_docs, internet_context, internet_docs except Exception as e: logger.error(f"文档检索异常: {str(e)}") return [], "", [] def generate_chat_response( query: str, docs: List[Dict], internet_context: str, fileNames: List[str], internet_docs: List[str], # stream: bool, model_name: str, prompt_name: str ) -> Union[EventSourceResponse, BaseResponse]: """ 生成聊天响应,整合本地文档和网络上下文 """ callback = AsyncIteratorCallbackHandler() model = get_ChatOpenAI( model_name=model_name, temperature=SELF_TEMPERATURE, max_tokens=SELF_MAX_TOKENS, callbacks=[callback], ) # 构建提示模板 prompt_template = get_prompt_template("knowledge_base_chat", prompt_name) chat_prompt = ChatPromptTemplate.from_messages([ History(role="system", content=prompt_template).to_msg_template(False) ]) # 配置问答链 qa_chain = load_qa_chain( model, chain_type="stuff", prompt=chat_prompt, verbose=True ) async def generation_wrapper() -> AsyncGenerator[str, None]: try: # 预处理文件名 clean_files = [re.sub(r"\.\w+$", "", f) for f in fileNames] # 构建上下文数据 context_data = { "input_documents": docs, "question": query, "fileName": ", ".join(clean_files), "internet_context": internet_context } task = asyncio.create_task(wrap_done( qa_chain.acall(context_data), callback.done )) # 流式输出处理 async for token in callback.aiter(): yield json.dumps({"text": token}, ensure_ascii=False) # 添加来源文档 if docs or internet_docs: # response = {} # source_documents = "\n".join(internet_docs) yield json.dumps({"docs": internet_docs}, ensure_ascii=False) await task except Exception as e: logger.error(f"生成响应异常: {str(e)}") yield json.dumps({"error": "生成响应时发生错误"}, ensure_ascii=False) return generation_wrapper() def process_source_documents(docs: List[Dict]) -> List[str]: """ 处理并格式化来源文档信息 """ seen = set() processed = [] for idx, doc in enumerate(docs, 1): # 提取文档摘要 summary = doc.metadata.get("summary", "") text = get_text_by_regex(summary) or doc.page_content[:200] # 生成唯一标识 content_hash = hashlib.md5(text.encode()).hexdigest() if content_hash in seen: continue seen.add(content_hash) source = doc.metadata.get("source", "未知来源") processed.append(f"[{idx}] {text[:10]}...(来源:{source})") logger.info(f"处理后的引用文档数: {len(processed)}") return processed