Files
gangyan/langchain-chat/server/chat/utils.py

584 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import os
from fastapi.responses import FileResponse
import pymysql
from contextvars import ContextVar
from configs.kb_config import DOWNLOAD_HOST_CK, KB_ROOT_PATH2, ck_mysql_config
from datetime import datetime
from pathlib import Path
import re
import threading
import requests
import time
import oss2
from pydantic import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate
from configs import logger, log_verbose
from typing import Any, List, Tuple, Dict, Union
from configs.oss_config import *
from apscheduler.schedulers.background import BackgroundScheduler
from fastapi import FastAPI, HTTPException
from configs.kb_config import similarity_url,similarity_score,similarity_internet
import httpx
import json
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from configs import kb_config as knowledgeBase_config
app = FastAPI()
scheduler = BackgroundScheduler()
# 创建一个ContextVar来存储请求范围内的共享字典
shared_variable: ContextVar[Dict[str, Any]] = ContextVar("shared_variable", default={})
expiration_times: ContextVar[Dict[str, float]] = ContextVar("expiration_times", default={})
def get_similar_documents(index,sentences,query: str, docs: List, top_k: int = 3):
if len(index) == 0:
# response = requests.post(similarity_url, json={"query": query, "title": sentences})
# result = response.json()
# num = 0
# for score in result["score"]:
# docs[num]["score"] = score
# num+=1
# docs = [doc for doc in docs if doc["score"] > similarity_internet]
# docs = sorted(docs, key=lambda x: x["score"], reverse=True)[:top_k]
docs = []
else:
docs = [docs[i] for i in range(0,len(docs)) if i in index]
return docs
def get_similar_documents1(index,sentences,query: str, docs: List[DocumentWithVSId], top_k: int = 3):
if len(index) == 0:
# response = requests.post(similarity_url, json={"query": query, "title": sentences})
# result = response.json()
# num = 0
# for score in result["score"]:
# docs[num][0].metadata["score"] = score
# num+=1
# docs = [doc for doc in docs if doc[0].metadata["score"] > similarity_score]
# docs = sorted(docs, key=lambda x: x[0].metadata["score"], reverse=True)[:top_k]
docs = []
else:
docs = [docs[i] for i in range(0,len(docs)) if i in index]
return docs
def remove_docs1(titles: List, docs: List[DocumentWithVSId]):
title = []
try:
res = [doc for doc in docs if doc.metadata["title"] not in titles]
for i in res:
title.append(i.metadata["title"])
except Exception as e:
res = [doc for doc in docs if doc.metadata["source"] not in titles]
for i in res:
title.append(i.metadata["source"])
return res,title
def remove_docs(titles: List, docs: List):
title = []
res = [doc for doc in docs if doc["title"] not in titles]
for i in res:
title.append(i["title"])
return res,title
def upload_image_to_oss(local_image_path, oss_file_name):
# 使用从ossconfig导入的配置信息
auth = oss2.Auth(access_key_id,access_key_secret)
bucket = oss2.Bucket(auth, endpoint, bucket_name)
# 上传文件
try:
object_name = f"files/chat_pic/{oss_file_name}"
with open(local_image_path, 'rb') as fileobj:
bucket.put_object(object_name, fileobj)
print(f"文件 {oss_file_name} 上传成功!")
# 构建线上URL
url = f"http://{bucket_name}.{endpoint}/{object_name}"
# url = bucket.sign_url('GET', object_name, 60 * 60 * 24)
return url
except Exception as e:
print(f"上传失败: {e}")
def get_personal_knowledge_map(uuid_names: list[str]):
config = ck_mysql_config
conn = pymysql.connect(**config)
try:
with conn.cursor() as cursor:
placeholders = ','.join(['%s'] * len(uuid_names))
sql = f"SELECT filepath,filename FROM gpt_upload_file WHERE filepath IN ({placeholders})"
filepaths = [f'./files/{uuid}' for uuid in uuid_names]
cursor.execute(sql, filepaths)
rows = cursor.fetchall()
result = {row[0].replace("./files/",""): row[1] for row in rows}
return result
finally:
conn.close()
def solve_knowledge_map(knowledge_map: list) -> list:
for inum,doc in enumerate(knowledge_map):
if doc == '政策库':
knowledge_map[inum] = knowledgeBase_config.DEFAULT_POLICY_BASE
elif doc == '期刊论文库':
knowledge_map[inum] = knowledgeBase_config.DEFAULT_JOURNAL_BASE
elif doc == '报告库':
knowledge_map[inum] = knowledgeBase_config.DEFAULT_REPORT_BASE1
elif doc == '冶金行业新闻库':
knowledge_map[inum] = knowledgeBase_config.GY_NEWS_BASE
elif doc == '冶金行业报告库':
knowledge_map[inum] = knowledgeBase_config.GY_REPORT_BASE
elif doc == '冶金专业知识库':
knowledge_map[inum] = knowledgeBase_config.GY_JOURNAL_BASE
# 新增冶金系列知识库中文名到英文名的映射
elif doc == knowledgeBase_config.YJ_CH_JOURNAL_BASE_NAME:
knowledge_map[inum] = knowledgeBase_config.YJ_CH_JOURNAL_BASE
elif doc == knowledgeBase_config.YJ_NEWS_BASE_NAME:
knowledge_map[inum] = knowledgeBase_config.YJ_NEWS_BASE
elif doc == knowledgeBase_config.YJ_FOR_JOURNAL_BASE_NAME:
knowledge_map[inum] = knowledgeBase_config.YJ_FOR_JOURNAL_BASE
elif doc == knowledgeBase_config.YJ_OA_JOURNAL_BASE_NAME:
knowledge_map[inum] = knowledgeBase_config.YJ_OA_JOURNAL_BASE
elif doc == knowledgeBase_config.YJ_POLICYS_BASE_NAME:
knowledge_map[inum] = knowledgeBase_config.YJ_POLICYS_BASE
elif doc == knowledgeBase_config.STEEL_KB_NAME:
knowledge_map[inum] = knowledgeBase_config.STEEL_KB
return knowledge_map
def get_shared_variable(key: str) -> Any:
# 获取当前上下文中的字典
context_dict = shared_variable.get()
# 返回字典中对应键的值
return context_dict.get(key)
def set_shared_variable(key: str, value: Any, ttl: int = 300):
context_dict = shared_variable.get()
expiration_dict = expiration_times.get()
# 设置键值和过期时间
context_dict[key] = value
now = datetime.now()
current_time = int(now.timestamp())
expiration_dict[key] = current_time + ttl
shared_variable.set(context_dict)
expiration_times.set(expiration_dict)
#手动删除
def remove_shared_variable(key: str):
context_dict = shared_variable.get()
expiration_dict = expiration_times.get()
# 删除键值对
if key in context_dict:
del context_dict[key]
if key in expiration_dict:
del expiration_dict[key]
shared_variable.set(context_dict)
expiration_times.set(expiration_dict)
async def clear_expired_keys():
while True:
time.sleep(2) # 每秒检查一次
print("Checking expired keys...")
now = datetime.now()
current_time = int(now.timestamp())
context_dict = shared_variable.get()
expiration_dict = expiration_times.get()
# 找出所有过期的键
keys_to_delete = [key for key, exp_time in expiration_dict.items() if exp_time <= current_time]
for key in keys_to_delete:
del context_dict[key]
del expiration_dict[key]
shared_variable.set(context_dict)
expiration_times.set(expiration_dict)
@app.on_event("startup")
async def app_start():
scheduler.add_job(clear_expired_keys, 'interval', seconds=3)
scheduler.start()
def compute_lps(pattern: str) -> list:
"""计算模式字符串的最长前缀后缀数组"""
lps = [0] * len(pattern)
length = 0
i = 1
while i < len(pattern):
if pattern[i] == pattern[length]:
length += 1
lps[i] = length
i += 1
else:
if length != 0:
length = lps[length - 1]
else:
lps[i] = 0
i += 1
return lps
#匹配字符串是否以模式开头
def kmp_match(text: str, pattern: str) -> bool:
"""使用KMP算法检查模式是否从文本的第一个字符开始匹配"""
n = len(text)
m = len(pattern)
# 计算模式字符串的最长前缀后缀数组
lps = compute_lps(pattern)
i = 0 # text的索引
j = 0 # pattern的索引
while i < n:
if pattern[j] == text[i]:
i += 1
j += 1
if j == m:
# 如果j等于模式的长度说明匹配成功
return i - j == 0
elif i < n and pattern[j] != text[i]:
if j != 0:
j = lps[j - 1]
else:
i += 1
return False
def remove_before_and_including(text: str, substring: str) -> str:
# 找到子字符串的位置
index = text.find(substring)
# 如果子字符串存在,则返回从子字符串之后的部分
if index != -1:
return text[index + len(substring)-1:]
else:
# 如果子字符串不存在,返回原始字符串
return text
def remove_after_and_including(text: str, substring: str) -> str:
# 找到子字符串的位置
index = text.find(substring)
# 如果子字符串存在,则返回从子字符串之前的部分
if index != -1:
return text[:index]
else:
# 如果子字符串不存在,返回原始字符串
return text
def remove_after_and_includings(text, keyword):
# 使用正则表达式找到第一个关键字的位置
pattern = re.escape(keyword)
match = re.search(pattern, text)
if match:
# 返回关键字之前的部分
return text[:match.start()]
return text
def download_self_doc(knowledge_name: str, file_name: str):
# 防止路径穿越
safe_name = Path(file_name).name
file_path = KB_ROOT_PATH2 / knowledge_name / "content" / safe_name
if not file_path.exists() or not file_path.is_file():
raise HTTPException(status_code=404, detail="暂未找到文件")
name_map = get_personal_knowledge_map([safe_name])
# 如果字典里有对应项,就用原始名,否则退回 safe_name
download_name = name_map.get(safe_name, safe_name)
return FileResponse(
path=str(file_path),
media_type="application/octet-stream",
filename=download_name,
)
#文档转数组
def doc_to_list(num: int, knowledge_name:str, docs: List[DocumentWithVSId],source_documents):
for inum,doc in enumerate(docs):
num +=1
if knowledge_name == knowledgeBase_config.DEFAULT_REPORT_BASE_NAME:
filename = doc.metadata.get("source").strip()
text = f"""[{num}] [{filename}]"""
elif knowledge_name in (knowledgeBase_config.GY_NEWS_BASE_NAME, knowledgeBase_config.GY_REPORT_BASE_NAME, knowledgeBase_config.GY_JOURNAL_BASE_NAME, knowledgeBase_config.YJ_NEWS_BASE_NAME):
filename = doc.metadata.get("title").strip()
detail_url = doc.metadata.get("url") or doc.metadata.get("releaseUrl", "")
text = f"""[{num}] [{filename}]({detail_url})"""
# 判断是否是个人知识库
elif knowledgeBase_config.SELF_KNOWLEDGE_BASE.match(knowledge_name) or knowledge_name == "coding":
filename = doc.metadata.get("source").strip()
uuid_name = doc.metadata.get("uuid_name").strip()
# text = f"""[{num}] [{filename}]\n"""
detail_url = f"{DOWNLOAD_HOST_CK}?filename={knowledge_name}/content/{uuid_name}&originname={filename}"
text = f"[{num}] [{Path(filename).stem}]({detail_url})\n"
elif knowledge_name in (knowledgeBase_config.YJ_CH_JOURNAL_BASE_NAME, knowledgeBase_config.YJ_FOR_JOURNAL_BASE_NAME, knowledgeBase_config.YJ_OA_JOURNAL_BASE_NAME):
filename = doc.metadata.get("title", "未命名").strip()
text = f"[{num}] {filename}"
elif knowledge_name == knowledgeBase_config.STEEL_KB_NAME:
filename = doc.metadata.get("title", "未命名").strip()
detail_url = doc.metadata.get("url", "")
text = f"[{num}] [{filename}]({detail_url})"
else:
filename = doc.metadata.get("title").strip()
try:
detail_url = "https://policy.ckcest.cn/detail/" + doc.metadata.get("primary_key") + ".html"
except:
detail_url = "https://kgo.ckcest.cn/kgo/detail/1002/ads_journal_article/" + doc.metadata.get("ID") + ".html"
if filename:
if doc.metadata.get('_type') == 'title':
text = f"""[{num}] [{filename}]({detail_url})"""
else:
text = f"""[{num}] [{filename}]({detail_url})"""
else:
if doc.metadata.get('_type') == 'title':
text = f"""[{num}] [{"原文地址"}]({detail_url})"""
else:
text = f"""[{num}] [{"原文地址"}]({detail_url})"""
source_documents.append(text.replace('\n', '').replace('\r', ''))
def solve_mental_data(knowledge_name,docs: List[DocumentWithVSId],doc,seen_docs,duplicate_indices,knowledge,inum):
# 判断是否是个人知识库
if knowledgeBase_config.SELF_KNOWLEDGE_BASE.match(knowledge_name) or knowledge_name == "coding":
doc_identifier = (doc.metadata["source"], doc.page_content)
if doc_identifier not in seen_docs:
seen_docs.add(doc_identifier)
knowledge.append(f"""参考资料[{inum + 1}] {doc.metadata["summary"]}""")
else:
duplicate_indices.append(inum)
else:
if "summary" in doc.metadata:
if len(doc.metadata['summary'])>15:
doc_identifier = (doc.metadata['title'], doc.page_content) if "title" in doc.metadata else (doc.metadata["source"], doc.page_content)
# 检查此标识符是否已存在于集合中
if doc_identifier not in seen_docs:
seen_docs.add(doc_identifier)
# 并将文档信息添加到knowledge列表中
if knowledge_name == knowledgeBase_config.DEFAULT_JOURNAL_BASE:
date_str = doc.metadata['publish_date']
date_obj = datetime.strptime(str(date_str), '%Y%m%d')
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}资料来源: {doc.metadata['data_source'].replace('.pdf','')}资料时间:{date_obj} \n資料内容: {doc.metadata['summary']}""")
elif knowledge_name == knowledgeBase_config.DEFAULT_REPORT_BASE:
try:
if doc.metadata['title'] == doc.page_content:
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}\n资料内容: {doc.metadata['summary']}""")
else:
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}\n资料内容: {doc.page_content}""")
except Exception as e:
if doc.metadata['source'] == doc.page_content:
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['source']}\n资料内容: {doc.metadata['summary']}""")
else:
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['source']}\n资料内容: {doc.page_content}""")
elif knowledge_name == knowledgeBase_config.GY_NEWS_BASE or knowledge_name == knowledgeBase_config.GY_REPORT_BASE or knowledge_name == knowledgeBase_config.GY_JOURNAL_BASE:
# 支持冶金新闻库2024年以及之前
if knowledge_name == knowledgeBase_config.YJ_NEWS_BASE:
date_str = doc.metadata.get('publish_date', '') or doc.metadata.get('publish_year', '')
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata.get('title','')}
资料内容: {doc.metadata.get('summary', doc.page_content)}""")
else:
if doc.metadata['title'] == doc.page_content:
date_str = doc.metadata['publish_date']
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}
资料内容: {doc.metadata['summary']}""")
else:
date_str = doc.metadata['publish_date']
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}
资料内容: {doc.page_content}""")
# ------- 新增冶金期刊类知识库处理 -------
elif knowledge_name in (knowledgeBase_config.YJ_CH_JOURNAL_BASE, knowledgeBase_config.YJ_FOR_JOURNAL_BASE, knowledgeBase_config.YJ_OA_JOURNAL_BASE):
# 处理期刊类文献,优先使用摘要
if doc.metadata.get('title') == doc.page_content:
content_used = doc.metadata.get('summary', doc.page_content)
else:
content_used = doc.page_content
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata.get('title','未命名')}
资料内容: {content_used}""")
# ------- 新增冶金政策库处理 -------
elif knowledge_name == knowledgeBase_config.YJ_POLICYS_BASE:
title = doc.metadata.get('title') or doc.metadata.get('source', '')
content_used = doc.metadata.get('summary', doc.page_content)
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {title}
资料内容: {content_used}""")
else:
try:
date_str = doc.metadata['release_date']
date_obj = datetime.strptime(str(date_str), '%Y%m%d')
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['title']}资料来源: {doc.metadata['source'].replace('.pdf','')}资料时间:{date_obj} \n資料内容: {doc.metadata['summary']}""")
except Exception as e:
if knowledge_name in knowledgeBase_config.STEEL_KB:
knowledge.append(f"""参考资料[{len(knowledge) + 1}]发布时间:{doc.metadata['date']} 资料来源: {doc.metadata['source']}标题:{doc.metadata['title']}\n資料内容: {doc.metadata['summary']}""")
else:
knowledge.append(f"""参考资料[{len(knowledge) + 1}] {doc.metadata['source']}\n資料内容: {doc.metadata['summary']}""")
else:
duplicate_indices.append(inum)
else:
duplicate_indices.append(inum)
# knowledge.append(f"""参考资料[{len(knowledge) + 1}] 报告来源: {doc.metadata['source'].replace('.pdf','')} \n报告内容 {doc.content}""")
else:
doc_identifier = (doc.metadata['title'], doc.page_content)
if doc_identifier not in seen_docs:
seen_docs.add(doc_identifier)
if doc.metadata["_type"] == "title":
try:
if knowledge_name in knowledgeBase_config.GY_BASE_NAME:
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']} \n文章内容 {doc.metadata['content']}""")
else:
date_str = doc.metadata['release_date']
date_obj = datetime.strptime(date_str, '%Y%m%d')
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content}资料时间{date_obj} \n文章内容 {doc.metadata['content']}""")
except Exception as e:
date_str = doc.metadata['publish_date'] if 'publish_date' in doc.metadata else ""
date_obj = datetime.strptime(date_str, '%Y%m%d') if not date_str=="" else ""
if knowledge_name in knowledgeBase_config.DEFAULT_REPORT_BASE1:
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['source']}\n文章内容 {doc.page_content}""")
else:
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content}资料时间{date_obj} \n文章内容 {doc.metadata['abstract']}""")
else:
try:
if knowledge_name in knowledgeBase_config.GY_BASE_NAME:
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']} \n文章内容 {doc.page_content}""")
else:
date_str = doc.metadata['release_date']
date_obj = datetime.strptime(date_str, '%Y%m%d')
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['title']}资料时间{date_obj} \n文章内容 {doc.page_content}""")
except Exception as e:
date_str = doc.metadata['publish_date'] if 'publish_date' in doc.metadata else ""
date_obj = datetime.strptime(date_str, '%Y%m%d') if not date_str=="" else ""
if knowledge_name in knowledgeBase_config.DEFAULT_REPORT_BASE1:
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.metadata['source']}\n文章内容 {doc.page_content}""")
else:
knowledge.append(f"""参考资料[{inum + 1}] 文章标题 {doc.page_content}资料时间{date_obj} \n文章内容 {doc.metadata['abstract']}""")
else:
duplicate_indices.append(inum)
@app.post("/chat_translate")
async def chat_translate(query, lang: str):
if lang == "zh2en":
url = 'http://192.168.56.123:8845/translate'
elif lang == "en2zh":
url = 'http://192.168.56.123:8846/translate'
else:
raise ValueError("lang参数错误")
# print("query", query)
data = {'text': query}
headers = {'Content-Type': 'application/json'}
async with httpx.AsyncClient() as client:
response = await client.post(url, data=json.dumps(data), headers=headers)
print("response", response)
if response.status_code == 200:
result = response.json()
return result['translation']
else:
raise Exception(status_code=response.status_code, detail="请求失败")
class History(BaseModel):
"""
对话历史
可从dict生成
h = History(**{"role":"user","content":"你好"})
也可转换为tuple
h.to_msy_tuple = ("human", "你好")
"""
role: str = Field(...)
content: str = Field(...)
def to_msg_tuple(self):
return ("assistant", self.content) if self.role == "assistant" else ("system", self.content) if self.role == "system" else ("user", self.content)
def to_msg_template(self, is_raw=True) -> ChatMessagePromptTemplate:
role_maps = {
"ai": "assistant",
"human": "user",
"system": "system",
}
role = role_maps.get(self.role, self.role)
if is_raw: # 当前默认历史消息都是没有input_variable的文本。
content = "{% raw %}" + self.content + "{% endraw %}"
else:
content = self.content
return ChatMessagePromptTemplate.from_template(
content,
"jinja2",
role=role,
)
@classmethod
def from_data(cls, h: Union[List, Tuple, Dict]) -> "History":
if isinstance(h, (list,tuple)) and len(h) >= 2:
h = cls(role=h[0], content=h[1])
elif isinstance(h, dict):
h = cls(**h)
return h
# 定义Markdown特殊字符的正则表达式作为常量
MARKDOWN_SPECIAL_CHARS = r'[\\`*_{}\[\]()#+\-\.!|>~:]'
def get_first_sentence_by_regex(text: str, fallback_length: int = 10) -> str:
"""
1. 去除首尾空白;
2. 删除指定范围的特殊符号;
3. 如果存在换行符,则只保留第一行;
4. 使用正则表达式从文本中截取第一句;
若未找到句号,则按 fallback_length 的逻辑截取。
"""
if not text:
return ""
# 1) 去除首尾空白
text = text.strip()
# 使用re.sub去除Markdown特殊字符
text = re.sub(MARKDOWN_SPECIAL_CHARS, '', text)
# 3) 如果存在换行符,则只保留第一行
# 可以使用 splitlines() 或 find('\n')
first_newline_index = text.find('\n')
if first_newline_index != -1:
text = text[:first_newline_index]
# 进一步去掉换行带来的首尾空格
text = text.strip()
# 4) 使用正则表达式从文本中截取第一句
pattern = re.compile(r'^(.*?。)')
match = pattern.search(text)
if match:
# 如果匹配到第一个句号,返回该句子(含句号)
return match.group(1).strip()
else:
# 如果没找到句号则退而求其次截取前N个字符+""..."
if len(text) <= fallback_length:
return text
else:
return text[:fallback_length] + "..."
def get_text_by_regex(text: str) -> str:
if not text:
return ""
# 去除首尾空白
text = text.strip()
# 使用re.sub去除Markdown特殊字符
return re.sub(MARKDOWN_SPECIAL_CHARS, '', text)
def split_questions(text: str) -> str:
"""
三重分隔策略,并直接返回用 <strip> 拼接好的字符串。
"""
text = text.strip()
# 按Q分割
parts = re.split(r'(?=Q\d+[:])', text)
parts = [p.strip() for p in parts if p.strip()]
if len(parts) > 1:
return '<strip>\n\n'.join(parts) + '<strip>'
# 按换行分割
parts = re.split(r'[\r\n\t]+', text)
parts = [p.strip() for p in parts if p.strip()]
if len(parts) > 1:
return '<strip>\n\n'.join(parts) + '<strip>'
# 按标点符号分割
tmp = re.sub(r'([\?!。])', r'\1<sep>', text)
parts = [p.strip() for p in tmp.split('<sep>') if p.strip()]
return '<strip>\n\n'.join(parts) + '<strip>'