120 lines
5.4 KiB
Python
120 lines
5.4 KiB
Python
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
import logging
|
|||
|
|
from typing import AsyncIterable, List, Optional
|
|||
|
|
from fastapi import Body, Request, HTTPException
|
|||
|
|
from fastapi.concurrency import run_in_threadpool
|
|||
|
|
from sse_starlette.sse import EventSourceResponse
|
|||
|
|
from server.knowledge_base.kb_doc_api import search_docs
|
|||
|
|
from configs import (
|
|||
|
|
TEMPERATURE,
|
|||
|
|
VECTOR_SEARCH_TOP_K,
|
|||
|
|
SCORE_THRESHOLD,
|
|||
|
|
LLM_MODELS,
|
|||
|
|
JOURNAL_KNOWLEDGE_BASE,
|
|||
|
|
MAX_TOKENS,
|
|||
|
|
)
|
|||
|
|
from server.utils import BaseResponse, get_prompt_template
|
|||
|
|
from server.chat.policy_fun_iast import get_llm_model_response
|
|||
|
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
|||
|
|
|
|||
|
|
async def abstract_search(
|
|||
|
|
query: str = Body(..., description="用户输入", examples=["人工智能领域相关的文献有啥"]),
|
|||
|
|
knowledge_base_name_list: List[str] = Body(..., description="知识库名称", examples=[["t_journal_article_bge_v1"]]),
|
|||
|
|
# top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数", ge=1),
|
|||
|
|
# fileName: List[str] = Body([], description="文件名称", examples=[["123.txt"]]),
|
|||
|
|
# stream: bool = Body(False, description="流式输出"),
|
|||
|
|
# model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
|||
|
|
# temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
|||
|
|
# max_tokens: Optional[int] = Body(MAX_TOKENS,description="限制LLM生成Token数量,默认None代表模型最大值"),
|
|||
|
|
# score_threshold: float = Body(SCORE_THRESHOLD,
|
|||
|
|
# description=(
|
|||
|
|
# "知识库匹配相关度阈值,取值范围在0-1之间,"
|
|||
|
|
# "SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右"
|
|||
|
|
# ),
|
|||
|
|
# ge=0.0,
|
|||
|
|
# le=1.0
|
|||
|
|
# ),
|
|||
|
|
# request: Request = None,
|
|||
|
|
# query_rewrite_model_name: str = Body(
|
|||
|
|
# 'Qwen2-72B-Instruct',
|
|||
|
|
# description="Query rewrite model name."
|
|||
|
|
# ),
|
|||
|
|
) -> EventSourceResponse:
|
|||
|
|
|
|||
|
|
async def search_iterator() -> AsyncIterable[str]:
|
|||
|
|
try:
|
|||
|
|
# 重写查询
|
|||
|
|
search_query_response = await run_in_threadpool(
|
|||
|
|
get_llm_model_response,
|
|||
|
|
strategy_name="query rewrite",
|
|||
|
|
llm_model_name=LLM_MODELS[0],
|
|||
|
|
template_prompt_name="query_rewrite",
|
|||
|
|
prompt_param_dict={"query": query, "history": []},
|
|||
|
|
temperature=0.01,
|
|||
|
|
max_tokens=512
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 解析重写后的查询
|
|||
|
|
try:
|
|||
|
|
data = json.loads(search_query_response.strip("```json\n").strip("```"))
|
|||
|
|
queries = data.get('query', [])
|
|||
|
|
search_query = ' '.join(queries)
|
|||
|
|
except json.JSONDecodeError:
|
|||
|
|
search_query = query
|
|||
|
|
|
|||
|
|
logging.info(f'段落推荐输入问题: {search_query}')
|
|||
|
|
|
|||
|
|
# 执行文档搜索(这里只搜索期刊知识库)
|
|||
|
|
journaldocs = await run_in_threadpool(
|
|||
|
|
search_docs,
|
|||
|
|
query=search_query,
|
|||
|
|
fileName=[],
|
|||
|
|
knowledge_base_name=JOURNAL_KNOWLEDGE_BASE,
|
|||
|
|
top_k=VECTOR_SEARCH_TOP_K,
|
|||
|
|
score_threshold=SCORE_THRESHOLD
|
|||
|
|
)
|
|||
|
|
logging.info(f"段落推荐召回资料: {journaldocs}")
|
|||
|
|
# 处理搜索结果,去重并整理
|
|||
|
|
seen_ids = set()
|
|||
|
|
journallist = []
|
|||
|
|
# journaldocs = []
|
|||
|
|
for doc in journaldocs:
|
|||
|
|
doc_id = doc.metadata.get('ID')
|
|||
|
|
content = doc.metadata.get('abstract', '').strip()
|
|||
|
|
#处理content不是完整段落的情况
|
|||
|
|
if content:
|
|||
|
|
if not content.endswith('。'):
|
|||
|
|
last_period_index = content.rfind('。')
|
|||
|
|
if last_period_index != -1:
|
|||
|
|
content = content[:last_period_index + 1]
|
|||
|
|
else:
|
|||
|
|
content += '...'
|
|||
|
|
|
|||
|
|
if doc_id and content and doc_id not in seen_ids:
|
|||
|
|
journallist.append({
|
|||
|
|
'id': doc_id,
|
|||
|
|
'title': doc.metadata.get('title', ''),
|
|||
|
|
'content': content
|
|||
|
|
})
|
|||
|
|
seen_ids.add(doc_id)
|
|||
|
|
|
|||
|
|
search_results = {'journal_article': journallist}
|
|||
|
|
logging.info(f"段落推荐召回资料整理结果:{journallist}")
|
|||
|
|
|
|||
|
|
# 如果没有找到相关文档,返回默认响应
|
|||
|
|
# if not journallist:
|
|||
|
|
# search_results['journal_article'].append({
|
|||
|
|
# 'id': 'empty',
|
|||
|
|
# 'title': 'empty',
|
|||
|
|
# 'content': '暂未找到相关资料。'
|
|||
|
|
# })
|
|||
|
|
|
|||
|
|
yield json.dumps(search_results, ensure_ascii=False)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
# 捕获并返回错误信息
|
|||
|
|
error_response = BaseResponse(code=500, msg=str(e))
|
|||
|
|
yield json.dumps(error_response.dict(), ensure_ascii=False)
|
|||
|
|
|
|||
|
|
return EventSourceResponse(search_iterator())
|