[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
583
langchain-chat/server/chat/utils.py
Normal file
583
langchain-chat/server/chat/utils.py
Normal file
@@ -0,0 +1,583 @@
|
||||
import asyncio
|
||||
import os
|
||||
from fastapi.responses import FileResponse
|
||||
import pymysql
|
||||
from contextvars import ContextVar
|
||||
from configs.kb_config import DOWNLOAD_HOST_CK, KB_ROOT_PATH2, ck_mysql_config
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import re
|
||||
import threading
|
||||
import requests
|
||||
import time
|
||||
import oss2
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||||
from configs import logger, log_verbose
|
||||
from typing import Any, List, Tuple, Dict, Union
|
||||
from configs.oss_config import *
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from configs.kb_config import similarity_url,similarity_score,similarity_internet
|
||||
import httpx
|
||||
import json
|
||||
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from configs import kb_config as knowledgeBase_config
|
||||
|
||||
app = FastAPI()
|
||||
scheduler = BackgroundScheduler()
|
||||
# 创建一个ContextVar来存储请求范围内的共享字典
|
||||
shared_variable: ContextVar[Dict[str, Any]] = ContextVar("shared_variable", default={})
|
||||
expiration_times: ContextVar[Dict[str, float]] = ContextVar("expiration_times", default={})
|
||||
|
||||
def get_similar_documents(index,sentences,query: str, docs: List, top_k: int = 3):
|
||||
if len(index) == 0:
|
||||
# response = requests.post(similarity_url, json={"query": query, "title": sentences})
|
||||
# result = response.json()
|
||||
# num = 0
|
||||
# for score in result["score"]:
|
||||
# docs[num]["score"] = score
|
||||
# num+=1
|
||||
# docs = [doc for doc in docs if doc["score"] > similarity_internet]
|
||||
# docs = sorted(docs, key=lambda x: x["score"], reverse=True)[:top_k]
|
||||
docs = []
|
||||
else:
|
||||
docs = [docs[i] for i in range(0,len(docs)) if i in index]
|
||||
return docs
|
||||
|
||||
|
||||
def get_similar_documents1(index,sentences,query: str, docs: List[DocumentWithVSId], top_k: int = 3):
|
||||
if len(index) == 0:
|
||||
# response = requests.post(similarity_url, json={"query": query, "title": sentences})
|
||||
# result = response.json()
|
||||
# num = 0
|
||||
# for score in result["score"]:
|
||||
# docs[num][0].metadata["score"] = score
|
||||
# num+=1
|
||||
# docs = [doc for doc in docs if doc[0].metadata["score"] > similarity_score]
|
||||
# docs = sorted(docs, key=lambda x: x[0].metadata["score"], reverse=True)[:top_k]
|
||||
docs = []
|
||||
else:
|
||||
docs = [docs[i] for i in range(0,len(docs)) if i in index]
|
||||
return docs
|
||||
def remove_docs1(titles: List, docs: List[DocumentWithVSId]):
|
||||
title = []
|
||||
try:
|
||||
res = [doc for doc in docs if doc.metadata["title"] not in titles]
|
||||
for i in res:
|
||||
title.append(i.metadata["title"])
|
||||
except Exception as e:
|
||||
res = [doc for doc in docs if doc.metadata["source"] not in titles]
|
||||
for i in res:
|
||||
title.append(i.metadata["source"])
|
||||
return res,title
|
||||
def remove_docs(titles: List, docs: List):
|
||||
title = []
|
||||
res = [doc for doc in docs if doc["title"] not in titles]
|
||||
for i in res:
|
||||
title.append(i["title"])
|
||||
return res,title
|
||||
|
||||
def upload_image_to_oss(local_image_path, oss_file_name):
|
||||
# 使用从ossconfig导入的配置信息
|
||||
auth = oss2.Auth(access_key_id,access_key_secret)
|
||||
bucket = oss2.Bucket(auth, endpoint, bucket_name)
|
||||
# 上传文件
|
||||
try:
|
||||
object_name = f"files/chat_pic/{oss_file_name}"
|
||||
with open(local_image_path, 'rb') as fileobj:
|
||||
bucket.put_object(object_name, fileobj)
|
||||
print(f"文件 {oss_file_name} 上传成功!")
|
||||
# 构建线上URL
|
||||
url = f"http://{bucket_name}.{endpoint}/{object_name}"
|
||||
# url = bucket.sign_url('GET', object_name, 60 * 60 * 24)
|
||||
return url
|
||||
except Exception as e:
|
||||
print(f"上传失败: {e}")
|
||||
|
||||
def get_personal_knowledge_map(uuid_names: list[str]):
|
||||
config = ck_mysql_config
|
||||
conn = pymysql.connect(**config)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
placeholders = ','.join(['%s'] * len(uuid_names))
|
||||
sql = f"SELECT filepath,filename FROM gpt_upload_file WHERE filepath IN ({placeholders})"
|
||||
filepaths = [f'./files/{uuid}' for uuid in uuid_names]
|
||||
cursor.execute(sql, filepaths)
|
||||
rows = cursor.fetchall()
|
||||
result = {row[0].replace("./files/",""): row[1] for row in rows}
|
||||
return result
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def solve_knowledge_map(knowledge_map: list) -> list:
|
||||
for inum,doc in enumerate(knowledge_map):
|
||||
if doc == '政策库':
|
||||
knowledge_map[inum] = knowledgeBase_config.DEFAULT_POLICY_BASE
|
||||
elif doc == '期刊论文库':
|
||||
knowledge_map[inum] = knowledgeBase_config.DEFAULT_JOURNAL_BASE
|
||||
elif doc == '报告库':
|
||||
knowledge_map[inum] = knowledgeBase_config.DEFAULT_REPORT_BASE1
|
||||
elif doc == '冶金行业新闻库':
|
||||
knowledge_map[inum] = knowledgeBase_config.GY_NEWS_BASE
|
||||
elif doc == '冶金行业报告库':
|
||||
knowledge_map[inum] = knowledgeBase_config.GY_REPORT_BASE
|
||||
elif doc == '冶金专业知识库':
|
||||
knowledge_map[inum] = knowledgeBase_config.GY_JOURNAL_BASE
|
||||
# 新增冶金系列知识库中文名到英文名的映射
|
||||
elif doc == knowledgeBase_config.YJ_CH_JOURNAL_BASE_NAME:
|
||||
knowledge_map[inum] = knowledgeBase_config.YJ_CH_JOURNAL_BASE
|
||||
elif doc == knowledgeBase_config.YJ_NEWS_BASE_NAME:
|
||||
knowledge_map[inum] = knowledgeBase_config.YJ_NEWS_BASE
|
||||
elif doc == knowledgeBase_config.YJ_FOR_JOURNAL_BASE_NAME:
|
||||
knowledge_map[inum] = knowledgeBase_config.YJ_FOR_JOURNAL_BASE
|
||||
elif doc == knowledgeBase_config.YJ_OA_JOURNAL_BASE_NAME:
|
||||
knowledge_map[inum] = knowledgeBase_config.YJ_OA_JOURNAL_BASE
|
||||
elif doc == knowledgeBase_config.YJ_POLICYS_BASE_NAME:
|
||||
knowledge_map[inum] = knowledgeBase_config.YJ_POLICYS_BASE
|
||||
elif doc == knowledgeBase_config.STEEL_KB_NAME:
|
||||
knowledge_map[inum] = knowledgeBase_config.STEEL_KB
|
||||
return knowledge_map
|
||||
|
||||
def get_shared_variable(key: str) -> Any:
|
||||
# 获取当前上下文中的字典
|
||||
context_dict = shared_variable.get()
|
||||
# 返回字典中对应键的值
|
||||
return context_dict.get(key)
|
||||
|
||||
def set_shared_variable(key: str, value: Any, ttl: int = 300):
|
||||
context_dict = shared_variable.get()
|
||||
expiration_dict = expiration_times.get()
|
||||
|
||||
# 设置键值和过期时间
|
||||
context_dict[key] = value
|
||||
now = datetime.now()
|
||||
current_time = int(now.timestamp())
|
||||
expiration_dict[key] = current_time + ttl
|
||||
|
||||
shared_variable.set(context_dict)
|
||||
expiration_times.set(expiration_dict)
|
||||
#手动删除
|
||||
def remove_shared_variable(key: str):
|
||||
context_dict = shared_variable.get()
|
||||
expiration_dict = expiration_times.get()
|
||||
|
||||
# 删除键值对
|
||||
if key in context_dict:
|
||||
del context_dict[key]
|
||||
if key in expiration_dict:
|
||||
del expiration_dict[key]
|
||||
|
||||
shared_variable.set(context_dict)
|
||||
expiration_times.set(expiration_dict)
|
||||
async def clear_expired_keys():
|
||||
while True:
|
||||
time.sleep(2) # 每秒检查一次
|
||||
print("Checking expired keys...")
|
||||
now = datetime.now()
|
||||
current_time = int(now.timestamp())
|
||||
context_dict = shared_variable.get()
|
||||
expiration_dict = expiration_times.get()
|
||||
|
||||
# 找出所有过期的键
|
||||
keys_to_delete = [key for key, exp_time in expiration_dict.items() if exp_time <= current_time]
|
||||
|
||||
for key in keys_to_delete:
|
||||
del context_dict[key]
|
||||
del expiration_dict[key]
|
||||
|
||||
shared_variable.set(context_dict)
|
||||
expiration_times.set(expiration_dict)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def app_start():
|
||||
scheduler.add_job(clear_expired_keys, 'interval', seconds=3)
|
||||
scheduler.start()
|
||||
|
||||
def compute_lps(pattern: str) -> list:
|
||||
"""计算模式字符串的最长前缀后缀数组"""
|
||||
lps = [0] * len(pattern)
|
||||
length = 0
|
||||
i = 1
|
||||
|
||||
while i < len(pattern):
|
||||
if pattern[i] == pattern[length]:
|
||||
length += 1
|
||||
lps[i] = length
|
||||
i += 1
|
||||
else:
|
||||
if length != 0:
|
||||
length = lps[length - 1]
|
||||
else:
|
||||
lps[i] = 0
|
||||
i += 1
|
||||
return lps
|
||||
|
||||
#匹配字符串是否以模式开头
|
||||
def kmp_match(text: str, pattern: str) -> bool:
|
||||
"""使用KMP算法检查模式是否从文本的第一个字符开始匹配"""
|
||||
n = len(text)
|
||||
m = len(pattern)
|
||||
|
||||
# 计算模式字符串的最长前缀后缀数组
|
||||
lps = compute_lps(pattern)
|
||||
|
||||
i = 0 # text的索引
|
||||
j = 0 # pattern的索引
|
||||
|
||||
while i < n:
|
||||
if pattern[j] == text[i]:
|
||||
i += 1
|
||||
j += 1
|
||||
|
||||
if j == m:
|
||||
# 如果j等于模式的长度,说明匹配成功
|
||||
return i - j == 0
|
||||
|
||||
elif i < n and pattern[j] != text[i]:
|
||||
if j != 0:
|
||||
j = lps[j - 1]
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def remove_before_and_including(text: str, substring: str) -> str:
|
||||
# 找到子字符串的位置
|
||||
index = text.find(substring)
|
||||
|
||||
# 如果子字符串存在,则返回从子字符串之后的部分
|
||||
if index != -1:
|
||||
return text[index + len(substring)-1:]
|
||||
else:
|
||||
# 如果子字符串不存在,返回原始字符串
|
||||
return text
|
||||
def remove_after_and_including(text: str, substring: str) -> str:
|
||||
# 找到子字符串的位置
|
||||
index = text.find(substring)
|
||||
|
||||
# 如果子字符串存在,则返回从子字符串之前的部分
|
||||
if index != -1:
|
||||
return text[:index]
|
||||
else:
|
||||
# 如果子字符串不存在,返回原始字符串
|
||||
return text
|
||||
|
||||
def remove_after_and_includings(text, keyword):
|
||||
# 使用正则表达式找到第一个关键字的位置
|
||||
pattern = re.escape(keyword)
|
||||
match = re.search(pattern, text)
|
||||
if match:
|
||||
# 返回关键字之前的部分
|
||||
return text[:match.start()]
|
||||
return text
|
||||
def download_self_doc(knowledge_name: str, file_name: str):
|
||||
# 防止路径穿越
|
||||
safe_name = Path(file_name).name
|
||||
file_path = KB_ROOT_PATH2 / knowledge_name / "content" / safe_name
|
||||
|
||||
if not file_path.exists() or not file_path.is_file():
|
||||
raise HTTPException(status_code=404, detail="暂未找到文件")
|
||||
name_map = get_personal_knowledge_map([safe_name])
|
||||
# 如果字典里有对应项,就用原始名,否则退回 safe_name
|
||||
download_name = name_map.get(safe_name, safe_name)
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
media_type="application/octet-stream",
|
||||
filename=download_name,
|
||||
)
|
||||
|
||||
#文档转数组
|
||||
def doc_to_list(num: int, knowledge_name:str, docs: List[DocumentWithVSId],source_documents):
|
||||
for inum,doc in enumerate(docs):
|
||||
num +=1
|
||||
if knowledge_name == knowledgeBase_config.DEFAULT_REPORT_BASE_NAME:
|
||||
filename = doc.metadata.get("source").strip()
|
||||
text = f"""[{num}] [{filename}]"""
|
||||
elif knowledge_name in (knowledgeBase_config.GY_NEWS_BASE_NAME, knowledgeBase_config.GY_REPORT_BASE_NAME, knowledgeBase_config.GY_JOURNAL_BASE_NAME, knowledgeBase_config.YJ_NEWS_BASE_NAME):
|
||||
filename = doc.metadata.get("title").strip()
|
||||
detail_url = doc.metadata.get("url") or doc.metadata.get("releaseUrl", "")
|
||||
text = f"""[{num}] [{filename}]({detail_url})"""
|
||||
# 判断是否是个人知识库
|
||||
elif knowledgeBase_config.SELF_KNOWLEDGE_BASE.match(knowledge_name) or knowledge_name == "coding":
|
||||
filename = doc.metadata.get("source").strip()
|
||||
uuid_name = doc.metadata.get("uuid_name").strip()
|
||||
# text = f"""[{num}] [{filename}]\n"""
|
||||
detail_url = f"{DOWNLOAD_HOST_CK}?filename={knowledge_name}/content/{uuid_name}&originname={filename}"
|
||||
text = f"[{num}] [{Path(filename).stem}]({detail_url})\n"
|
||||
elif knowledge_name in (knowledgeBase_config.YJ_CH_JOURNAL_BASE_NAME, knowledgeBase_config.YJ_FOR_JOURNAL_BASE_NAME, knowledgeBase_config.YJ_OA_JOURNAL_BASE_NAME):
|
||||
filename = doc.metadata.get("title", "未命名").strip()
|
||||
text = f"[{num}] {filename}"
|
||||
elif knowledge_name == knowledgeBase_config.STEEL_KB_NAME:
|
||||
filename = doc.metadata.get("title", "未命名").strip()
|
||||
detail_url = doc.metadata.get("url", "")
|
||||
text = f"[{num}] [{filename}]({detail_url})"
|
||||
else:
|
||||
filename = doc.metadata.get("title").strip()
|
||||
try:
|
||||
detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html"
|
||||
except:
|
||||
detail_url = "https://kgo.ckcest.cn/kgo/detail/1002/ads_journal_article/" + doc.metadata.get("ID") + ".html"
|
||||
if filename:
|
||||
if doc.metadata.get('_type') == 'title':
|
||||
text = f"""[{num}] [{filename}]({detail_url})"""
|
||||
else:
|
||||
text = f"""[{num}] [{filename}]({detail_url})"""
|
||||
else:
|
||||
if doc.metadata.get('_type') == 'title':
|
||||
text = f"""[{num}] [{"原文地址"}]({detail_url})"""
|
||||
else:
|
||||
text = f"""[{num}] [{"原文地址"}]({detail_url})"""
|
||||
source_documents.append(text.replace('\n', '').replace('\r', ''))
|
||||
|
||||
|
||||
def solve_mental_data(knowledge_name,docs: List[DocumentWithVSId],doc,seen_docs,duplicate_indices,knowledge,inum):
|
||||
# 判断是否是个人知识库
|
||||
if knowledgeBase_config.SELF_KNOWLEDGE_BASE.match(knowledge_name) or knowledge_name == "coding":
|
||||
doc_identifier = (doc.metadata["source"], doc.page_content)
|
||||
if doc_identifier not in seen_docs:
|
||||
seen_docs.add(doc_identifier)
|
||||
knowledge.append(f"""参考资料[{inum + 1}] {doc.metadata["summary"]}""")
|
||||
else:
|
||||
duplicate_indices.append(inum)
|
||||
else:
|
||||
if "summary" in doc.metadata:
|
||||
if len(doc.metadata['summary'])>15:
|
||||
doc_identifier = (doc.metadata['title'], doc.page_content) if "title" in doc.metadata else (doc.metadata["source"], doc.page_content)
|
||||
# 检查此标识符是否已存在于集合中
|
||||
if doc_identifier not in seen_docs:
|
||||
seen_docs.add(doc_identifier)
|
||||
# 并将文档信息添加到knowledge列表中
|
||||
if knowledge_name == knowledgeBase_config.DEFAULT_JOURNAL_BASE:
|
||||
date_str = doc.metadata['publish_date']
|
||||
date_obj = datetime.strptime(str(date_str), '%Y%m%d')
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}资料来源: {doc.metadata['data_source'].replace('.pdf','')}资料时间:{date_obj} \n資料内容: {doc.metadata['summary']}""")
|
||||
elif knowledge_name == knowledgeBase_config.DEFAULT_REPORT_BASE:
|
||||
try:
|
||||
if doc.metadata['title'] == doc.page_content:
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}\n资料内容: {doc.metadata['summary']}""")
|
||||
else:
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}\n资料内容: {doc.page_content}""")
|
||||
except Exception as e:
|
||||
if doc.metadata['source'] == doc.page_content:
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['source']}\n资料内容: {doc.metadata['summary']}""")
|
||||
else:
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['source']}\n资料内容: {doc.page_content}""")
|
||||
elif knowledge_name == knowledgeBase_config.GY_NEWS_BASE or knowledge_name == knowledgeBase_config.GY_REPORT_BASE or knowledge_name == knowledgeBase_config.GY_JOURNAL_BASE:
|
||||
# 支持冶金新闻库(2024年以及之前)
|
||||
if knowledge_name == knowledgeBase_config.YJ_NEWS_BASE:
|
||||
date_str = doc.metadata.get('publish_date', '') or doc.metadata.get('publish_year', '')
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata.get('title','')}
|
||||
资料内容: {doc.metadata.get('summary', doc.page_content)}""")
|
||||
else:
|
||||
if doc.metadata['title'] == doc.page_content:
|
||||
date_str = doc.metadata['publish_date']
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}
|
||||
资料内容: {doc.metadata['summary']}""")
|
||||
else:
|
||||
date_str = doc.metadata['publish_date']
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}
|
||||
资料内容: {doc.page_content}""")
|
||||
# ------- 新增冶金期刊类知识库处理 -------
|
||||
elif knowledge_name in (knowledgeBase_config.YJ_CH_JOURNAL_BASE, knowledgeBase_config.YJ_FOR_JOURNAL_BASE, knowledgeBase_config.YJ_OA_JOURNAL_BASE):
|
||||
# 处理期刊类文献,优先使用摘要
|
||||
if doc.metadata.get('title') == doc.page_content:
|
||||
content_used = doc.metadata.get('summary', doc.page_content)
|
||||
else:
|
||||
content_used = doc.page_content
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata.get('title','未命名')}
|
||||
资料内容: {content_used}""")
|
||||
# ------- 新增冶金政策库处理 -------
|
||||
elif knowledge_name == knowledgeBase_config.YJ_POLICYS_BASE:
|
||||
title = doc.metadata.get('title') or doc.metadata.get('source', '')
|
||||
content_used = doc.metadata.get('summary', doc.page_content)
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {title}
|
||||
资料内容: {content_used}""")
|
||||
else:
|
||||
try:
|
||||
date_str = doc.metadata['release_date']
|
||||
date_obj = datetime.strptime(str(date_str), '%Y%m%d')
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}资料来源: {doc.metadata['source'].replace('.pdf','')}资料时间:{date_obj} \n資料内容: {doc.metadata['summary']}""")
|
||||
except Exception as e:
|
||||
if knowledge_name in knowledgeBase_config.STEEL_KB:
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}]发布时间:{doc.metadata['date']} 资料来源: {doc.metadata['source']}标题:{doc.metadata['title']}\n資料内容: {doc.metadata['summary']}""")
|
||||
else:
|
||||
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['source']}\n資料内容: {doc.metadata['summary']}""")
|
||||
else:
|
||||
duplicate_indices.append(inum)
|
||||
else:
|
||||
duplicate_indices.append(inum)
|
||||
# knowledge.append(f"""参考资料[{len(knowledge) + 1}] 报告来源: {doc.metadata['source'].replace('.pdf','')} \n报告内容: {doc.content}""")
|
||||
else:
|
||||
doc_identifier = (doc.metadata['title'], doc.page_content)
|
||||
if doc_identifier not in seen_docs:
|
||||
seen_docs.add(doc_identifier)
|
||||
if doc.metadata["_type"] == "title":
|
||||
try:
|
||||
if knowledge_name in knowledgeBase_config.GY_BASE_NAME:
|
||||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']} \n文章内容 {doc.metadata['content']}""")
|
||||
else:
|
||||
date_str = doc.metadata['release_date']
|
||||
date_obj = datetime.strptime(date_str, '%Y%m%d')
|
||||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content}资料时间{date_obj} \n文章内容 {doc.metadata['content']}""")
|
||||
except Exception as e:
|
||||
date_str = doc.metadata['publish_date'] if 'publish_date' in doc.metadata else ""
|
||||
date_obj = datetime.strptime(date_str, '%Y%m%d') if not date_str=="" else ""
|
||||
if knowledge_name in knowledgeBase_config.DEFAULT_REPORT_BASE1:
|
||||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['source']}\n文章内容 {doc.page_content}""")
|
||||
else:
|
||||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content}资料时间{date_obj} \n文章内容 {doc.metadata['abstract']}""")
|
||||
else:
|
||||
try:
|
||||
if knowledge_name in knowledgeBase_config.GY_BASE_NAME:
|
||||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']} \n文章内容 {doc.page_content}""")
|
||||
else:
|
||||
date_str = doc.metadata['release_date']
|
||||
date_obj = datetime.strptime(date_str, '%Y%m%d')
|
||||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']}资料时间{date_obj} \n文章内容 {doc.page_content}""")
|
||||
except Exception as e:
|
||||
date_str = doc.metadata['publish_date'] if 'publish_date' in doc.metadata else ""
|
||||
date_obj = datetime.strptime(date_str, '%Y%m%d') if not date_str=="" else ""
|
||||
if knowledge_name in knowledgeBase_config.DEFAULT_REPORT_BASE1:
|
||||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['source']}\n文章内容 {doc.page_content}""")
|
||||
else:
|
||||
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content}资料时间{date_obj} \n文章内容 {doc.metadata['abstract']}""")
|
||||
else:
|
||||
duplicate_indices.append(inum)
|
||||
@app.post("/chat_translate")
|
||||
async def chat_translate(query, lang: str):
|
||||
if lang == "zh2en":
|
||||
url = 'http://192.168.56.123:8845/translate'
|
||||
elif lang == "en2zh":
|
||||
url = 'http://192.168.56.123:8846/translate'
|
||||
else:
|
||||
raise ValueError("lang参数错误")
|
||||
|
||||
# print("query", query)
|
||||
data = {'text': query}
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, data=json.dumps(data), headers=headers)
|
||||
print("response", response)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result['translation']
|
||||
else:
|
||||
raise Exception(status_code=response.status_code, detail="请求失败")
|
||||
|
||||
class History(BaseModel):
|
||||
"""
|
||||
对话历史
|
||||
可从dict生成,如
|
||||
h = History(**{"role":"user","content":"你好"})
|
||||
也可转换为tuple,如
|
||||
h.to_msy_tuple = ("human", "你好")
|
||||
"""
|
||||
role: str = Field(...)
|
||||
content: str = Field(...)
|
||||
|
||||
def to_msg_tuple(self):
|
||||
return ("assistant", self.content) if self.role == "assistant" else ("system", self.content) if self.role == "system" else ("user", self.content)
|
||||
|
||||
def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
|
||||
role_maps = {
|
||||
"ai": "assistant",
|
||||
"human": "user",
|
||||
"system": "system",
|
||||
}
|
||||
role = role_maps.get(self.role, self.role)
|
||||
if is_raw: # 当前默认历史消息都是没有input_variable的文本。
|
||||
content = "{% raw %}" + self.content + "{% endraw %}"
|
||||
else:
|
||||
content = self.content
|
||||
|
||||
return ChatMessagePromptTemplate.from_template(
|
||||
content,
|
||||
"jinja2",
|
||||
role=role,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
|
||||
if isinstance(h, (list,tuple)) and len(h) >= 2:
|
||||
h = cls(role=h[0], content=h[1])
|
||||
elif isinstance(h, dict):
|
||||
h = cls(**h)
|
||||
|
||||
return h
|
||||
|
||||
# 定义Markdown特殊字符的正则表达式,作为常量
|
||||
MARKDOWN_SPECIAL_CHARS = r'[\\`*_{}\[\]()#+\-\.!|>~:]'
|
||||
|
||||
def get_first_sentence_by_regex(text: str, fallback_length: int = 10) -> str:
|
||||
"""
|
||||
1. 去除首尾空白;
|
||||
2. 删除指定范围的特殊符号;
|
||||
3. 如果存在换行符,则只保留第一行;
|
||||
4. 使用正则表达式从文本中截取第一句;
|
||||
若未找到句号,则按 fallback_length 的逻辑截取。
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
# 1) 去除首尾空白
|
||||
text = text.strip()
|
||||
# 使用re.sub去除Markdown特殊字符
|
||||
text = re.sub(MARKDOWN_SPECIAL_CHARS, '', text)
|
||||
|
||||
# 3) 如果存在换行符,则只保留第一行
|
||||
# 可以使用 splitlines() 或 find('\n')
|
||||
first_newline_index = text.find('\n')
|
||||
if first_newline_index != -1:
|
||||
text = text[:first_newline_index]
|
||||
# 进一步去掉换行带来的首尾空格
|
||||
text = text.strip()
|
||||
|
||||
# 4) 使用正则表达式从文本中截取第一句
|
||||
pattern = re.compile(r'^(.*?。)')
|
||||
match = pattern.search(text)
|
||||
if match:
|
||||
# 如果匹配到第一个句号,返回该句子(含句号)
|
||||
return match.group(1).strip()
|
||||
else:
|
||||
# 如果没找到句号,则退而求其次,截取前N个字符+""..."
|
||||
if len(text) <= fallback_length:
|
||||
return text
|
||||
else:
|
||||
return text[:fallback_length] + "..."
|
||||
|
||||
def get_text_by_regex(text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
# 去除首尾空白
|
||||
text = text.strip()
|
||||
# 使用re.sub去除Markdown特殊字符
|
||||
return re.sub(MARKDOWN_SPECIAL_CHARS, '', text)
|
||||
|
||||
def split_questions(text: str) -> str:
|
||||
"""
|
||||
三重分隔策略,并直接返回用 <strip> 拼接好的字符串。
|
||||
"""
|
||||
text = text.strip()
|
||||
|
||||
# 按Q分割
|
||||
parts = re.split(r'(?=Q\d+[::])', text)
|
||||
parts = [p.strip() for p in parts if p.strip()]
|
||||
if len(parts) > 1:
|
||||
return '<strip>\n\n'.join(parts) + '<strip>'
|
||||
|
||||
# 按换行分割
|
||||
parts = re.split(r'[\r\n\t]+', text)
|
||||
parts = [p.strip() for p in parts if p.strip()]
|
||||
if len(parts) > 1:
|
||||
return '<strip>\n\n'.join(parts) + '<strip>'
|
||||
|
||||
# 按标点符号分割
|
||||
tmp = re.sub(r'([?\?!!。])', r'\1<sep>', text)
|
||||
parts = [p.strip() for p in tmp.split('<sep>') if p.strip()]
|
||||
return '<strip>\n\n'.join(parts) + '<strip>'
|
||||
Reference in New Issue
Block a user