import asyncio import json import logging from typing import AsyncIterable, List, Optional from fastapi import Body, Request, HTTPException from fastapi.concurrency import run_in_threadpool from sse_starlette.sse import EventSourceResponse from server.knowledge_base.kb_doc_api import search_docs from configs import ( TEMPERATURE, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, LLM_MODELS, JOURNAL_KNOWLEDGE_BASE, MAX_TOKENS, ) from server.utils import BaseResponse, get_prompt_template from server.chat.policy_fun_iast import get_llm_model_response from server.knowledge_base.kb_service.base import KBServiceFactory async def abstract_search( query: str = Body(..., description="用户输入", examples=["人工智能领域相关的文献有啥"]), knowledge_base_name_list: List[str] = Body(..., description="知识库名称", examples=[["t_journal_article_bge_v1"]]), # top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数", ge=1), # fileName: List[str] = Body([], description="文件名称", examples=[["123.txt"]]), # stream: bool = Body(False, description="流式输出"), # model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"), # temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), # max_tokens: Optional[int] = Body(MAX_TOKENS,description="限制LLM生成Token数量,默认None代表模型最大值"), # score_threshold: float = Body(SCORE_THRESHOLD, # description=( # "知识库匹配相关度阈值,取值范围在0-1之间," # "SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右" # ), # ge=0.0, # le=1.0 # ), # request: Request = None, # query_rewrite_model_name: str = Body( # 'Qwen2-72B-Instruct', # description="Query rewrite model name." # ), ) -> EventSourceResponse: async def search_iterator() -> AsyncIterable[str]: try: # 重写查询 search_query_response = await run_in_threadpool( get_llm_model_response, strategy_name="query rewrite", llm_model_name=LLM_MODELS[0], template_prompt_name="query_rewrite", prompt_param_dict={"query": query, "history": []}, temperature=0.01, max_tokens=512 ) # 解析重写后的查询 try: data = json.loads(search_query_response.strip("```json\n").strip("```")) queries = data.get('query', []) search_query = ' '.join(queries) except json.JSONDecodeError: search_query = query logging.info(f'段落推荐输入问题: {search_query}') # 执行文档搜索(这里只搜索期刊知识库) journaldocs = await run_in_threadpool( search_docs, query=search_query, fileName=[], knowledge_base_name=JOURNAL_KNOWLEDGE_BASE, top_k=VECTOR_SEARCH_TOP_K, score_threshold=SCORE_THRESHOLD ) logging.info(f"段落推荐召回资料: {journaldocs}") # 处理搜索结果,去重并整理 seen_ids = set() journallist = [] # journaldocs = [] for doc in journaldocs: doc_id = doc.metadata.get('ID') content = doc.metadata.get('abstract', '').strip() #处理content不是完整段落的情况 if content: if not content.endswith('。'): last_period_index = content.rfind('。') if last_period_index != -1: content = content[:last_period_index + 1] else: content += '...' if doc_id and content and doc_id not in seen_ids: journallist.append({ 'id': doc_id, 'title': doc.metadata.get('title', ''), 'content': content }) seen_ids.add(doc_id) search_results = {'journal_article': journallist} logging.info(f"段落推荐召回资料整理结果:{journallist}") # 如果没有找到相关文档,返回默认响应 # if not journallist: # search_results['journal_article'].append({ # 'id': 'empty', # 'title': 'empty', # 'content': '暂未找到相关资料。' # }) yield json.dumps(search_results, ensure_ascii=False) except Exception as e: # 捕获并返回错误信息 error_response = BaseResponse(code=500, msg=str(e)) yield json.dumps(error_response.dict(), ensure_ascii=False) return EventSourceResponse(search_iterator())