Files
gangyan/langchain-chat/server/chat/relevant_articles.py

108 lines
4.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
from fastapi import Body, HTTPException
from fastapi.concurrency import run_in_threadpool
from typing import List
import logging
from configs import (
SELF_TOP_K,
SCORE_THRESHOLD,
)
from configs.kb_config import DEFAULT_JOURNAL_BASE
from server.knowledge_base.kb_doc_api import search_docs
from server.utils import BaseResponse
from pydantic import BaseModel, Field # 用于请求模型验证
# 配置日志记录
logger = logging.getLogger(__name__)
# 请求模型验证(可选增强)
class ArticleRequest(BaseModel):
file_name: str = Field(..., min_length=1, description="文件标题", example="智能制造技术的发展与应用")
key_words: str = Field(..., description="逗号分隔的关键词", example="智能制造,工业互联网,物联网")
async def relevant_articles(
request: ArticleRequest = Body(...) # 使用Pydantic模型进行输入验证
):
"""
相关文献搜索:\n
- 输入:文件标题、关键词\n
- 输出:相关文献列表\n
返回值数据结构:\n
{\n
"code": 状态码200表示成功404表示未找到相关文档500表示服务出现错误,\n
"msg": 消息描述,\n
"data": {\n
"total": 相关文献的总数,\n
"articles": [\n
{\n
"doc_id": 文件ID\n
"title": 标题\n
"author_info": 作者\n
"publish_date": 发表时间\n
"data_source": 链接\n
},\n
]\n
}\n
}\n
"""
try:
# 构造规范化查询(标题重复两次加强权重,关键词空格分隔)
keywords_list = [kw.strip() for kw in request.key_words.split(',') if kw.strip()]
processed_keywords = ' '.join(set(keywords_list)).strip("暂无关键词") # 去重处理
# 构造带权重的查询字符串
query = f"{request.file_name}{processed_keywords}"
logger.info(f"相关文献检索query{query}")
# 合并所有知识库的搜索结果
all_docs = []
for knowledge_name in [DEFAULT_JOURNAL_BASE]:
docs = await run_in_threadpool(
search_docs,
query=query,
fileName=[],
knowledge_base_name=knowledge_name,
top_k=SELF_TOP_K,
score_threshold=SCORE_THRESHOLD
)
for doc in docs:
if doc.metadata.get("publish_date", None) == 17010101:
doc.metadata.pop("publish_date", None)
all_docs.extend(docs)
docs = all_docs[:5]
# ========== 结果后处理 ==========
if not docs:
logger.info(f"未找到相关文档:{query}")
return BaseResponse(code=404, msg="未找到相关文档")
# 提取链接
# pattern = r'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
formatted_docs = []
for doc in docs:
# data_source = doc.metadata.get("data_source", "")
# links = re.findall(pattern, data_source)
formatted_doc = {
"doc_id": doc.metadata.get("ID", ""),
"title": doc.metadata.get("title", ""),
"author_info": doc.metadata.get("author_info", ""),
"publish_date": doc.metadata.get("publish_date", ""),
"data_source": doc.metadata.get("data_source", ""),
}
formatted_docs.append(formatted_doc)
logger.info(f"成功返回 {len(docs)} 篇相关文档")
return BaseResponse(
code=200,
msg="成功检索到相关文档",
data={
"total": len(formatted_docs),
"articles": formatted_docs,
}
)
except Exception as e:
logger.error(f"文档检索失败: {str(e)}", exc_info=True)
return BaseResponse(
code=500,
msg=f"文档检索服务暂时不可用,错误信息:{str(e)}"
)