120 lines
5.4 KiB
Python
120 lines
5.4 KiB
Python
import asyncio
|
||
import json
|
||
import logging
|
||
from typing import AsyncIterable, List, Optional
|
||
from fastapi import Body, Request, HTTPException
|
||
from fastapi.concurrency import run_in_threadpool
|
||
from sse_starlette.sse import EventSourceResponse
|
||
from server.knowledge_base.kb_doc_api import search_docs
|
||
from configs import (
|
||
TEMPERATURE,
|
||
VECTOR_SEARCH_TOP_K,
|
||
SCORE_THRESHOLD,
|
||
LLM_MODELS,
|
||
JOURNAL_KNOWLEDGE_BASE,
|
||
MAX_TOKENS,
|
||
)
|
||
from server.utils import BaseResponse, get_prompt_template
|
||
from server.chat.policy_fun_iast import get_llm_model_response
|
||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||
|
||
async def abstract_search(
|
||
query: str = Body(..., description="用户输入", examples=["人工智能领域相关的文献有啥"]),
|
||
knowledge_base_name_list: List[str] = Body(..., description="知识库名称", examples=[["t_journal_article_bge_v1"]]),
|
||
# top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数", ge=1),
|
||
# fileName: List[str] = Body([], description="文件名称", examples=[["123.txt"]]),
|
||
# stream: bool = Body(False, description="流式输出"),
|
||
# model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||
# temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||
# max_tokens: Optional[int] = Body(MAX_TOKENS,description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||
# score_threshold: float = Body(SCORE_THRESHOLD,
|
||
# description=(
|
||
# "知识库匹配相关度阈值,取值范围在0-1之间,"
|
||
# "SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右"
|
||
# ),
|
||
# ge=0.0,
|
||
# le=1.0
|
||
# ),
|
||
# request: Request = None,
|
||
# query_rewrite_model_name: str = Body(
|
||
# 'Qwen2-72B-Instruct',
|
||
# description="Query rewrite model name."
|
||
# ),
|
||
) -> EventSourceResponse:
|
||
|
||
async def search_iterator() -> AsyncIterable[str]:
|
||
try:
|
||
# 重写查询
|
||
search_query_response = await run_in_threadpool(
|
||
get_llm_model_response,
|
||
strategy_name="query rewrite",
|
||
llm_model_name=LLM_MODELS[0],
|
||
template_prompt_name="query_rewrite",
|
||
prompt_param_dict={"query": query, "history": []},
|
||
temperature=0.01,
|
||
max_tokens=512
|
||
)
|
||
|
||
# 解析重写后的查询
|
||
try:
|
||
data = json.loads(search_query_response.strip("```json\n").strip("```"))
|
||
queries = data.get('query', [])
|
||
search_query = ' '.join(queries)
|
||
except json.JSONDecodeError:
|
||
search_query = query
|
||
|
||
logging.info(f'段落推荐输入问题: {search_query}')
|
||
|
||
# 执行文档搜索(这里只搜索期刊知识库)
|
||
journaldocs = await run_in_threadpool(
|
||
search_docs,
|
||
query=search_query,
|
||
fileName=[],
|
||
knowledge_base_name=JOURNAL_KNOWLEDGE_BASE,
|
||
top_k=VECTOR_SEARCH_TOP_K,
|
||
score_threshold=SCORE_THRESHOLD
|
||
)
|
||
logging.info(f"段落推荐召回资料: {journaldocs}")
|
||
# 处理搜索结果,去重并整理
|
||
seen_ids = set()
|
||
journallist = []
|
||
# journaldocs = []
|
||
for doc in journaldocs:
|
||
doc_id = doc.metadata.get('ID')
|
||
content = doc.metadata.get('abstract', '').strip()
|
||
#处理content不是完整段落的情况
|
||
if content:
|
||
if not content.endswith('。'):
|
||
last_period_index = content.rfind('。')
|
||
if last_period_index != -1:
|
||
content = content[:last_period_index + 1]
|
||
else:
|
||
content += '...'
|
||
|
||
if doc_id and content and doc_id not in seen_ids:
|
||
journallist.append({
|
||
'id': doc_id,
|
||
'title': doc.metadata.get('title', ''),
|
||
'content': content
|
||
})
|
||
seen_ids.add(doc_id)
|
||
|
||
search_results = {'journal_article': journallist}
|
||
logging.info(f"段落推荐召回资料整理结果:{journallist}")
|
||
|
||
# 如果没有找到相关文档,返回默认响应
|
||
# if not journallist:
|
||
# search_results['journal_article'].append({
|
||
# 'id': 'empty',
|
||
# 'title': 'empty',
|
||
# 'content': '暂未找到相关资料。'
|
||
# })
|
||
|
||
yield json.dumps(search_results, ensure_ascii=False)
|
||
|
||
except Exception as e:
|
||
# 捕获并返回错误信息
|
||
error_response = BaseResponse(code=500, msg=str(e))
|
||
yield json.dumps(error_response.dict(), ensure_ascii=False)
|
||
|
||
return EventSourceResponse(search_iterator()) |