584 lines
26 KiB
Python
584 lines
26 KiB
Python
import asyncio
|
||
import os
|
||
from fastapi.responses import FileResponse
|
||
import pymysql
|
||
from contextvars import ContextVar
|
||
from configs.kb_config import DOWNLOAD_HOST_CK, KB_ROOT_PATH2, ck_mysql_config
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
import re
|
||
import threading
|
||
import requests
|
||
import time
|
||
import oss2
|
||
from pydantic import BaseModel, Field
|
||
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||
from configs import logger, log_verbose
|
||
from typing import Any, List, Tuple, Dict, Union
|
||
from configs.oss_config import *
|
||
from apscheduler.schedulers.background import BackgroundScheduler
|
||
from fastapi import FastAPI, HTTPException
|
||
from configs.kb_config import similarity_url,similarity_score,similarity_internet
|
||
import httpx
|
||
import json
|
||
|
||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||
from configs import kb_config as knowledgeBase_config
|
||
|
||
app = FastAPI()
|
||
scheduler = BackgroundScheduler()
|
||
# 创建一个ContextVar来存储请求范围内的共享字典
|
||
shared_variable: ContextVar[Dict[str, Any]] = ContextVar("shared_variable", default={})
|
||
expiration_times: ContextVar[Dict[str, float]] = ContextVar("expiration_times", default={})
|
||
|
||
def get_similar_documents(index,sentences,query: str, docs: List, top_k: int = 3):
|
||
if len(index) == 0:
|
||
# response = requests.post(similarity_url, json={"query": query, "title": sentences})
|
||
# result = response.json()
|
||
# num = 0
|
||
# for score in result["score"]:
|
||
# docs[num]["score"] = score
|
||
# num+=1
|
||
# docs = [doc for doc in docs if doc["score"] > similarity_internet]
|
||
# docs = sorted(docs, key=lambda x: x["score"], reverse=True)[:top_k]
|
||
docs = []
|
||
else:
|
||
docs = [docs[i] for i in range(0,len(docs)) if i in index]
|
||
return docs
|
||
|
||
|
||
def get_similar_documents1(index,sentences,query: str, docs: List[DocumentWithVSId], top_k: int = 3):
|
||
if len(index) == 0:
|
||
# response = requests.post(similarity_url, json={"query": query, "title": sentences})
|
||
# result = response.json()
|
||
# num = 0
|
||
# for score in result["score"]:
|
||
# docs[num][0].metadata["score"] = score
|
||
# num+=1
|
||
# docs = [doc for doc in docs if doc[0].metadata["score"] > similarity_score]
|
||
# docs = sorted(docs, key=lambda x: x[0].metadata["score"], reverse=True)[:top_k]
|
||
docs = []
|
||
else:
|
||
docs = [docs[i] for i in range(0,len(docs)) if i in index]
|
||
return docs
|
||
def remove_docs1(titles: List, docs: List[DocumentWithVSId]):
|
||
title = []
|
||
try:
|
||
res = [doc for doc in docs if doc.metadata["title"] not in titles]
|
||
for i in res:
|
||
title.append(i.metadata["title"])
|
||
except Exception as e:
|
||
res = [doc for doc in docs if doc.metadata["source"] not in titles]
|
||
for i in res:
|
||
title.append(i.metadata["source"])
|
||
return res,title
|
||
def remove_docs(titles: List, docs: List):
|
||
title = []
|
||
res = [doc for doc in docs if doc["title"] not in titles]
|
||
for i in res:
|
||
title.append(i["title"])
|
||
return res,title
|
||
|
||
def upload_image_to_oss(local_image_path, oss_file_name):
|
||
# 使用从ossconfig导入的配置信息
|
||
auth = oss2.Auth(access_key_id,access_key_secret)
|
||
bucket = oss2.Bucket(auth, endpoint, bucket_name)
|
||
# 上传文件
|
||
try:
|
||
object_name = f"files/chat_pic/{oss_file_name}"
|
||
with open(local_image_path, 'rb') as fileobj:
|
||
bucket.put_object(object_name, fileobj)
|
||
print(f"文件 {oss_file_name} 上传成功!")
|
||
# 构建线上URL
|
||
url = f"http://{bucket_name}.{endpoint}/{object_name}"
|
||
# url = bucket.sign_url('GET', object_name, 60 * 60 * 24)
|
||
return url
|
||
except Exception as e:
|
||
print(f"上传失败: {e}")
|
||
|
||
def get_personal_knowledge_map(uuid_names: list[str]):
|
||
config = ck_mysql_config
|
||
conn = pymysql.connect(**config)
|
||
try:
|
||
with conn.cursor() as cursor:
|
||
placeholders = ','.join(['%s'] * len(uuid_names))
|
||
sql = f"SELECT filepath,filename FROM gpt_upload_file WHERE filepath IN ({placeholders})"
|
||
filepaths = [f'./files/{uuid}' for uuid in uuid_names]
|
||
cursor.execute(sql, filepaths)
|
||
rows = cursor.fetchall()
|
||
result = {row[0].replace("./files/",""): row[1] for row in rows}
|
||
return result
|
||
finally:
|
||
conn.close()
|
||
|
||
|
||
def solve_knowledge_map(knowledge_map: list) -> list:
|
||
for inum,doc in enumerate(knowledge_map):
|
||
if doc == '政策库':
|
||
knowledge_map[inum] = knowledgeBase_config.DEFAULT_POLICY_BASE
|
||
elif doc == '期刊论文库':
|
||
knowledge_map[inum] = knowledgeBase_config.DEFAULT_JOURNAL_BASE
|
||
elif doc == '报告库':
|
||
knowledge_map[inum] = knowledgeBase_config.DEFAULT_REPORT_BASE1
|
||
elif doc == '冶金行业新闻库':
|
||
knowledge_map[inum] = knowledgeBase_config.GY_NEWS_BASE
|
||
elif doc == '冶金行业报告库':
|
||
knowledge_map[inum] = knowledgeBase_config.GY_REPORT_BASE
|
||
elif doc == '冶金专业知识库':
|
||
knowledge_map[inum] = knowledgeBase_config.GY_JOURNAL_BASE
|
||
# 新增冶金系列知识库中文名到英文名的映射
|
||
elif doc == knowledgeBase_config.YJ_CH_JOURNAL_BASE_NAME:
|
||
knowledge_map[inum] = knowledgeBase_config.YJ_CH_JOURNAL_BASE
|
||
elif doc == knowledgeBase_config.YJ_NEWS_BASE_NAME:
|
||
knowledge_map[inum] = knowledgeBase_config.YJ_NEWS_BASE
|
||
elif doc == knowledgeBase_config.YJ_FOR_JOURNAL_BASE_NAME:
|
||
knowledge_map[inum] = knowledgeBase_config.YJ_FOR_JOURNAL_BASE
|
||
elif doc == knowledgeBase_config.YJ_OA_JOURNAL_BASE_NAME:
|
||
knowledge_map[inum] = knowledgeBase_config.YJ_OA_JOURNAL_BASE
|
||
elif doc == knowledgeBase_config.YJ_POLICYS_BASE_NAME:
|
||
knowledge_map[inum] = knowledgeBase_config.YJ_POLICYS_BASE
|
||
elif doc == knowledgeBase_config.STEEL_KB_NAME:
|
||
knowledge_map[inum] = knowledgeBase_config.STEEL_KB
|
||
return knowledge_map
|
||
|
||
def get_shared_variable(key: str) -> Any:
|
||
# 获取当前上下文中的字典
|
||
context_dict = shared_variable.get()
|
||
# 返回字典中对应键的值
|
||
return context_dict.get(key)
|
||
|
||
def set_shared_variable(key: str, value: Any, ttl: int = 300):
|
||
context_dict = shared_variable.get()
|
||
expiration_dict = expiration_times.get()
|
||
|
||
# 设置键值和过期时间
|
||
context_dict[key] = value
|
||
now = datetime.now()
|
||
current_time = int(now.timestamp())
|
||
expiration_dict[key] = current_time + ttl
|
||
|
||
shared_variable.set(context_dict)
|
||
expiration_times.set(expiration_dict)
|
||
#手动删除
|
||
def remove_shared_variable(key: str):
|
||
context_dict = shared_variable.get()
|
||
expiration_dict = expiration_times.get()
|
||
|
||
# 删除键值对
|
||
if key in context_dict:
|
||
del context_dict[key]
|
||
if key in expiration_dict:
|
||
del expiration_dict[key]
|
||
|
||
shared_variable.set(context_dict)
|
||
expiration_times.set(expiration_dict)
|
||
async def clear_expired_keys():
|
||
while True:
|
||
time.sleep(2) # 每秒检查一次
|
||
print("Checking expired keys...")
|
||
now = datetime.now()
|
||
current_time = int(now.timestamp())
|
||
context_dict = shared_variable.get()
|
||
expiration_dict = expiration_times.get()
|
||
|
||
# 找出所有过期的键
|
||
keys_to_delete = [key for key, exp_time in expiration_dict.items() if exp_time <= current_time]
|
||
|
||
for key in keys_to_delete:
|
||
del context_dict[key]
|
||
del expiration_dict[key]
|
||
|
||
shared_variable.set(context_dict)
|
||
expiration_times.set(expiration_dict)
|
||
|
||
@app.on_event("startup")
|
||
async def app_start():
|
||
scheduler.add_job(clear_expired_keys, 'interval', seconds=3)
|
||
scheduler.start()
|
||
|
||
def compute_lps(pattern: str) -> list:
|
||
"""计算模式字符串的最长前缀后缀数组"""
|
||
lps = [0] * len(pattern)
|
||
length = 0
|
||
i = 1
|
||
|
||
while i < len(pattern):
|
||
if pattern[i] == pattern[length]:
|
||
length += 1
|
||
lps[i] = length
|
||
i += 1
|
||
else:
|
||
if length != 0:
|
||
length = lps[length - 1]
|
||
else:
|
||
lps[i] = 0
|
||
i += 1
|
||
return lps
|
||
|
||
#匹配字符串是否以模式开头
|
||
def kmp_match(text: str, pattern: str) -> bool:
|
||
"""使用KMP算法检查模式是否从文本的第一个字符开始匹配"""
|
||
n = len(text)
|
||
m = len(pattern)
|
||
|
||
# 计算模式字符串的最长前缀后缀数组
|
||
lps = compute_lps(pattern)
|
||
|
||
i = 0 # text的索引
|
||
j = 0 # pattern的索引
|
||
|
||
while i < n:
|
||
if pattern[j] == text[i]:
|
||
i += 1
|
||
j += 1
|
||
|
||
if j == m:
|
||
# 如果j等于模式的长度,说明匹配成功
|
||
return i - j == 0
|
||
|
||
elif i < n and pattern[j] != text[i]:
|
||
if j != 0:
|
||
j = lps[j - 1]
|
||
else:
|
||
i += 1
|
||
|
||
return False
|
||
|
||
|
||
def remove_before_and_including(text: str, substring: str) -> str:
|
||
# 找到子字符串的位置
|
||
index = text.find(substring)
|
||
|
||
# 如果子字符串存在,则返回从子字符串之后的部分
|
||
if index != -1:
|
||
return text[index + len(substring)-1:]
|
||
else:
|
||
# 如果子字符串不存在,返回原始字符串
|
||
return text
|
||
def remove_after_and_including(text: str, substring: str) -> str:
|
||
# 找到子字符串的位置
|
||
index = text.find(substring)
|
||
|
||
# 如果子字符串存在,则返回从子字符串之前的部分
|
||
if index != -1:
|
||
return text[:index]
|
||
else:
|
||
# 如果子字符串不存在,返回原始字符串
|
||
return text
|
||
|
||
def remove_after_and_includings(text, keyword):
|
||
# 使用正则表达式找到第一个关键字的位置
|
||
pattern = re.escape(keyword)
|
||
match = re.search(pattern, text)
|
||
if match:
|
||
# 返回关键字之前的部分
|
||
return text[:match.start()]
|
||
return text
|
||
def download_self_doc(knowledge_name: str, file_name: str):
|
||
# 防止路径穿越
|
||
safe_name = Path(file_name).name
|
||
file_path = KB_ROOT_PATH2 / knowledge_name / "content" / safe_name
|
||
|
||
if not file_path.exists() or not file_path.is_file():
|
||
raise HTTPException(status_code=404, detail="暂未找到文件")
|
||
name_map = get_personal_knowledge_map([safe_name])
|
||
# 如果字典里有对应项,就用原始名,否则退回 safe_name
|
||
download_name = name_map.get(safe_name, safe_name)
|
||
return FileResponse(
|
||
path=str(file_path),
|
||
media_type="application/octet-stream",
|
||
filename=download_name,
|
||
)
|
||
|
||
#文档转数组
|
||
def doc_to_list(num: int, knowledge_name:str, docs: List[DocumentWithVSId],source_documents):
|
||
for inum,doc in enumerate(docs):
|
||
num +=1
|
||
if knowledge_name == knowledgeBase_config.DEFAULT_REPORT_BASE_NAME:
|
||
filename = doc.metadata.get("source").strip()
|
||
text = f"""[{num}] [{filename}]"""
|
||
elif knowledge_name in (knowledgeBase_config.GY_NEWS_BASE_NAME, knowledgeBase_config.GY_REPORT_BASE_NAME, knowledgeBase_config.GY_JOURNAL_BASE_NAME, knowledgeBase_config.YJ_NEWS_BASE_NAME):
|
||
filename = doc.metadata.get("title").strip()
|
||
detail_url = doc.metadata.get("url") or doc.metadata.get("releaseUrl", "")
|
||
text = f"""[{num}] [{filename}]({detail_url})"""
|
||
# 判断是否是个人知识库
|
||
elif knowledgeBase_config.SELF_KNOWLEDGE_BASE.match(knowledge_name) or knowledge_name == "coding":
|
||
filename = doc.metadata.get("source").strip()
|
||
uuid_name = doc.metadata.get("uuid_name").strip()
|
||
# text = f"""[{num}] [{filename}]\n"""
|
||
detail_url = f"{DOWNLOAD_HOST_CK}?filename={knowledge_name}/content/{uuid_name}&originname={filename}"
|
||
text = f"[{num}] [{Path(filename).stem}]({detail_url})\n"
|
||
elif knowledge_name in (knowledgeBase_config.YJ_CH_JOURNAL_BASE_NAME, knowledgeBase_config.YJ_FOR_JOURNAL_BASE_NAME, knowledgeBase_config.YJ_OA_JOURNAL_BASE_NAME):
|
||
filename = doc.metadata.get("title", "未命名").strip()
|
||
text = f"[{num}] {filename}"
|
||
elif knowledge_name == knowledgeBase_config.STEEL_KB_NAME:
|
||
filename = doc.metadata.get("title", "未命名").strip()
|
||
detail_url = doc.metadata.get("url", "")
|
||
text = f"[{num}] [{filename}]({detail_url})"
|
||
else:
|
||
filename = doc.metadata.get("title").strip()
|
||
try:
|
||
detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html"
|
||
except:
|
||
detail_url = "https://kgo.ckcest.cn/kgo/detail/1002/ads_journal_article/" + doc.metadata.get("ID") + ".html"
|
||
if filename:
|
||
if doc.metadata.get('_type') == 'title':
|
||
text = f"""[{num}] [{filename}]({detail_url})"""
|
||
else:
|
||
text = f"""[{num}] [{filename}]({detail_url})"""
|
||
else:
|
||
if doc.metadata.get('_type') == 'title':
|
||
text = f"""[{num}] [{"原文地址"}]({detail_url})"""
|
||
else:
|
||
text = f"""[{num}] [{"原文地址"}]({detail_url})"""
|
||
source_documents.append(text.replace('\n', '').replace('\r', ''))
|
||
|
||
|
||
def solve_mental_data(knowledge_name,docs: List[DocumentWithVSId],doc,seen_docs,duplicate_indices,knowledge,inum):
|
||
# 判断是否是个人知识库
|
||
if knowledgeBase_config.SELF_KNOWLEDGE_BASE.match(knowledge_name) or knowledge_name == "coding":
|
||
doc_identifier = (doc.metadata["source"], doc.page_content)
|
||
if doc_identifier not in seen_docs:
|
||
seen_docs.add(doc_identifier)
|
||
knowledge.append(f"""参考资料[{inum + 1}] {doc.metadata["summary"]}""")
|
||
else:
|
||
duplicate_indices.append(inum)
|
||
else:
|
||
if "summary" in doc.metadata:
|
||
if len(doc.metadata['summary'])>15:
|
||
doc_identifier = (doc.metadata['title'], doc.page_content) if "title" in doc.metadata else (doc.metadata["source"], doc.page_content)
|
||
# 检查此标识符是否已存在于集合中
|
||
if doc_identifier not in seen_docs:
|
||
seen_docs.add(doc_identifier)
|
||
# 并将文档信息添加到knowledge列表中
|
||
if knowledge_name == knowledgeBase_config.DEFAULT_JOURNAL_BASE:
|
||
date_str = doc.metadata['publish_date']
|
||
date_obj = datetime.strptime(str(date_str), '%Y%m%d')
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}资料来源: {doc.metadata['data_source'].replace('.pdf','')}资料时间:{date_obj} \n資料内容: {doc.metadata['summary']}""")
|
||
elif knowledge_name == knowledgeBase_config.DEFAULT_REPORT_BASE:
|
||
try:
|
||
if doc.metadata['title'] == doc.page_content:
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}\n资料内容: {doc.metadata['summary']}""")
|
||
else:
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}\n资料内容: {doc.page_content}""")
|
||
except Exception as e:
|
||
if doc.metadata['source'] == doc.page_content:
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['source']}\n资料内容: {doc.metadata['summary']}""")
|
||
else:
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['source']}\n资料内容: {doc.page_content}""")
|
||
elif knowledge_name == knowledgeBase_config.GY_NEWS_BASE or knowledge_name == knowledgeBase_config.GY_REPORT_BASE or knowledge_name == knowledgeBase_config.GY_JOURNAL_BASE:
|
||
# 支持冶金新闻库(2024年以及之前)
|
||
if knowledge_name == knowledgeBase_config.YJ_NEWS_BASE:
|
||
date_str = doc.metadata.get('publish_date', '') or doc.metadata.get('publish_year', '')
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata.get('title','')}
|
||
资料内容: {doc.metadata.get('summary', doc.page_content)}""")
|
||
else:
|
||
if doc.metadata['title'] == doc.page_content:
|
||
date_str = doc.metadata['publish_date']
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}
|
||
资料内容: {doc.metadata['summary']}""")
|
||
else:
|
||
date_str = doc.metadata['publish_date']
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}
|
||
资料内容: {doc.page_content}""")
|
||
# ------- 新增冶金期刊类知识库处理 -------
|
||
elif knowledge_name in (knowledgeBase_config.YJ_CH_JOURNAL_BASE, knowledgeBase_config.YJ_FOR_JOURNAL_BASE, knowledgeBase_config.YJ_OA_JOURNAL_BASE):
|
||
# 处理期刊类文献,优先使用摘要
|
||
if doc.metadata.get('title') == doc.page_content:
|
||
content_used = doc.metadata.get('summary', doc.page_content)
|
||
else:
|
||
content_used = doc.page_content
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata.get('title','未命名')}
|
||
资料内容: {content_used}""")
|
||
# ------- 新增冶金政策库处理 -------
|
||
elif knowledge_name == knowledgeBase_config.YJ_POLICYS_BASE:
|
||
title = doc.metadata.get('title') or doc.metadata.get('source', '')
|
||
content_used = doc.metadata.get('summary', doc.page_content)
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {title}
|
||
资料内容: {content_used}""")
|
||
else:
|
||
try:
|
||
date_str = doc.metadata['release_date']
|
||
date_obj = datetime.strptime(str(date_str), '%Y%m%d')
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}资料来源: {doc.metadata['source'].replace('.pdf','')}资料时间:{date_obj} \n資料内容: {doc.metadata['summary']}""")
|
||
except Exception as e:
|
||
if knowledge_name in knowledgeBase_config.STEEL_KB:
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}]发布时间:{doc.metadata['date']} 资料来源: {doc.metadata['source']}标题:{doc.metadata['title']}\n資料内容: {doc.metadata['summary']}""")
|
||
else:
|
||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['source']}\n資料内容: {doc.metadata['summary']}""")
|
||
else:
|
||
duplicate_indices.append(inum)
|
||
else:
|
||
duplicate_indices.append(inum)
|
||
# knowledge.append(f"""参考资料[{len(knowledge) + 1}] 报告来源: {doc.metadata['source'].replace('.pdf','')} \n报告内容: {doc.content}""")
|
||
else:
|
||
doc_identifier = (doc.metadata['title'], doc.page_content)
|
||
if doc_identifier not in seen_docs:
|
||
seen_docs.add(doc_identifier)
|
||
if doc.metadata["_type"] == "title":
|
||
try:
|
||
if knowledge_name in knowledgeBase_config.GY_BASE_NAME:
|
||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']} \n文章内容 {doc.metadata['content']}""")
|
||
else:
|
||
date_str = doc.metadata['release_date']
|
||
date_obj = datetime.strptime(date_str, '%Y%m%d')
|
||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content}资料时间{date_obj} \n文章内容 {doc.metadata['content']}""")
|
||
except Exception as e:
|
||
date_str = doc.metadata['publish_date'] if 'publish_date' in doc.metadata else ""
|
||
date_obj = datetime.strptime(date_str, '%Y%m%d') if not date_str=="" else ""
|
||
if knowledge_name in knowledgeBase_config.DEFAULT_REPORT_BASE1:
|
||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['source']}\n文章内容 {doc.page_content}""")
|
||
else:
|
||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content}资料时间{date_obj} \n文章内容 {doc.metadata['abstract']}""")
|
||
else:
|
||
try:
|
||
if knowledge_name in knowledgeBase_config.GY_BASE_NAME:
|
||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']} \n文章内容 {doc.page_content}""")
|
||
else:
|
||
date_str = doc.metadata['release_date']
|
||
date_obj = datetime.strptime(date_str, '%Y%m%d')
|
||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']}资料时间{date_obj} \n文章内容 {doc.page_content}""")
|
||
except Exception as e:
|
||
date_str = doc.metadata['publish_date'] if 'publish_date' in doc.metadata else ""
|
||
date_obj = datetime.strptime(date_str, '%Y%m%d') if not date_str=="" else ""
|
||
if knowledge_name in knowledgeBase_config.DEFAULT_REPORT_BASE1:
|
||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['source']}\n文章内容 {doc.page_content}""")
|
||
else:
|
||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content}资料时间{date_obj} \n文章内容 {doc.metadata['abstract']}""")
|
||
else:
|
||
duplicate_indices.append(inum)
|
||
@app.post("/chat_translate")
|
||
async def chat_translate(query, lang: str):
|
||
if lang == "zh2en":
|
||
url = 'http://192.168.56.123:8845/translate'
|
||
elif lang == "en2zh":
|
||
url = 'http://192.168.56.123:8846/translate'
|
||
else:
|
||
raise ValueError("lang参数错误")
|
||
|
||
# print("query", query)
|
||
data = {'text': query}
|
||
headers = {'Content-Type': 'application/json'}
|
||
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.post(url, data=json.dumps(data), headers=headers)
|
||
print("response", response)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
return result['translation']
|
||
else:
|
||
raise Exception(status_code=response.status_code, detail="请求失败")
|
||
|
||
class History(BaseModel):
|
||
"""
|
||
对话历史
|
||
可从dict生成,如
|
||
h = History(**{"role":"user","content":"你好"})
|
||
也可转换为tuple,如
|
||
h.to_msy_tuple = ("human", "你好")
|
||
"""
|
||
role: str = Field(...)
|
||
content: str = Field(...)
|
||
|
||
def to_msg_tuple(self):
|
||
return ("assistant", self.content) if self.role == "assistant" else ("system", self.content) if self.role == "system" else ("user", self.content)
|
||
|
||
def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
|
||
role_maps = {
|
||
"ai": "assistant",
|
||
"human": "user",
|
||
"system": "system",
|
||
}
|
||
role = role_maps.get(self.role, self.role)
|
||
if is_raw: # 当前默认历史消息都是没有input_variable的文本。
|
||
content = "{% raw %}" + self.content + "{% endraw %}"
|
||
else:
|
||
content = self.content
|
||
|
||
return ChatMessagePromptTemplate.from_template(
|
||
content,
|
||
"jinja2",
|
||
role=role,
|
||
)
|
||
|
||
@classmethod
|
||
def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
|
||
if isinstance(h, (list,tuple)) and len(h) >= 2:
|
||
h = cls(role=h[0], content=h[1])
|
||
elif isinstance(h, dict):
|
||
h = cls(**h)
|
||
|
||
return h
|
||
|
||
# 定义Markdown特殊字符的正则表达式,作为常量
|
||
MARKDOWN_SPECIAL_CHARS = r'[\\`*_{}\[\]()#+\-\.!|>~:]'
|
||
|
||
def get_first_sentence_by_regex(text: str, fallback_length: int = 10) -> str:
|
||
"""
|
||
1. 去除首尾空白;
|
||
2. 删除指定范围的特殊符号;
|
||
3. 如果存在换行符,则只保留第一行;
|
||
4. 使用正则表达式从文本中截取第一句;
|
||
若未找到句号,则按 fallback_length 的逻辑截取。
|
||
"""
|
||
if not text:
|
||
return ""
|
||
|
||
# 1) 去除首尾空白
|
||
text = text.strip()
|
||
# 使用re.sub去除Markdown特殊字符
|
||
text = re.sub(MARKDOWN_SPECIAL_CHARS, '', text)
|
||
|
||
# 3) 如果存在换行符,则只保留第一行
|
||
# 可以使用 splitlines() 或 find('\n')
|
||
first_newline_index = text.find('\n')
|
||
if first_newline_index != -1:
|
||
text = text[:first_newline_index]
|
||
# 进一步去掉换行带来的首尾空格
|
||
text = text.strip()
|
||
|
||
# 4) 使用正则表达式从文本中截取第一句
|
||
pattern = re.compile(r'^(.*?。)')
|
||
match = pattern.search(text)
|
||
if match:
|
||
# 如果匹配到第一个句号,返回该句子(含句号)
|
||
return match.group(1).strip()
|
||
else:
|
||
# 如果没找到句号,则退而求其次,截取前N个字符+""..."
|
||
if len(text) <= fallback_length:
|
||
return text
|
||
else:
|
||
return text[:fallback_length] + "..."
|
||
|
||
def get_text_by_regex(text: str) -> str:
|
||
if not text:
|
||
return ""
|
||
# 去除首尾空白
|
||
text = text.strip()
|
||
# 使用re.sub去除Markdown特殊字符
|
||
return re.sub(MARKDOWN_SPECIAL_CHARS, '', text)
|
||
|
||
def split_questions(text: str) -> str:
|
||
"""
|
||
三重分隔策略,并直接返回用 <strip> 拼接好的字符串。
|
||
"""
|
||
text = text.strip()
|
||
|
||
# 按Q分割
|
||
parts = re.split(r'(?=Q\d+[::])', text)
|
||
parts = [p.strip() for p in parts if p.strip()]
|
||
if len(parts) > 1:
|
||
return '<strip>\n\n'.join(parts) + '<strip>'
|
||
|
||
# 按换行分割
|
||
parts = re.split(r'[\r\n\t]+', text)
|
||
parts = [p.strip() for p in parts if p.strip()]
|
||
if len(parts) > 1:
|
||
return '<strip>\n\n'.join(parts) + '<strip>'
|
||
|
||
# 按标点符号分割
|
||
tmp = re.sub(r'([?\?!!。])', r'\1<sep>', text)
|
||
parts = [p.strip() for p in tmp.split('<sep>') if p.strip()]
|
||
return '<strip>\n\n'.join(parts) + '<strip>'
|