584 lines
26 KiB
Python
584 lines
26 KiB
Python
|
|
import asyncio
|
|||
|
|
import os
|
|||
|
|
from fastapi.responses import FileResponse
|
|||
|
|
import pymysql
|
|||
|
|
from contextvars import ContextVar
|
|||
|
|
from configs.kb_config import DOWNLOAD_HOST_CK, KB_ROOT_PATH2, ck_mysql_config
|
|||
|
|
from datetime import datetime
|
|||
|
|
from pathlib import Path
|
|||
|
|
import re
|
|||
|
|
import threading
|
|||
|
|
import requests
|
|||
|
|
import time
|
|||
|
|
import oss2
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from langchain.prompts.chat import ChatMessagePromptTemplate
|
|||
|
|
from configs import logger, log_verbose
|
|||
|
|
from typing import Any, List, Tuple, Dict, Union
|
|||
|
|
from configs.oss_config import *
|
|||
|
|
from apscheduler.schedulers.background import BackgroundScheduler
|
|||
|
|
from fastapi import FastAPI, HTTPException
|
|||
|
|
from configs.kb_config import similarity_url,similarity_score,similarity_internet
|
|||
|
|
import httpx
|
|||
|
|
import json
|
|||
|
|
|
|||
|
|
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
|||
|
|
from configs import kb_config as knowledgeBase_config
|
|||
|
|
|
|||
|
|
app = FastAPI()
|
|||
|
|
scheduler = BackgroundScheduler()
|
|||
|
|
# 创建一个ContextVar来存储请求范围内的共享字典
|
|||
|
|
shared_variable: ContextVar[Dict[str, Any]] = ContextVar("shared_variable", default={})
|
|||
|
|
expiration_times: ContextVar[Dict[str, float]] = ContextVar("expiration_times", default={})
|
|||
|
|
|
|||
|
|
def get_similar_documents(index,sentences,query: str, docs: List, top_k: int = 3):
|
|||
|
|
if len(index) == 0:
|
|||
|
|
# response = requests.post(similarity_url, json={"query": query, "title": sentences})
|
|||
|
|
# result = response.json()
|
|||
|
|
# num = 0
|
|||
|
|
# for score in result["score"]:
|
|||
|
|
# docs[num]["score"] = score
|
|||
|
|
# num+=1
|
|||
|
|
# docs = [doc for doc in docs if doc["score"] > similarity_internet]
|
|||
|
|
# docs = sorted(docs, key=lambda x: x["score"], reverse=True)[:top_k]
|
|||
|
|
docs = []
|
|||
|
|
else:
|
|||
|
|
docs = [docs[i] for i in range(0,len(docs)) if i in index]
|
|||
|
|
return docs
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_similar_documents1(index,sentences,query: str, docs: List[DocumentWithVSId], top_k: int = 3):
|
|||
|
|
if len(index) == 0:
|
|||
|
|
# response = requests.post(similarity_url, json={"query": query, "title": sentences})
|
|||
|
|
# result = response.json()
|
|||
|
|
# num = 0
|
|||
|
|
# for score in result["score"]:
|
|||
|
|
# docs[num][0].metadata["score"] = score
|
|||
|
|
# num+=1
|
|||
|
|
# docs = [doc for doc in docs if doc[0].metadata["score"] > similarity_score]
|
|||
|
|
# docs = sorted(docs, key=lambda x: x[0].metadata["score"], reverse=True)[:top_k]
|
|||
|
|
docs = []
|
|||
|
|
else:
|
|||
|
|
docs = [docs[i] for i in range(0,len(docs)) if i in index]
|
|||
|
|
return docs
|
|||
|
|
def remove_docs1(titles: List, docs: List[DocumentWithVSId]):
|
|||
|
|
title = []
|
|||
|
|
try:
|
|||
|
|
res = [doc for doc in docs if doc.metadata["title"] not in titles]
|
|||
|
|
for i in res:
|
|||
|
|
title.append(i.metadata["title"])
|
|||
|
|
except Exception as e:
|
|||
|
|
res = [doc for doc in docs if doc.metadata["source"] not in titles]
|
|||
|
|
for i in res:
|
|||
|
|
title.append(i.metadata["source"])
|
|||
|
|
return res,title
|
|||
|
|
def remove_docs(titles: List, docs: List):
|
|||
|
|
title = []
|
|||
|
|
res = [doc for doc in docs if doc["title"] not in titles]
|
|||
|
|
for i in res:
|
|||
|
|
title.append(i["title"])
|
|||
|
|
return res,title
|
|||
|
|
|
|||
|
|
def upload_image_to_oss(local_image_path, oss_file_name):
|
|||
|
|
# 使用从ossconfig导入的配置信息
|
|||
|
|
auth = oss2.Auth(access_key_id,access_key_secret)
|
|||
|
|
bucket = oss2.Bucket(auth, endpoint, bucket_name)
|
|||
|
|
# 上传文件
|
|||
|
|
try:
|
|||
|
|
object_name = f"files/chat_pic/{oss_file_name}"
|
|||
|
|
with open(local_image_path, 'rb') as fileobj:
|
|||
|
|
bucket.put_object(object_name, fileobj)
|
|||
|
|
print(f"文件 {oss_file_name} 上传成功!")
|
|||
|
|
# 构建线上URL
|
|||
|
|
url = f"http://{bucket_name}.{endpoint}/{object_name}"
|
|||
|
|
# url = bucket.sign_url('GET', object_name, 60 * 60 * 24)
|
|||
|
|
return url
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"上传失败: {e}")
|
|||
|
|
|
|||
|
|
def get_personal_knowledge_map(uuid_names: list[str]):
|
|||
|
|
config = ck_mysql_config
|
|||
|
|
conn = pymysql.connect(**config)
|
|||
|
|
try:
|
|||
|
|
with conn.cursor() as cursor:
|
|||
|
|
placeholders = ','.join(['%s'] * len(uuid_names))
|
|||
|
|
sql = f"SELECT filepath,filename FROM gpt_upload_file WHERE filepath IN ({placeholders})"
|
|||
|
|
filepaths = [f'./files/{uuid}' for uuid in uuid_names]
|
|||
|
|
cursor.execute(sql, filepaths)
|
|||
|
|
rows = cursor.fetchall()
|
|||
|
|
result = {row[0].replace("./files/",""): row[1] for row in rows}
|
|||
|
|
return result
|
|||
|
|
finally:
|
|||
|
|
conn.close()
|
|||
|
|
|
|||
|
|
|
|||
|
|
def solve_knowledge_map(knowledge_map: list) -> list:
|
|||
|
|
for inum,doc in enumerate(knowledge_map):
|
|||
|
|
if doc == '政策库':
|
|||
|
|
knowledge_map[inum] = knowledgeBase_config.DEFAULT_POLICY_BASE
|
|||
|
|
elif doc == '期刊论文库':
|
|||
|
|
knowledge_map[inum] = knowledgeBase_config.DEFAULT_JOURNAL_BASE
|
|||
|
|
elif doc == '报告库':
|
|||
|
|
knowledge_map[inum] = knowledgeBase_config.DEFAULT_REPORT_BASE1
|
|||
|
|
elif doc == '冶金行业新闻库':
|
|||
|
|
knowledge_map[inum] = knowledgeBase_config.GY_NEWS_BASE
|
|||
|
|
elif doc == '冶金行业报告库':
|
|||
|
|
knowledge_map[inum] = knowledgeBase_config.GY_REPORT_BASE
|
|||
|
|
elif doc == '冶金专业知识库':
|
|||
|
|
knowledge_map[inum] = knowledgeBase_config.GY_JOURNAL_BASE
|
|||
|
|
# 新增冶金系列知识库中文名到英文名的映射
|
|||
|
|
elif doc == knowledgeBase_config.YJ_CH_JOURNAL_BASE_NAME:
|
|||
|
|
knowledge_map[inum] = knowledgeBase_config.YJ_CH_JOURNAL_BASE
|
|||
|
|
elif doc == knowledgeBase_config.YJ_NEWS_BASE_NAME:
|
|||
|
|
knowledge_map[inum] = knowledgeBase_config.YJ_NEWS_BASE
|
|||
|
|
elif doc == knowledgeBase_config.YJ_FOR_JOURNAL_BASE_NAME:
|
|||
|
|
knowledge_map[inum] = knowledgeBase_config.YJ_FOR_JOURNAL_BASE
|
|||
|
|
elif doc == knowledgeBase_config.YJ_OA_JOURNAL_BASE_NAME:
|
|||
|
|
knowledge_map[inum] = knowledgeBase_config.YJ_OA_JOURNAL_BASE
|
|||
|
|
elif doc == knowledgeBase_config.YJ_POLICYS_BASE_NAME:
|
|||
|
|
knowledge_map[inum] = knowledgeBase_config.YJ_POLICYS_BASE
|
|||
|
|
elif doc == knowledgeBase_config.STEEL_KB_NAME:
|
|||
|
|
knowledge_map[inum] = knowledgeBase_config.STEEL_KB
|
|||
|
|
return knowledge_map
|
|||
|
|
|
|||
|
|
def get_shared_variable(key: str) -> Any:
|
|||
|
|
# 获取当前上下文中的字典
|
|||
|
|
context_dict = shared_variable.get()
|
|||
|
|
# 返回字典中对应键的值
|
|||
|
|
return context_dict.get(key)
|
|||
|
|
|
|||
|
|
def set_shared_variable(key: str, value: Any, ttl: int = 300):
|
|||
|
|
context_dict = shared_variable.get()
|
|||
|
|
expiration_dict = expiration_times.get()
|
|||
|
|
|
|||
|
|
# 设置键值和过期时间
|
|||
|
|
context_dict[key] = value
|
|||
|
|
now = datetime.now()
|
|||
|
|
current_time = int(now.timestamp())
|
|||
|
|
expiration_dict[key] = current_time + ttl
|
|||
|
|
|
|||
|
|
shared_variable.set(context_dict)
|
|||
|
|
expiration_times.set(expiration_dict)
|
|||
|
|
#手动删除
|
|||
|
|
def remove_shared_variable(key: str):
|
|||
|
|
context_dict = shared_variable.get()
|
|||
|
|
expiration_dict = expiration_times.get()
|
|||
|
|
|
|||
|
|
# 删除键值对
|
|||
|
|
if key in context_dict:
|
|||
|
|
del context_dict[key]
|
|||
|
|
if key in expiration_dict:
|
|||
|
|
del expiration_dict[key]
|
|||
|
|
|
|||
|
|
shared_variable.set(context_dict)
|
|||
|
|
expiration_times.set(expiration_dict)
|
|||
|
|
async def clear_expired_keys():
|
|||
|
|
while True:
|
|||
|
|
time.sleep(2) # 每秒检查一次
|
|||
|
|
print("Checking expired keys...")
|
|||
|
|
now = datetime.now()
|
|||
|
|
current_time = int(now.timestamp())
|
|||
|
|
context_dict = shared_variable.get()
|
|||
|
|
expiration_dict = expiration_times.get()
|
|||
|
|
|
|||
|
|
# 找出所有过期的键
|
|||
|
|
keys_to_delete = [key for key, exp_time in expiration_dict.items() if exp_time <= current_time]
|
|||
|
|
|
|||
|
|
for key in keys_to_delete:
|
|||
|
|
del context_dict[key]
|
|||
|
|
del expiration_dict[key]
|
|||
|
|
|
|||
|
|
shared_variable.set(context_dict)
|
|||
|
|
expiration_times.set(expiration_dict)
|
|||
|
|
|
|||
|
|
@app.on_event("startup")
|
|||
|
|
async def app_start():
|
|||
|
|
scheduler.add_job(clear_expired_keys, 'interval', seconds=3)
|
|||
|
|
scheduler.start()
|
|||
|
|
|
|||
|
|
def compute_lps(pattern: str) -> list:
|
|||
|
|
"""计算模式字符串的最长前缀后缀数组"""
|
|||
|
|
lps = [0] * len(pattern)
|
|||
|
|
length = 0
|
|||
|
|
i = 1
|
|||
|
|
|
|||
|
|
while i < len(pattern):
|
|||
|
|
if pattern[i] == pattern[length]:
|
|||
|
|
length += 1
|
|||
|
|
lps[i] = length
|
|||
|
|
i += 1
|
|||
|
|
else:
|
|||
|
|
if length != 0:
|
|||
|
|
length = lps[length - 1]
|
|||
|
|
else:
|
|||
|
|
lps[i] = 0
|
|||
|
|
i += 1
|
|||
|
|
return lps
|
|||
|
|
|
|||
|
|
#匹配字符串是否以模式开头
|
|||
|
|
def kmp_match(text: str, pattern: str) -> bool:
|
|||
|
|
"""使用KMP算法检查模式是否从文本的第一个字符开始匹配"""
|
|||
|
|
n = len(text)
|
|||
|
|
m = len(pattern)
|
|||
|
|
|
|||
|
|
# 计算模式字符串的最长前缀后缀数组
|
|||
|
|
lps = compute_lps(pattern)
|
|||
|
|
|
|||
|
|
i = 0 # text的索引
|
|||
|
|
j = 0 # pattern的索引
|
|||
|
|
|
|||
|
|
while i < n:
|
|||
|
|
if pattern[j] == text[i]:
|
|||
|
|
i += 1
|
|||
|
|
j += 1
|
|||
|
|
|
|||
|
|
if j == m:
|
|||
|
|
# 如果j等于模式的长度,说明匹配成功
|
|||
|
|
return i - j == 0
|
|||
|
|
|
|||
|
|
elif i < n and pattern[j] != text[i]:
|
|||
|
|
if j != 0:
|
|||
|
|
j = lps[j - 1]
|
|||
|
|
else:
|
|||
|
|
i += 1
|
|||
|
|
|
|||
|
|
return False
|
|||
|
|
|
|||
|
|
|
|||
|
|
def remove_before_and_including(text: str, substring: str) -> str:
|
|||
|
|
# 找到子字符串的位置
|
|||
|
|
index = text.find(substring)
|
|||
|
|
|
|||
|
|
# 如果子字符串存在,则返回从子字符串之后的部分
|
|||
|
|
if index != -1:
|
|||
|
|
return text[index + len(substring)-1:]
|
|||
|
|
else:
|
|||
|
|
# 如果子字符串不存在,返回原始字符串
|
|||
|
|
return text
|
|||
|
|
def remove_after_and_including(text: str, substring: str) -> str:
|
|||
|
|
# 找到子字符串的位置
|
|||
|
|
index = text.find(substring)
|
|||
|
|
|
|||
|
|
# 如果子字符串存在,则返回从子字符串之前的部分
|
|||
|
|
if index != -1:
|
|||
|
|
return text[:index]
|
|||
|
|
else:
|
|||
|
|
# 如果子字符串不存在,返回原始字符串
|
|||
|
|
return text
|
|||
|
|
|
|||
|
|
def remove_after_and_includings(text, keyword):
|
|||
|
|
# 使用正则表达式找到第一个关键字的位置
|
|||
|
|
pattern = re.escape(keyword)
|
|||
|
|
match = re.search(pattern, text)
|
|||
|
|
if match:
|
|||
|
|
# 返回关键字之前的部分
|
|||
|
|
return text[:match.start()]
|
|||
|
|
return text
|
|||
|
|
def download_self_doc(knowledge_name: str, file_name: str):
|
|||
|
|
# 防止路径穿越
|
|||
|
|
safe_name = Path(file_name).name
|
|||
|
|
file_path = KB_ROOT_PATH2 / knowledge_name / "content" / safe_name
|
|||
|
|
|
|||
|
|
if not file_path.exists() or not file_path.is_file():
|
|||
|
|
raise HTTPException(status_code=404, detail="暂未找到文件")
|
|||
|
|
name_map = get_personal_knowledge_map([safe_name])
|
|||
|
|
# 如果字典里有对应项,就用原始名,否则退回 safe_name
|
|||
|
|
download_name = name_map.get(safe_name, safe_name)
|
|||
|
|
return FileResponse(
|
|||
|
|
path=str(file_path),
|
|||
|
|
media_type="application/octet-stream",
|
|||
|
|
filename=download_name,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
#文档转数组
|
|||
|
|
def doc_to_list(num: int, knowledge_name:str, docs: List[DocumentWithVSId],source_documents):
|
|||
|
|
for inum,doc in enumerate(docs):
|
|||
|
|
num +=1
|
|||
|
|
if knowledge_name == knowledgeBase_config.DEFAULT_REPORT_BASE_NAME:
|
|||
|
|
filename = doc.metadata.get("source").strip()
|
|||
|
|
text = f"""[{num}] [{filename}]"""
|
|||
|
|
elif knowledge_name in (knowledgeBase_config.GY_NEWS_BASE_NAME, knowledgeBase_config.GY_REPORT_BASE_NAME, knowledgeBase_config.GY_JOURNAL_BASE_NAME, knowledgeBase_config.YJ_NEWS_BASE_NAME):
|
|||
|
|
filename = doc.metadata.get("title").strip()
|
|||
|
|
detail_url = doc.metadata.get("url") or doc.metadata.get("releaseUrl", "")
|
|||
|
|
text = f"""[{num}] [{filename}]({detail_url})"""
|
|||
|
|
# 判断是否是个人知识库
|
|||
|
|
elif knowledgeBase_config.SELF_KNOWLEDGE_BASE.match(knowledge_name) or knowledge_name == "coding":
|
|||
|
|
filename = doc.metadata.get("source").strip()
|
|||
|
|
uuid_name = doc.metadata.get("uuid_name").strip()
|
|||
|
|
# text = f"""[{num}] [{filename}]\n"""
|
|||
|
|
detail_url = f"{DOWNLOAD_HOST_CK}?filename={knowledge_name}/content/{uuid_name}&originname={filename}"
|
|||
|
|
text = f"[{num}] [{Path(filename).stem}]({detail_url})\n"
|
|||
|
|
elif knowledge_name in (knowledgeBase_config.YJ_CH_JOURNAL_BASE_NAME, knowledgeBase_config.YJ_FOR_JOURNAL_BASE_NAME, knowledgeBase_config.YJ_OA_JOURNAL_BASE_NAME):
|
|||
|
|
filename = doc.metadata.get("title", "未命名").strip()
|
|||
|
|
text = f"[{num}] {filename}"
|
|||
|
|
elif knowledge_name == knowledgeBase_config.STEEL_KB_NAME:
|
|||
|
|
filename = doc.metadata.get("title", "未命名").strip()
|
|||
|
|
detail_url = doc.metadata.get("url", "")
|
|||
|
|
text = f"[{num}] [{filename}]({detail_url})"
|
|||
|
|
else:
|
|||
|
|
filename = doc.metadata.get("title").strip()
|
|||
|
|
try:
|
|||
|
|
detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html"
|
|||
|
|
except:
|
|||
|
|
detail_url = "https://kgo.ckcest.cn/kgo/detail/1002/ads_journal_article/" + doc.metadata.get("ID") + ".html"
|
|||
|
|
if filename:
|
|||
|
|
if doc.metadata.get('_type') == 'title':
|
|||
|
|
text = f"""[{num}] [{filename}]({detail_url})"""
|
|||
|
|
else:
|
|||
|
|
text = f"""[{num}] [{filename}]({detail_url})"""
|
|||
|
|
else:
|
|||
|
|
if doc.metadata.get('_type') == 'title':
|
|||
|
|
text = f"""[{num}] [{"原文地址"}]({detail_url})"""
|
|||
|
|
else:
|
|||
|
|
text = f"""[{num}] [{"原文地址"}]({detail_url})"""
|
|||
|
|
source_documents.append(text.replace('\n', '').replace('\r', ''))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def solve_mental_data(knowledge_name,docs: List[DocumentWithVSId],doc,seen_docs,duplicate_indices,knowledge,inum):
|
|||
|
|
# 判断是否是个人知识库
|
|||
|
|
if knowledgeBase_config.SELF_KNOWLEDGE_BASE.match(knowledge_name) or knowledge_name == "coding":
|
|||
|
|
doc_identifier = (doc.metadata["source"], doc.page_content)
|
|||
|
|
if doc_identifier not in seen_docs:
|
|||
|
|
seen_docs.add(doc_identifier)
|
|||
|
|
knowledge.append(f"""参考资料[{inum + 1}] {doc.metadata["summary"]}""")
|
|||
|
|
else:
|
|||
|
|
duplicate_indices.append(inum)
|
|||
|
|
else:
|
|||
|
|
if "summary" in doc.metadata:
|
|||
|
|
if len(doc.metadata['summary'])>15:
|
|||
|
|
doc_identifier = (doc.metadata['title'], doc.page_content) if "title" in doc.metadata else (doc.metadata["source"], doc.page_content)
|
|||
|
|
# 检查此标识符是否已存在于集合中
|
|||
|
|
if doc_identifier not in seen_docs:
|
|||
|
|
seen_docs.add(doc_identifier)
|
|||
|
|
# 并将文档信息添加到knowledge列表中
|
|||
|
|
if knowledge_name == knowledgeBase_config.DEFAULT_JOURNAL_BASE:
|
|||
|
|
date_str = doc.metadata['publish_date']
|
|||
|
|
date_obj = datetime.strptime(str(date_str), '%Y%m%d')
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}资料来源: {doc.metadata['data_source'].replace('.pdf','')}资料时间:{date_obj} \n資料内容: {doc.metadata['summary']}""")
|
|||
|
|
elif knowledge_name == knowledgeBase_config.DEFAULT_REPORT_BASE:
|
|||
|
|
try:
|
|||
|
|
if doc.metadata['title'] == doc.page_content:
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}\n资料内容: {doc.metadata['summary']}""")
|
|||
|
|
else:
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}\n资料内容: {doc.page_content}""")
|
|||
|
|
except Exception as e:
|
|||
|
|
if doc.metadata['source'] == doc.page_content:
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['source']}\n资料内容: {doc.metadata['summary']}""")
|
|||
|
|
else:
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['source']}\n资料内容: {doc.page_content}""")
|
|||
|
|
elif knowledge_name == knowledgeBase_config.GY_NEWS_BASE or knowledge_name == knowledgeBase_config.GY_REPORT_BASE or knowledge_name == knowledgeBase_config.GY_JOURNAL_BASE:
|
|||
|
|
# 支持冶金新闻库(2024年以及之前)
|
|||
|
|
if knowledge_name == knowledgeBase_config.YJ_NEWS_BASE:
|
|||
|
|
date_str = doc.metadata.get('publish_date', '') or doc.metadata.get('publish_year', '')
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata.get('title','')}
|
|||
|
|
资料内容: {doc.metadata.get('summary', doc.page_content)}""")
|
|||
|
|
else:
|
|||
|
|
if doc.metadata['title'] == doc.page_content:
|
|||
|
|
date_str = doc.metadata['publish_date']
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}
|
|||
|
|
资料内容: {doc.metadata['summary']}""")
|
|||
|
|
else:
|
|||
|
|
date_str = doc.metadata['publish_date']
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}
|
|||
|
|
资料内容: {doc.page_content}""")
|
|||
|
|
# ------- 新增冶金期刊类知识库处理 -------
|
|||
|
|
elif knowledge_name in (knowledgeBase_config.YJ_CH_JOURNAL_BASE, knowledgeBase_config.YJ_FOR_JOURNAL_BASE, knowledgeBase_config.YJ_OA_JOURNAL_BASE):
|
|||
|
|
# 处理期刊类文献,优先使用摘要
|
|||
|
|
if doc.metadata.get('title') == doc.page_content:
|
|||
|
|
content_used = doc.metadata.get('summary', doc.page_content)
|
|||
|
|
else:
|
|||
|
|
content_used = doc.page_content
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata.get('title','未命名')}
|
|||
|
|
资料内容: {content_used}""")
|
|||
|
|
# ------- 新增冶金政策库处理 -------
|
|||
|
|
elif knowledge_name == knowledgeBase_config.YJ_POLICYS_BASE:
|
|||
|
|
title = doc.metadata.get('title') or doc.metadata.get('source', '')
|
|||
|
|
content_used = doc.metadata.get('summary', doc.page_content)
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {title}
|
|||
|
|
资料内容: {content_used}""")
|
|||
|
|
else:
|
|||
|
|
try:
|
|||
|
|
date_str = doc.metadata['release_date']
|
|||
|
|
date_obj = datetime.strptime(str(date_str), '%Y%m%d')
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}资料来源: {doc.metadata['source'].replace('.pdf','')}资料时间:{date_obj} \n資料内容: {doc.metadata['summary']}""")
|
|||
|
|
except Exception as e:
|
|||
|
|
if knowledge_name in knowledgeBase_config.STEEL_KB:
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}]发布时间:{doc.metadata['date']} 资料来源: {doc.metadata['source']}标题:{doc.metadata['title']}\n資料内容: {doc.metadata['summary']}""")
|
|||
|
|
else:
|
|||
|
|
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['source']}\n資料内容: {doc.metadata['summary']}""")
|
|||
|
|
else:
|
|||
|
|
duplicate_indices.append(inum)
|
|||
|
|
else:
|
|||
|
|
duplicate_indices.append(inum)
|
|||
|
|
# knowledge.append(f"""参考资料[{len(knowledge) + 1}] 报告来源: {doc.metadata['source'].replace('.pdf','')} \n报告内容: {doc.content}""")
|
|||
|
|
else:
|
|||
|
|
doc_identifier = (doc.metadata['title'], doc.page_content)
|
|||
|
|
if doc_identifier not in seen_docs:
|
|||
|
|
seen_docs.add(doc_identifier)
|
|||
|
|
if doc.metadata["_type"] == "title":
|
|||
|
|
try:
|
|||
|
|
if knowledge_name in knowledgeBase_config.GY_BASE_NAME:
|
|||
|
|
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']} \n文章内容 {doc.metadata['content']}""")
|
|||
|
|
else:
|
|||
|
|
date_str = doc.metadata['release_date']
|
|||
|
|
date_obj = datetime.strptime(date_str, '%Y%m%d')
|
|||
|
|
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content}资料时间{date_obj} \n文章内容 {doc.metadata['content']}""")
|
|||
|
|
except Exception as e:
|
|||
|
|
date_str = doc.metadata['publish_date'] if 'publish_date' in doc.metadata else ""
|
|||
|
|
date_obj = datetime.strptime(date_str, '%Y%m%d') if not date_str=="" else ""
|
|||
|
|
if knowledge_name in knowledgeBase_config.DEFAULT_REPORT_BASE1:
|
|||
|
|
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['source']}\n文章内容 {doc.page_content}""")
|
|||
|
|
else:
|
|||
|
|
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content}资料时间{date_obj} \n文章内容 {doc.metadata['abstract']}""")
|
|||
|
|
else:
|
|||
|
|
try:
|
|||
|
|
if knowledge_name in knowledgeBase_config.GY_BASE_NAME:
|
|||
|
|
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']} \n文章内容 {doc.page_content}""")
|
|||
|
|
else:
|
|||
|
|
date_str = doc.metadata['release_date']
|
|||
|
|
date_obj = datetime.strptime(date_str, '%Y%m%d')
|
|||
|
|
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']}资料时间{date_obj} \n文章内容 {doc.page_content}""")
|
|||
|
|
except Exception as e:
|
|||
|
|
date_str = doc.metadata['publish_date'] if 'publish_date' in doc.metadata else ""
|
|||
|
|
date_obj = datetime.strptime(date_str, '%Y%m%d') if not date_str=="" else ""
|
|||
|
|
if knowledge_name in knowledgeBase_config.DEFAULT_REPORT_BASE1:
|
|||
|
|
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['source']}\n文章内容 {doc.page_content}""")
|
|||
|
|
else:
|
|||
|
|
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content}资料时间{date_obj} \n文章内容 {doc.metadata['abstract']}""")
|
|||
|
|
else:
|
|||
|
|
duplicate_indices.append(inum)
|
|||
|
|
@app.post("/chat_translate")
|
|||
|
|
async def chat_translate(query, lang: str):
|
|||
|
|
if lang == "zh2en":
|
|||
|
|
url = 'http://192.168.56.123:8845/translate'
|
|||
|
|
elif lang == "en2zh":
|
|||
|
|
url = 'http://192.168.56.123:8846/translate'
|
|||
|
|
else:
|
|||
|
|
raise ValueError("lang参数错误")
|
|||
|
|
|
|||
|
|
# print("query", query)
|
|||
|
|
data = {'text': query}
|
|||
|
|
headers = {'Content-Type': 'application/json'}
|
|||
|
|
|
|||
|
|
async with httpx.AsyncClient() as client:
|
|||
|
|
response = await client.post(url, data=json.dumps(data), headers=headers)
|
|||
|
|
print("response", response)
|
|||
|
|
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
result = response.json()
|
|||
|
|
return result['translation']
|
|||
|
|
else:
|
|||
|
|
raise Exception(status_code=response.status_code, detail="请求失败")
|
|||
|
|
|
|||
|
|
class History(BaseModel):
|
|||
|
|
"""
|
|||
|
|
对话历史
|
|||
|
|
可从dict生成,如
|
|||
|
|
h = History(**{"role":"user","content":"你好"})
|
|||
|
|
也可转换为tuple,如
|
|||
|
|
h.to_msy_tuple = ("human", "你好")
|
|||
|
|
"""
|
|||
|
|
role: str = Field(...)
|
|||
|
|
content: str = Field(...)
|
|||
|
|
|
|||
|
|
def to_msg_tuple(self):
|
|||
|
|
return ("assistant", self.content) if self.role == "assistant" else ("system", self.content) if self.role == "system" else ("user", self.content)
|
|||
|
|
|
|||
|
|
def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
|
|||
|
|
role_maps = {
|
|||
|
|
"ai": "assistant",
|
|||
|
|
"human": "user",
|
|||
|
|
"system": "system",
|
|||
|
|
}
|
|||
|
|
role = role_maps.get(self.role, self.role)
|
|||
|
|
if is_raw: # 当前默认历史消息都是没有input_variable的文本。
|
|||
|
|
content = "{% raw %}" + self.content + "{% endraw %}"
|
|||
|
|
else:
|
|||
|
|
content = self.content
|
|||
|
|
|
|||
|
|
return ChatMessagePromptTemplate.from_template(
|
|||
|
|
content,
|
|||
|
|
"jinja2",
|
|||
|
|
role=role,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
|
|||
|
|
if isinstance(h, (list,tuple)) and len(h) >= 2:
|
|||
|
|
h = cls(role=h[0], content=h[1])
|
|||
|
|
elif isinstance(h, dict):
|
|||
|
|
h = cls(**h)
|
|||
|
|
|
|||
|
|
return h
|
|||
|
|
|
|||
|
|
# 定义Markdown特殊字符的正则表达式,作为常量
|
|||
|
|
MARKDOWN_SPECIAL_CHARS = r'[\\`*_{}\[\]()#+\-\.!|>~:]'
|
|||
|
|
|
|||
|
|
def get_first_sentence_by_regex(text: str, fallback_length: int = 10) -> str:
|
|||
|
|
"""
|
|||
|
|
1. 去除首尾空白;
|
|||
|
|
2. 删除指定范围的特殊符号;
|
|||
|
|
3. 如果存在换行符,则只保留第一行;
|
|||
|
|
4. 使用正则表达式从文本中截取第一句;
|
|||
|
|
若未找到句号,则按 fallback_length 的逻辑截取。
|
|||
|
|
"""
|
|||
|
|
if not text:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
# 1) 去除首尾空白
|
|||
|
|
text = text.strip()
|
|||
|
|
# 使用re.sub去除Markdown特殊字符
|
|||
|
|
text = re.sub(MARKDOWN_SPECIAL_CHARS, '', text)
|
|||
|
|
|
|||
|
|
# 3) 如果存在换行符,则只保留第一行
|
|||
|
|
# 可以使用 splitlines() 或 find('\n')
|
|||
|
|
first_newline_index = text.find('\n')
|
|||
|
|
if first_newline_index != -1:
|
|||
|
|
text = text[:first_newline_index]
|
|||
|
|
# 进一步去掉换行带来的首尾空格
|
|||
|
|
text = text.strip()
|
|||
|
|
|
|||
|
|
# 4) 使用正则表达式从文本中截取第一句
|
|||
|
|
pattern = re.compile(r'^(.*?。)')
|
|||
|
|
match = pattern.search(text)
|
|||
|
|
if match:
|
|||
|
|
# 如果匹配到第一个句号,返回该句子(含句号)
|
|||
|
|
return match.group(1).strip()
|
|||
|
|
else:
|
|||
|
|
# 如果没找到句号,则退而求其次,截取前N个字符+""..."
|
|||
|
|
if len(text) <= fallback_length:
|
|||
|
|
return text
|
|||
|
|
else:
|
|||
|
|
return text[:fallback_length] + "..."
|
|||
|
|
|
|||
|
|
def get_text_by_regex(text: str) -> str:
|
|||
|
|
if not text:
|
|||
|
|
return ""
|
|||
|
|
# 去除首尾空白
|
|||
|
|
text = text.strip()
|
|||
|
|
# 使用re.sub去除Markdown特殊字符
|
|||
|
|
return re.sub(MARKDOWN_SPECIAL_CHARS, '', text)
|
|||
|
|
|
|||
|
|
def split_questions(text: str) -> str:
|
|||
|
|
"""
|
|||
|
|
三重分隔策略,并直接返回用 <strip> 拼接好的字符串。
|
|||
|
|
"""
|
|||
|
|
text = text.strip()
|
|||
|
|
|
|||
|
|
# 按Q分割
|
|||
|
|
parts = re.split(r'(?=Q\d+[::])', text)
|
|||
|
|
parts = [p.strip() for p in parts if p.strip()]
|
|||
|
|
if len(parts) > 1:
|
|||
|
|
return '<strip>\n\n'.join(parts) + '<strip>'
|
|||
|
|
|
|||
|
|
# 按换行分割
|
|||
|
|
parts = re.split(r'[\r\n\t]+', text)
|
|||
|
|
parts = [p.strip() for p in parts if p.strip()]
|
|||
|
|
if len(parts) > 1:
|
|||
|
|
return '<strip>\n\n'.join(parts) + '<strip>'
|
|||
|
|
|
|||
|
|
# 按标点符号分割
|
|||
|
|
tmp = re.sub(r'([?\?!!。])', r'\1<sep>', text)
|
|||
|
|
parts = [p.strip() for p in tmp.split('<sep>') if p.strip()]
|
|||
|
|
return '<strip>\n\n'.join(parts) + '<strip>'
|