[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
120
langchain-chat/server/custom/abstract_search.py
Normal file
120
langchain-chat/server/custom/abstract_search.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncIterable, List, Optional
|
||||
from fastapi import Body, Request, HTTPException
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
from configs import (
|
||||
TEMPERATURE,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SCORE_THRESHOLD,
|
||||
LLM_MODELS,
|
||||
JOURNAL_KNOWLEDGE_BASE,
|
||||
MAX_TOKENS,
|
||||
)
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
|
||||
async def abstract_search(
|
||||
query: str = Body(..., description="用户输入", examples=["人工智能领域相关的文献有啥"]),
|
||||
knowledge_base_name_list: List[str] = Body(..., description="知识库名称", examples=[["t_journal_article_bge_v1"]]),
|
||||
# top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数", ge=1),
|
||||
# fileName: List[str] = Body([], description="文件名称", examples=[["123.txt"]]),
|
||||
# stream: bool = Body(False, description="流式输出"),
|
||||
# model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
# temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
# max_tokens: Optional[int] = Body(MAX_TOKENS,description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
# score_threshold: float = Body(SCORE_THRESHOLD,
|
||||
# description=(
|
||||
# "知识库匹配相关度阈值,取值范围在0-1之间,"
|
||||
# "SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右"
|
||||
# ),
|
||||
# ge=0.0,
|
||||
# le=1.0
|
||||
# ),
|
||||
# request: Request = None,
|
||||
# query_rewrite_model_name: str = Body(
|
||||
# 'Qwen2-72B-Instruct',
|
||||
# description="Query rewrite model name."
|
||||
# ),
|
||||
) -> EventSourceResponse:
|
||||
|
||||
async def search_iterator() -> AsyncIterable[str]:
|
||||
try:
|
||||
# 重写查询
|
||||
search_query_response = await run_in_threadpool(
|
||||
get_llm_model_response,
|
||||
strategy_name="query rewrite",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="query_rewrite",
|
||||
prompt_param_dict={"query": query, "history": []},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
|
||||
# 解析重写后的查询
|
||||
try:
|
||||
data = json.loads(search_query_response.strip("```json\n").strip("```"))
|
||||
queries = data.get('query', [])
|
||||
search_query = ' '.join(queries)
|
||||
except json.JSONDecodeError:
|
||||
search_query = query
|
||||
|
||||
logging.info(f'段落推荐输入问题: {search_query}')
|
||||
|
||||
# 执行文档搜索(这里只搜索期刊知识库)
|
||||
journaldocs = await run_in_threadpool(
|
||||
search_docs,
|
||||
query=search_query,
|
||||
fileName=[],
|
||||
knowledge_base_name=JOURNAL_KNOWLEDGE_BASE,
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
score_threshold=SCORE_THRESHOLD
|
||||
)
|
||||
logging.info(f"段落推荐召回资料: {journaldocs}")
|
||||
# 处理搜索结果,去重并整理
|
||||
seen_ids = set()
|
||||
journallist = []
|
||||
# journaldocs = []
|
||||
for doc in journaldocs:
|
||||
doc_id = doc.metadata.get('ID')
|
||||
content = doc.metadata.get('abstract', '').strip()
|
||||
#处理content不是完整段落的情况
|
||||
if content:
|
||||
if not content.endswith('。'):
|
||||
last_period_index = content.rfind('。')
|
||||
if last_period_index != -1:
|
||||
content = content[:last_period_index + 1]
|
||||
else:
|
||||
content += '...'
|
||||
|
||||
if doc_id and content and doc_id not in seen_ids:
|
||||
journallist.append({
|
||||
'id': doc_id,
|
||||
'title': doc.metadata.get('title', ''),
|
||||
'content': content
|
||||
})
|
||||
seen_ids.add(doc_id)
|
||||
|
||||
search_results = {'journal_article': journallist}
|
||||
logging.info(f"段落推荐召回资料整理结果:{journallist}")
|
||||
|
||||
# 如果没有找到相关文档,返回默认响应
|
||||
# if not journallist:
|
||||
# search_results['journal_article'].append({
|
||||
# 'id': 'empty',
|
||||
# 'title': 'empty',
|
||||
# 'content': '暂未找到相关资料。'
|
||||
# })
|
||||
|
||||
yield json.dumps(search_results, ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
# 捕获并返回错误信息
|
||||
error_response = BaseResponse(code=500, msg=str(e))
|
||||
yield json.dumps(error_response.dict(), ensure_ascii=False)
|
||||
|
||||
return EventSourceResponse(search_iterator())
|
||||
Reference in New Issue
Block a user