108 lines
4.2 KiB
Python
108 lines
4.2 KiB
Python
import re
|
||
from fastapi import Body, HTTPException
|
||
from fastapi.concurrency import run_in_threadpool
|
||
from typing import List
|
||
import logging
|
||
from configs import (
|
||
SELF_TOP_K,
|
||
SCORE_THRESHOLD,
|
||
)
|
||
from configs.kb_config import DEFAULT_JOURNAL_BASE
|
||
from server.knowledge_base.kb_doc_api import search_docs
|
||
from server.utils import BaseResponse
|
||
from pydantic import BaseModel, Field # 用于请求模型验证
|
||
|
||
# 配置日志记录
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 请求模型验证(可选增强)
|
||
class ArticleRequest(BaseModel):
|
||
file_name: str = Field(..., min_length=1, description="文件标题", example="智能制造技术的发展与应用")
|
||
key_words: str = Field(..., description="逗号分隔的关键词", example="智能制造,工业互联网,物联网")
|
||
|
||
async def relevant_articles(
|
||
request: ArticleRequest = Body(...) # 使用Pydantic模型进行输入验证
|
||
):
|
||
"""
|
||
相关文献搜索:\n
|
||
- 输入:文件标题、关键词\n
|
||
- 输出:相关文献列表\n
|
||
返回值数据结构:\n
|
||
{\n
|
||
"code": 状态码(200表示成功,404表示未找到相关文档,500表示服务出现错误),\n
|
||
"msg": 消息描述,\n
|
||
"data": {\n
|
||
"total": 相关文献的总数,\n
|
||
"articles": [\n
|
||
{\n
|
||
"doc_id": 文件ID\n
|
||
"title": 标题\n
|
||
"author_info": 作者\n
|
||
"publish_date": 发表时间\n
|
||
"data_source": 链接\n
|
||
},\n
|
||
]\n
|
||
}\n
|
||
}\n
|
||
"""
|
||
try:
|
||
# 构造规范化查询(标题重复两次加强权重,关键词空格分隔)
|
||
keywords_list = [kw.strip() for kw in request.key_words.split(',') if kw.strip()]
|
||
processed_keywords = ' '.join(set(keywords_list)).strip("暂无关键词") # 去重处理
|
||
|
||
# 构造带权重的查询字符串
|
||
query = f"{request.file_name}{processed_keywords}"
|
||
logger.info(f"相关文献检索query:{query}")
|
||
# 合并所有知识库的搜索结果
|
||
all_docs = []
|
||
for knowledge_name in [DEFAULT_JOURNAL_BASE]:
|
||
docs = await run_in_threadpool(
|
||
search_docs,
|
||
query=query,
|
||
fileName=[],
|
||
knowledge_base_name=knowledge_name,
|
||
top_k=SELF_TOP_K,
|
||
score_threshold=SCORE_THRESHOLD
|
||
)
|
||
for doc in docs:
|
||
if doc.metadata.get("publish_date", None) == 17010101:
|
||
doc.metadata.pop("publish_date", None)
|
||
all_docs.extend(docs)
|
||
docs = all_docs[:5]
|
||
|
||
# ========== 结果后处理 ==========
|
||
if not docs:
|
||
logger.info(f"未找到相关文档:{query}")
|
||
return BaseResponse(code=404, msg="未找到相关文档")
|
||
|
||
# 提取链接
|
||
# pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
|
||
formatted_docs = []
|
||
for doc in docs:
|
||
# data_source = doc.metadata.get("data_source", "")
|
||
# links = re.findall(pattern, data_source)
|
||
formatted_doc = {
|
||
"doc_id": doc.metadata.get("ID", ""),
|
||
"title": doc.metadata.get("title", ""),
|
||
"author_info": doc.metadata.get("author_info", ""),
|
||
"publish_date": doc.metadata.get("publish_date", ""),
|
||
"data_source": doc.metadata.get("data_source", ""),
|
||
}
|
||
formatted_docs.append(formatted_doc)
|
||
|
||
logger.info(f"成功返回 {len(docs)} 篇相关文档")
|
||
return BaseResponse(
|
||
code=200,
|
||
msg="成功检索到相关文档",
|
||
data={
|
||
"total": len(formatted_docs),
|
||
"articles": formatted_docs,
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"文档检索失败: {str(e)}", exc_info=True)
|
||
return BaseResponse(
|
||
code=500,
|
||
msg=f"文档检索服务暂时不可用,错误信息:{str(e)}"
|
||
) |