Files
gangyan/langchain-chat/server/custom/abstract_search.py

120 lines
5.4 KiB
Python
Raw Permalink Normal View History

import asyncio
import json
import logging
from typing import AsyncIterable, List, Optional
from fastapi import Body, Request, HTTPException
from fastapi.concurrency import run_in_threadpool
from sse_starlette.sse import EventSourceResponse
from server.knowledge_base.kb_doc_api import search_docs
from configs import (
TEMPERATURE,
VECTOR_SEARCH_TOP_K,
SCORE_THRESHOLD,
LLM_MODELS,
JOURNAL_KNOWLEDGE_BASE,
MAX_TOKENS,
)
from server.utils import BaseResponse, get_prompt_template
from server.chat.policy_fun_iast import get_llm_model_response
from server.knowledge_base.kb_service.base import KBServiceFactory
async def abstract_search(
query: str = Body(..., description="用户输入", examples=["人工智能领域相关的文献有啥"]),
knowledge_base_name_list: List[str] = Body(..., description="知识库名称", examples=[["t_journal_article_bge_v1"]]),
# top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数", ge=1),
# fileName: List[str] = Body([], description="文件名称", examples=[["123.txt"]]),
# stream: bool = Body(False, description="流式输出"),
# model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
# temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
# max_tokens: Optional[int] = Body(MAX_TOKENS,description="限制LLM生成Token数量默认None代表模型最大值"),
# score_threshold: float = Body(SCORE_THRESHOLD,
# description=(
# "知识库匹配相关度阈值取值范围在0-1之间"
# "SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右"
# ),
# ge=0.0,
# le=1.0
# ),
# request: Request = None,
# query_rewrite_model_name: str = Body(
# 'Qwen2-72B-Instruct',
# description="Query rewrite model name."
# ),
) -> EventSourceResponse:
async def search_iterator() -> AsyncIterable[str]:
try:
# 重写查询
search_query_response = await run_in_threadpool(
get_llm_model_response,
strategy_name="query rewrite",
llm_model_name=LLM_MODELS[0],
template_prompt_name="query_rewrite",
prompt_param_dict={"query": query, "history": []},
temperature=0.01,
max_tokens=512
)
# 解析重写后的查询
try:
data = json.loads(search_query_response.strip("```json\n").strip("```"))
queries = data.get('query', [])
search_query = ' '.join(queries)
except json.JSONDecodeError:
search_query = query
logging.info(f'段落推荐输入问题: {search_query}')
# 执行文档搜索(这里只搜索期刊知识库)
journaldocs = await run_in_threadpool(
search_docs,
query=search_query,
fileName=[],
knowledge_base_name=JOURNAL_KNOWLEDGE_BASE,
top_k=VECTOR_SEARCH_TOP_K,
score_threshold=SCORE_THRESHOLD
)
logging.info(f"段落推荐召回资料: {journaldocs}")
# 处理搜索结果,去重并整理
seen_ids = set()
journallist = []
# journaldocs = []
for doc in journaldocs:
doc_id = doc.metadata.get('ID')
content = doc.metadata.get('abstract', '').strip()
#处理content不是完整段落的情况
if content:
if not content.endswith(''):
last_period_index = content.rfind('')
if last_period_index != -1:
content = content[:last_period_index + 1]
else:
content += '...'
if doc_id and content and doc_id not in seen_ids:
journallist.append({
'id': doc_id,
'title': doc.metadata.get('title', ''),
'content': content
})
seen_ids.add(doc_id)
search_results = {'journal_article': journallist}
logging.info(f"段落推荐召回资料整理结果:{journallist}")
# 如果没有找到相关文档,返回默认响应
# if not journallist:
# search_results['journal_article'].append({
# 'id': 'empty',
# 'title': 'empty',
# 'content': '暂未找到相关资料。'
# })
yield json.dumps(search_results, ensure_ascii=False)
except Exception as e:
# 捕获并返回错误信息
error_response = BaseResponse(code=500, msg=str(e))
yield json.dumps(error_response.dict(), ensure_ascii=False)
return EventSourceResponse(search_iterator())