[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
76
langchain-chat/server/knowledge_base/TexkRank.py
Normal file
76
langchain-chat/server/knowledge_base/TexkRank.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import multiprocessing
|
||||
import re
|
||||
import time
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from textrank4zh import TextRank4Keyword, TextRank4Sentence
|
||||
from joblib import Parallel, delayed, parallel_backend
|
||||
import logging
|
||||
nx.from_numpy_matrix = nx.from_numpy_array
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
|
||||
def process_text_segment(text_segment, num_sentences):
|
||||
tr4w = TextRank4Keyword()
|
||||
tr4w.analyze(text=text_segment, lower=True, window=5)
|
||||
keywords = [(item.word, item.weight) for item in tr4w.get_keywords(30, word_min_len=4)]
|
||||
|
||||
tr4s = TextRank4Sentence()
|
||||
tr4s.analyze(text=text_segment, lower=True, source='all_filters')
|
||||
summaries = [item.sentence for item in tr4s.get_key_sentences(num=num_sentences)]
|
||||
|
||||
return keywords, summaries
|
||||
|
||||
|
||||
def split_text_by_sentences(text, n_parts):
|
||||
"""Split the text into n_parts based on sentences using regular expressions."""
|
||||
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text)
|
||||
k, m = divmod(len(sentences), n_parts)
|
||||
return [' '.join(sentences[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n_parts)]
|
||||
from nltk.tokenize import sent_tokenize
|
||||
def split_text_balanced(text, n_parts):
|
||||
sentences = sent_tokenize(text)
|
||||
min_sentences_per_part = 10
|
||||
n_parts = max(1, min(n_parts, len(sentences) // min_sentences_per_part))
|
||||
k, m = divmod(len(sentences), n_parts)
|
||||
return [' '.join(sentences[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n_parts)]
|
||||
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
def TextRank(text,num_sentences, n_cores=multiprocessing.cpu_count()):
|
||||
start_time = time.time()
|
||||
logging.info("TextRank 函数开始执行")
|
||||
|
||||
# text_parts = split_text_by_sentences(text, n_cores)
|
||||
text_parts = split_text_balanced(text, n_cores)
|
||||
all_keywords = []
|
||||
all_summaries = []
|
||||
|
||||
# with ProcessPoolExecutor (max_workers=n_cores) as executor:
|
||||
# future_to_part = {executor.submit(process_text_segment, part, num_sentences): part for part in text_parts}
|
||||
# for future in as_completed(future_to_part):
|
||||
# keywords, summaries = future.result()
|
||||
# all_keywords.extend(keywords)
|
||||
# all_summaries.extend(summaries)
|
||||
for part in text_parts:
|
||||
keywords, summaries = process_text_segment(part, num_sentences)
|
||||
all_keywords.extend(keywords)
|
||||
all_summaries.extend(summaries)
|
||||
for word, weight in sorted(all_keywords, key=lambda x: x[1], reverse=True):
|
||||
print(word, weight)
|
||||
|
||||
all_summaries = "".join(all_summaries)
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
logging.info(f"TextRank 函数执行结束,耗时: {elapsed_time:.2f} 秒")
|
||||
|
||||
return all_summaries
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 传入必要的参数
|
||||
num_sentences = 80
|
||||
text = """中华人民共和国国民经济和社会发展第十四个五年(2021-2025年)规划和2035年远景目标纲要"""
|
||||
summary = TextRank(text, num_sentences)
|
||||
print(f"原文长度{len(text)},压缩文本后长度 {len(summary)}")
|
||||
3
langchain-chat/server/knowledge_base/__init__.py
Normal file
3
langchain-chat/server/knowledge_base/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# from .kb_api import list_kbs, create_kb, delete_kb
|
||||
# from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
|
||||
# from .utils import KnowledgeFile, KBServiceFactory
|
||||
372
langchain-chat/server/knowledge_base/cleanpdf.py
Normal file
372
langchain-chat/server/knowledge_base/cleanpdf.py
Normal file
@@ -0,0 +1,372 @@
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
from bs4 import BeautifulSoup
|
||||
from collections import defaultdict
|
||||
import cssutils
|
||||
from server.knowledge_base.file_converter import FileConverter
|
||||
import uuid
|
||||
import base64
|
||||
|
||||
class PdfConverter(FileConverter):
|
||||
def _clean_pdf_html(self, html: str) -> str:
|
||||
"""HTML后处理方法"""
|
||||
soup = BeautifulSoup(html, 'html.parser')
|
||||
|
||||
# 处理样式表中的CSS规则
|
||||
def process_rule(rule):
|
||||
if rule.type == rule.MEDIA_RULE:
|
||||
for nested_rule in rule:
|
||||
process_rule(nested_rule)
|
||||
elif rule.type == rule.STYLE_RULE:
|
||||
# 移除文本选择限制属性
|
||||
for prop in ['user-select', '-webkit-user-select', '-moz-user-select', '-ms-user-select']:
|
||||
rule.style.removeProperty(prop)
|
||||
|
||||
# 原有处理逻辑保持不变
|
||||
if any('#page-container-1' in selector.selectorText for selector in rule.selectorList):
|
||||
rule.style.removeProperty('background-color')
|
||||
rule.style.removeProperty('background-image')
|
||||
|
||||
if any(re.search(r'(^|[\s>+~])\.pf($|[\s\[.:>+~])', selector.selectorText)
|
||||
for selector in rule.selectorList):
|
||||
for prop in ['box-shadow', 'border-collapse']:
|
||||
for _ in range(3):
|
||||
if rule.style.removeProperty(prop):
|
||||
break
|
||||
|
||||
# 处理内联样式
|
||||
def clean_inline_styles(tag):
|
||||
if tag.has_attr('style'):
|
||||
style = cssutils.parseStyle(tag['style'])
|
||||
# 移除文本选择限制属性
|
||||
for prop in ['user-select', '-webkit-user-select', '-moz-user-select', '-ms-user-select']:
|
||||
style.removeProperty(prop)
|
||||
# 原有处理逻辑保持不变
|
||||
if tag.get('id') == 'page-container-1':
|
||||
style.removeProperty('background-color')
|
||||
style.removeProperty('background-image')
|
||||
if 'pf' in tag.get('class', []):
|
||||
style.removeProperty('box-shadow')
|
||||
style.removeProperty('border-collapse')
|
||||
tag['style'] = style.cssText.replace('\n', ' ').strip()
|
||||
if not tag['style']:
|
||||
del tag['style']
|
||||
|
||||
# 清理空的和仅含空格的span标签
|
||||
for span in soup.find_all('span'):
|
||||
# 判断是否包含可见内容
|
||||
if not span.text.strip():
|
||||
span.decompose()
|
||||
else:
|
||||
# 清理内部的空白字符
|
||||
if span.string and span.string.isspace():
|
||||
span.string.replace_with(' ')
|
||||
# 处理包含多个空白文本节点的情况
|
||||
elif all(isinstance(c, str) and c.isspace() for c in span.contents):
|
||||
span.replace_with(' ')
|
||||
|
||||
# 原有处理流程
|
||||
for style_tag in soup.find_all('style'):
|
||||
if style_tag.string:
|
||||
try:
|
||||
sheet = cssutils.parseString(style_tag.string)
|
||||
for rule in sheet:
|
||||
process_rule(rule)
|
||||
style_tag.string = sheet.cssText.decode('utf-8')\
|
||||
.replace('\\n', '\n')\
|
||||
.replace(' !important', '!important')
|
||||
except Exception as e:
|
||||
print(f"CSS处理错误: {str(e)}")
|
||||
continue
|
||||
|
||||
for container in soup.select('#page-container-1'):
|
||||
clean_inline_styles(container)
|
||||
for pf_element in soup.select('.pf'):
|
||||
clean_inline_styles(pf_element)
|
||||
|
||||
content = str(soup)
|
||||
content = self._add_pdf_element_ids(content)
|
||||
|
||||
if hasattr(self, 'page_container_id') and self.page_container_id:
|
||||
new_id = self.page_container_id
|
||||
head_pattern = re.compile(
|
||||
r'(<head[^>]*>)(.*?)(</head>)',
|
||||
re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
def replace_head(match):
|
||||
head_content = match.group(2)
|
||||
head_content = re.sub(
|
||||
r'(id\s*=\s*["\']?)page-container(["\'\]>])',
|
||||
f'\\g<1>{new_id}\\g<2>',
|
||||
head_content
|
||||
)
|
||||
head_content = re.sub(
|
||||
r'(#[^{\s>]+?{.*?)(\bbackground-(color|image)\s*:[^;]+;?)',
|
||||
lambda m: m.group(1) if m.group(2) else m.group(0),
|
||||
head_content,
|
||||
flags=re.DOTALL|re.IGNORECASE
|
||||
)
|
||||
return f"{match.group(1)}{head_content}{match.group(3)}"
|
||||
|
||||
content = head_pattern.sub(replace_head, content)
|
||||
|
||||
content = re.sub(
|
||||
r'<script\b[^>]*>[\s\S]*?</script>',
|
||||
'',
|
||||
content,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
|
||||
return content.strip()
|
||||
|
||||
def _add_pdf_element_ids(self, content: str) -> str:
|
||||
"""为元素添加唯一ID"""
|
||||
counters = defaultdict(int)
|
||||
self.page_container_id = None # 重置ID记录
|
||||
|
||||
def replace_tag(match):
|
||||
tag = match.group(1).lower()
|
||||
attrs = match.group(2)
|
||||
|
||||
# 处理page-container的特殊逻辑
|
||||
if tag == "div":
|
||||
id_match = re.search(
|
||||
r'\bid\s*=\s*["\']page-container["\']',
|
||||
attrs,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
if id_match:
|
||||
# 生成唯一ID并记录
|
||||
if not self.page_container_id:
|
||||
counters['page-container'] += 1
|
||||
self.page_container_id = f"page-container-{counters['page-container']}"
|
||||
# 保留其他属性
|
||||
clean_attrs = re.sub(r'\s+id="[^"]*"', '', attrs)
|
||||
return f'<div id="{self.page_container_id}"{clean_attrs}>'
|
||||
|
||||
# 常规标签处理
|
||||
counters[tag] += 1
|
||||
clean_attrs = re.sub(r'\s+id="[^"]*"', '', attrs)
|
||||
return f'<{tag} id="{tag}-{counters[tag]}"{clean_attrs}>'
|
||||
|
||||
# 处理所有目标标签
|
||||
return re.sub(
|
||||
r'<(h[1-6]|p|div|span)(\b[^>]*)>',
|
||||
replace_tag,
|
||||
content,
|
||||
flags=re.IGNORECASE
|
||||
)
|
||||
def _save_pdf_html(self, content: str, output_path: Optional[str] = None) -> str:
|
||||
"""统一保存方法"""
|
||||
cleaned = self._clean_pdf_html(content)
|
||||
# cleaned = self._add_pdf_element_ids(content)
|
||||
if output_path:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(cleaned)
|
||||
return cleaned
|
||||
|
||||
# def pdf_to_html(self, input_path: str, output_path: Optional[str] = None) -> str:
|
||||
# """PDF转换方法"""
|
||||
# cmd = [
|
||||
# 'pdf2htmlEX',
|
||||
# '--zoom', '1.2', # 放大
|
||||
# '--split-pages', '0', # 保持整体布局
|
||||
# # '--embed-css', '0', # 避免内联样式冲突
|
||||
# # '--embed-image', '0', # 避免内联图片冲突
|
||||
# # '--optimize-text', '1', # 优化文本渲染
|
||||
# input_path
|
||||
# ]
|
||||
# result = subprocess.run(
|
||||
# 'cd /data3/pdffiles && ' + ' '.join(cmd),
|
||||
# shell=True,
|
||||
# stdout=subprocess.PIPE,
|
||||
# stderr=subprocess.STDOUT
|
||||
# )
|
||||
# print(f"转换状态: {result.returncode}\n输出: {result.stdout.decode()[:200]}")
|
||||
|
||||
# # 准备文件名
|
||||
# file_name = os.path.basename(input_path)[:-3] + "html"
|
||||
# html_path = f"/data3/pdffiles/{file_name}"
|
||||
|
||||
# if not os.path.exists(html_path):
|
||||
# return f"{file_name} 转换失败"
|
||||
|
||||
# # 读取并处理HTML内容
|
||||
# with open(html_path, 'r', encoding='utf-8') as file:
|
||||
# soup = BeautifulSoup(file, 'html.parser')
|
||||
|
||||
# # 移除注释
|
||||
# for comment in soup.find_all(string=lambda text: isinstance(text, str) and "Created by pdf2htmlEX" in text):
|
||||
# comment.extract()
|
||||
|
||||
# # 移除loading-indicator
|
||||
# for div in soup.find_all('div', class_='loading-indicator'):
|
||||
# div.decompose()
|
||||
|
||||
# # 移除所有包含sidebar的div
|
||||
# for div in soup.find_all('div', id=lambda x: x and 'sidebar' in x.lower()):
|
||||
# div.decompose()
|
||||
|
||||
# # 转换为字符串并处理base64
|
||||
# html_content = str(soup)
|
||||
|
||||
# # 清理临时文件
|
||||
# os.remove(html_path)
|
||||
|
||||
# # 处理base64图片
|
||||
# html_content = self.read_and_replace_base64(
|
||||
# html_content,
|
||||
# output_dir={GENERATED_IMAGES_BASE_PATH}
|
||||
# )
|
||||
|
||||
# return f"{self._save_pdf_html(html_content, output_path)}"
|
||||
|
||||
def pdf_to_html(self, input_path: str, output_path: Optional[str] = None) -> str:
|
||||
"""PDF 预览:与基类一致(本机 PyMuPDF 抽文本)。如需后处理可在此包装 super() 结果。"""
|
||||
return super().pdf_to_html(input_path, output_path)
|
||||
|
||||
def read_and_replace_base64(self,html_content, output_dir):
|
||||
image_index = 0 # 用于生成唯一的文件名
|
||||
def replace_base64(match):
|
||||
nonlocal image_index
|
||||
base64_data = match.group(0)
|
||||
# 保存 Base64 图片并获取文件路径
|
||||
# 提取文件类型和实际的 Base64 数据
|
||||
header, data = base64_data.split(',', 1)
|
||||
file_extension = header.split(';')[0].split('/')[1] # 获取文件扩展名
|
||||
file_name = f'image_{uuid.uuid1()}_{image_index}.{file_extension}' # 生成文件名
|
||||
file_path = os.path.join(output_dir, file_name)
|
||||
|
||||
# 将 Base64 数据解码并保存为文件
|
||||
with open(file_path, 'wb') as image_file:
|
||||
image_file.write(base64.b64decode(data))
|
||||
image_index += 1
|
||||
# 返回文件的 URL
|
||||
return f"http://127.0.0.1:8099/chat_web_backend/get-image?file_name={os.path.basename(file_path)}"
|
||||
|
||||
# 使用正则表达式匹配 Base64 字符串
|
||||
base64_pattern = r'data:image/(png|jpg|jpeg);base64,[A-Za-z0-9+/=]+'
|
||||
# base64_pattern = r'data:image/(png|jpg|jpeg);base64,[A-Za-z0-9+/=]+|data:application/font-woff;base64,[A-Za-z0-9+/=]+'
|
||||
updated_html_content = re.sub(base64_pattern, replace_base64, html_content)
|
||||
return updated_html_content
|
||||
|
||||
# def pdf_to_html(self, input_path: str, output_path: Optional[str] = None) -> str:
|
||||
# """PDF转换方法"""
|
||||
# try:
|
||||
# doc = fitz.open(input_path)
|
||||
# page_width = doc[0].rect.width
|
||||
# page_height = doc[0].rect.height
|
||||
# border_radius = 5
|
||||
# html = ['<style>','pre { background-color: #2d2d2d;color: #f8f8f2; padding: 10px;margin: 0;width: 80%;box-sizing: border-box;border-radius: 0px;}', '</style>', '<body style="position: relative;">']
|
||||
# image_save_path = '{GENERATED_IMAGES_BASE_PATH}'
|
||||
# pic_num =0
|
||||
# # 确保图片保存路径存在
|
||||
# os.makedirs(image_save_path, exist_ok=True)
|
||||
|
||||
# for page in doc:
|
||||
# blocks = page.get_text("dict")["blocks"]
|
||||
# sorted_blocks = sorted(blocks, key=lambda b: (b["bbox"][1], b["bbox"][0])) # 按y坐标和x坐标排序
|
||||
|
||||
# for block in sorted_blocks:
|
||||
# if "image" in block:
|
||||
# pic_num += 1
|
||||
# bbox = block["bbox"]
|
||||
# image_bytes = block["image"]
|
||||
# image_ext = block["ext"]
|
||||
# image_name = f'image_{page.number}_{pic_num}.{image_ext}'
|
||||
# image_url = f'http://127.0.0.1:8099/chat_web_backend/get-image?file_name={image_name}'
|
||||
# image_path = os.path.join(image_save_path, image_name)
|
||||
# # 保存图片到指定路径
|
||||
# with open(image_path, 'wb') as img_file:
|
||||
# img_file.write(image_bytes)
|
||||
# percent_left = (bbox[0]) / page_width * 100
|
||||
# # 获取页面的宽度和高度
|
||||
# container_width = page_width # 页面宽度
|
||||
# container_height = page_height # 页面高度
|
||||
|
||||
# # 计算图像的宽度和高度
|
||||
# img_width = bbox[2] - bbox[0] # 计算宽度
|
||||
# img_height = bbox[3] - bbox[1] # 计算高度
|
||||
|
||||
# # 计算百分比
|
||||
# width_percent = (img_width / container_width) * 100
|
||||
# height_percent = (img_height / container_height) * 100
|
||||
# html.append(f'<div style="width: {width_percent}%; height: {height_percent}%; margin-left: {percent_left}%;clear: both;overflow: auto;"><img src="{image_url}" alt="Image {pic_num}" style="max-width: 100%; height: auto;display: block;"/></div>')
|
||||
# if "lines" in block:
|
||||
# text_nums = 0
|
||||
# for line in block["lines"]:
|
||||
# is_code_block =any(span["font"].startswith(("Courier", "NSimSun")) for span in line["spans"]) # 假设代码使用Courier字体
|
||||
# if is_code_block:
|
||||
# html.append(f"<pre>")
|
||||
# for span in line["spans"]:
|
||||
# text_nums += 1
|
||||
# bbox = span["bbox"]
|
||||
# text = span["text"]
|
||||
# font = span["font"] # 字体
|
||||
# size = span["size"] # 字体大小
|
||||
# color = span["color"] # 字体颜色
|
||||
|
||||
# # 动态生成CSS样式
|
||||
# css_style = f'font-family: {font}; font-size: {size}px; color: #{color:06x};'
|
||||
# percent_left = (bbox[0]) / page_width * 100
|
||||
# # 根据字体大小判断标题
|
||||
# if size > 20: # 假设大于20的字体为标题
|
||||
# if text_nums == 1:
|
||||
# html.append(f'<h2 style="{css_style};display: inline; margin-left: {percent_left}%; ">{text.strip()}</h2>')
|
||||
# else:
|
||||
# html.append(f'<h3 style="{css_style};display: inline;">{text.strip()}</h3>')
|
||||
# else:
|
||||
# if text_nums == 1:
|
||||
# html.append(f'<p style="{css_style};display: inline; margin-left: {percent_left}%; ">{text.strip()}</p>')
|
||||
# else:
|
||||
# html.append(f'<p style="{css_style};display: inline; ">{text.strip()}</p>')
|
||||
|
||||
# if is_code_block or size<=20:
|
||||
# if is_code_block:
|
||||
# html.append("</pre>")
|
||||
# else:
|
||||
# html.append("<br>")
|
||||
# else:
|
||||
# html.append('<br>')
|
||||
# # html.append('<br>')
|
||||
|
||||
# html.append('</body>')
|
||||
|
||||
# # 将HTML内容保存到指定路径
|
||||
# html_content = ''.join(html)
|
||||
# if output_path:
|
||||
# with open(output_path, 'w', encoding='utf-8') as file:
|
||||
# file.write(html_content)
|
||||
# else:
|
||||
# # 如果没有指定路径,使用默认路径或返回HTML内容
|
||||
# output_path = 'output.html'
|
||||
# with open(output_path, 'w', encoding='utf-8') as file:
|
||||
# file.write(html_content)
|
||||
|
||||
# return output_path
|
||||
# except Exception as e:
|
||||
# raise RuntimeError(f"PDF转换失败: {str(e)}")
|
||||
# def replace_base64_with_url(self,html_content, output_dir):
|
||||
# image_index = 0 # 用于生成唯一的文件名
|
||||
# def replace_base64(match):
|
||||
# nonlocal image_index
|
||||
# base64_data = match.group(0)
|
||||
# # 保存 Base64 图片并获取文件路径
|
||||
# # 提取文件类型和实际的 Base64 数据
|
||||
# header, data = base64_data.split(',', 1)
|
||||
# file_extension = header.split(';')[0].split('/')[1] # 获取文件扩展名
|
||||
# file_name = f'image_{uuid.uuid1()}_{image_index}.{file_extension}' # 生成文件名
|
||||
# file_path = os.path.join(output_dir, file_name)
|
||||
|
||||
# # 将 Base64 数据解码并保存为文件
|
||||
# with open(file_path, 'wb') as image_file:
|
||||
# image_file.write(base64.b64decode(data))
|
||||
# image_index += 1
|
||||
# # 返回文件的 URL
|
||||
# return f"http://127.0.0.1:8099/chat_web_backend/get-image?file_name={os.path.basename(file_path)}"
|
||||
|
||||
# # 使用正则表达式匹配 Base64 字符串
|
||||
# base64_pattern = r'data:image/(png|jpg|jpeg);base64,[A-Za-z0-9+/=]+'
|
||||
# updated_html_content = re.sub(base64_pattern, replace_base64, html_content)
|
||||
# return updated_html_content
|
||||
1154
langchain-chat/server/knowledge_base/file_converter.py
Normal file
1154
langchain-chat/server/knowledge_base/file_converter.py
Normal file
File diff suppressed because it is too large
Load Diff
65
langchain-chat/server/knowledge_base/kb_api.py
Normal file
65
langchain-chat/server/knowledge_base/kb_api.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import urllib
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import validate_kb_name
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.repository.knowledge_base_repository import list_kbs_from_db
|
||||
from configs import EMBEDDING_MODEL, logger, log_verbose
|
||||
from fastapi import Body
|
||||
|
||||
|
||||
def list_kbs():
|
||||
# Get List of Knowledge Base
|
||||
return ListResponse(data=list_kbs_from_db())
|
||||
|
||||
|
||||
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
vector_store_type: str = Body("faiss"),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
) -> BaseResponse:
|
||||
# Create selected knowledge base
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
if knowledge_base_name is None or knowledge_base_name.strip() == "":
|
||||
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is not None:
|
||||
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
|
||||
try:
|
||||
kb.create_kb()
|
||||
except Exception as e:
|
||||
msg = f"创建知识库出错: {e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
|
||||
|
||||
|
||||
def delete_kb(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"])
|
||||
) -> BaseResponse:
|
||||
# Delete selected knowledge base
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
try:
|
||||
status = kb.clear_vs()
|
||||
status = kb.drop_kb()
|
||||
if status:
|
||||
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
|
||||
except Exception as e:
|
||||
msg = f"删除知识库时出现意外: {e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")
|
||||
164
langchain-chat/server/knowledge_base/kb_cache/base.py
Normal file
164
langchain-chat/server/knowledge_base/kb_cache/base.py
Normal file
@@ -0,0 +1,164 @@
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
import threading
|
||||
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
|
||||
logger, log_verbose)
|
||||
from server.utils import embedding_device, get_model_path, list_online_embed_models, resolve_embed_model_name
|
||||
from contextlib import contextmanager
|
||||
from collections import OrderedDict
|
||||
from typing import List, Any, Union, Tuple
|
||||
|
||||
|
||||
class ThreadSafeObject:
|
||||
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
|
||||
self._obj = obj
|
||||
self._key = key
|
||||
self._pool = pool
|
||||
self._lock = threading.RLock()
|
||||
self._loaded = threading.Event()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
cls = type(self).__name__
|
||||
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self._key
|
||||
|
||||
@contextmanager
|
||||
def acquire(self, owner: str = "", msg: str = "") -> FAISS:
|
||||
owner = owner or f"thread {threading.get_native_id()}"
|
||||
try:
|
||||
self._lock.acquire()
|
||||
if self._pool is not None:
|
||||
self._pool._cache.move_to_end(self.key)
|
||||
if log_verbose:
|
||||
logger.info(f"{owner} 开始操作:{self.key}。{msg}")
|
||||
yield self._obj
|
||||
finally:
|
||||
if log_verbose:
|
||||
logger.info(f"{owner} 结束操作:{self.key}。{msg}")
|
||||
self._lock.release()
|
||||
|
||||
def start_loading(self):
|
||||
self._loaded.clear()
|
||||
|
||||
def finish_loading(self):
|
||||
self._loaded.set()
|
||||
|
||||
def wait_for_loading(self):
|
||||
self._loaded.wait()
|
||||
|
||||
@property
|
||||
def obj(self):
|
||||
return self._obj
|
||||
|
||||
@obj.setter
|
||||
def obj(self, val: Any):
|
||||
self._obj = val
|
||||
|
||||
|
||||
class CachePool:
|
||||
def __init__(self, cache_num: int = -1):
|
||||
self._cache_num = cache_num
|
||||
self._cache = OrderedDict()
|
||||
self.atomic = threading.RLock()
|
||||
|
||||
def keys(self) -> List[str]:
|
||||
return list(self._cache.keys())
|
||||
|
||||
def _check_count(self):
|
||||
if isinstance(self._cache_num, int) and self._cache_num > 0:
|
||||
while len(self._cache) > self._cache_num:
|
||||
self._cache.popitem(last=False)
|
||||
|
||||
def get(self, key: str) -> ThreadSafeObject:
|
||||
if cache := self._cache.get(key):
|
||||
cache.wait_for_loading()
|
||||
return cache
|
||||
|
||||
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
|
||||
self._cache[key] = obj
|
||||
self._check_count()
|
||||
return obj
|
||||
|
||||
def pop(self, key: str = None) -> ThreadSafeObject:
|
||||
if key is None:
|
||||
return self._cache.popitem(last=False)
|
||||
else:
|
||||
return self._cache.pop(key, None)
|
||||
|
||||
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
|
||||
cache = self.get(key)
|
||||
if cache is None:
|
||||
raise RuntimeError(f"请求的资源 {key} 不存在")
|
||||
elif isinstance(cache, ThreadSafeObject):
|
||||
self._cache.move_to_end(key)
|
||||
return cache.acquire(owner=owner, msg=msg)
|
||||
else:
|
||||
return cache
|
||||
|
||||
def load_kb_embeddings(
|
||||
self,
|
||||
kb_name: str,
|
||||
embed_device: str = embedding_device(),
|
||||
default_embed_model: str = EMBEDDING_MODEL,
|
||||
) -> Embeddings:
|
||||
from server.db.repository.knowledge_base_repository import get_kb_detail
|
||||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||
|
||||
kb_detail = get_kb_detail(kb_name)
|
||||
embed_model = resolve_embed_model_name(
|
||||
kb_detail.get("embed_model", default_embed_model)
|
||||
)
|
||||
|
||||
if embed_model in list_online_embed_models():
|
||||
return EmbeddingsFunAdapter(embed_model)
|
||||
else:
|
||||
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
|
||||
|
||||
|
||||
class EmbeddingsPool(CachePool):
|
||||
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
|
||||
self.atomic.acquire()
|
||||
model = model or EMBEDDING_MODEL
|
||||
device = embedding_device()
|
||||
key = (model, device)
|
||||
if not self.get(key):
|
||||
item = ThreadSafeObject(key, pool=self)
|
||||
self.set(key, item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
embeddings = OpenAIEmbeddings(model=model,
|
||||
openai_api_key=get_model_path(model),
|
||||
chunk_size=CHUNK_SIZE)
|
||||
elif 'bge-' in model:
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
if 'zh' in model:
|
||||
# for chinese model
|
||||
query_instruction = "为这个句子生成表示以用于检索相关文章:"
|
||||
elif 'en' in model:
|
||||
# for english model
|
||||
query_instruction = "Represent this sentence for searching relevant passages:"
|
||||
else:
|
||||
# maybe ReRanker or else, just use empty string instead
|
||||
query_instruction = ""
|
||||
embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model),
|
||||
model_kwargs={'device': device},
|
||||
query_instruction=query_instruction)
|
||||
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
|
||||
embeddings.query_instruction = ""
|
||||
else:
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model),
|
||||
model_kwargs={'device': device})
|
||||
item.obj = embeddings
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get(key).obj
|
||||
|
||||
|
||||
embeddings_pool = EmbeddingsPool(cache_num=1)
|
||||
175
langchain-chat/server/knowledge_base/kb_cache/faiss_cache.py
Normal file
175
langchain-chat/server/knowledge_base/kb_cache/faiss_cache.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
|
||||
from server.knowledge_base.kb_cache.base import *
|
||||
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
|
||||
from server.utils import load_local_embeddings
|
||||
from server.knowledge_base.utils import get_vs_path
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain.docstore.in_memory import InMemoryDocstore
|
||||
from langchain.schema import Document
|
||||
import os
|
||||
from langchain.schema import Document
|
||||
|
||||
|
||||
# patch FAISS to include doc id in Document.metadata
|
||||
def _new_ds_search(self, search: str) -> Union[str, Document]:
|
||||
if search not in self._dict:
|
||||
return f"ID {search} not found."
|
||||
else:
|
||||
doc = self._dict[search]
|
||||
if isinstance(doc, Document):
|
||||
doc.metadata["id"] = search
|
||||
return doc
|
||||
InMemoryDocstore.search = _new_ds_search
|
||||
|
||||
|
||||
class ThreadSafeFaiss(ThreadSafeObject):
|
||||
def __repr__(self) -> str:
|
||||
cls = type(self).__name__
|
||||
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
|
||||
|
||||
def docs_count(self) -> int:
|
||||
return len(self._obj.docstore._dict)
|
||||
|
||||
def save(self, path: str, create_path: bool = True):
|
||||
with self.acquire():
|
||||
if not os.path.isdir(path) and create_path:
|
||||
os.makedirs(path)
|
||||
ret = self._obj.save_local(path)
|
||||
logger.info(f"已将向量库 {self.key} 保存到磁盘")
|
||||
return ret
|
||||
|
||||
def clear(self):
|
||||
ret = []
|
||||
with self.acquire():
|
||||
ids = list(self._obj.docstore._dict.keys())
|
||||
if ids:
|
||||
ret = self._obj.delete(ids)
|
||||
assert len(self._obj.docstore._dict) == 0
|
||||
logger.info(f"已将向量库 {self.key} 清空")
|
||||
return ret
|
||||
|
||||
|
||||
class _FaissPool(CachePool):
|
||||
def new_vector_store(
|
||||
self,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> FAISS:
|
||||
embeddings = EmbeddingsFunAdapter(embed_model)
|
||||
doc = Document(page_content="init", metadata={})
|
||||
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||
ids = list(vector_store.docstore._dict.keys())
|
||||
vector_store.delete(ids)
|
||||
return vector_store
|
||||
|
||||
def save_vector_store(self, kb_name: str, path: str=None):
|
||||
if cache := self.get(kb_name):
|
||||
return cache.save(path)
|
||||
|
||||
def unload_vector_store(self, kb_name: str):
|
||||
if cache := self.get(kb_name):
|
||||
self.pop(kb_name)
|
||||
logger.info(f"成功释放向量库:{kb_name}")
|
||||
|
||||
|
||||
class KBFaissPool(_FaissPool):
|
||||
def load_vector_store(
|
||||
self,
|
||||
kb_name: str,
|
||||
vector_name: str = None,
|
||||
create: bool = True,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
vector_name = vector_name or embed_model
|
||||
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
|
||||
if cache is None:
|
||||
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
|
||||
self.set((kb_name, vector_name), item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
|
||||
vs_path = get_vs_path(kb_name, vector_name)
|
||||
|
||||
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
|
||||
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
|
||||
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
|
||||
elif create:
|
||||
# create an empty vector store
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
|
||||
vector_store.save_local(vs_path)
|
||||
else:
|
||||
raise RuntimeError(f"knowledge base {kb_name} not exist.")
|
||||
item.obj = vector_store
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get((kb_name, vector_name))
|
||||
|
||||
|
||||
class MemoFaissPool(_FaissPool):
|
||||
def load_vector_store(
|
||||
self,
|
||||
kb_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = embedding_device(),
|
||||
) -> ThreadSafeFaiss:
|
||||
self.atomic.acquire()
|
||||
cache = self.get(kb_name)
|
||||
if cache is None:
|
||||
item = ThreadSafeFaiss(kb_name, pool=self)
|
||||
self.set(kb_name, item)
|
||||
with item.acquire(msg="初始化"):
|
||||
self.atomic.release()
|
||||
logger.info(f"loading vector store in '{kb_name}' to memory.")
|
||||
# create an empty vector store
|
||||
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
|
||||
item.obj = vector_store
|
||||
item.finish_loading()
|
||||
else:
|
||||
self.atomic.release()
|
||||
return self.get(kb_name)
|
||||
|
||||
|
||||
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
|
||||
memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time, random
|
||||
from pprint import pprint
|
||||
|
||||
kb_names = ["vs1", "vs2", "vs3"]
|
||||
# for name in kb_names:
|
||||
# memo_faiss_pool.load_vector_store(name)
|
||||
|
||||
def worker(vs_name: str, name: str):
|
||||
vs_name = "samples"
|
||||
time.sleep(random.randint(1, 5))
|
||||
embeddings = load_local_embeddings()
|
||||
r = random.randint(1, 3)
|
||||
|
||||
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
|
||||
if r == 1: # add docs
|
||||
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
|
||||
pprint(ids)
|
||||
elif r == 2: # search docs
|
||||
docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
|
||||
pprint(docs)
|
||||
if r == 3: # delete docs
|
||||
logger.warning(f"清除 {vs_name} by {name}")
|
||||
kb_faiss_pool.get(vs_name).clear()
|
||||
|
||||
threads = []
|
||||
for n in range(1, 30):
|
||||
t = threading.Thread(target=worker,
|
||||
kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
|
||||
daemon=True)
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
673
langchain-chat/server/knowledge_base/kb_doc_api.py
Normal file
673
langchain-chat/server/knowledge_base/kb_doc_api.py
Normal file
@@ -0,0 +1,673 @@
|
||||
import asyncio
|
||||
import os
|
||||
import urllib
|
||||
from fastapi import File, Form, Body, Query, Response, UploadFile
|
||||
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
EXPR,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
||||
logger, log_verbose, POLICY_KNOWLEDGE_BASE)
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.knowledge_base.cleanpdf import PdfConverter
|
||||
from server.knowledge_base.file_converter import FileConverter
|
||||
from server.utils import BaseResponse, ListResponse, flatten, run_in_thread_pool
|
||||
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
|
||||
files2docs_in_thread, KnowledgeFile)
|
||||
from fastapi.responses import FileResponse
|
||||
from sse_starlette import EventSourceResponse
|
||||
from pydantic import Json
|
||||
import json
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.repository.knowledge_file_repository import get_file_detail
|
||||
from langchain.docstore.document import Document
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from typing import List, Dict
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from datetime import datetime
|
||||
|
||||
def search_docs(
|
||||
fileName: list = Body([], description="文件名称", examples=["123.txt"]),
|
||||
query: str = Body("", description="改写后的query", examples=["你好"]),
|
||||
usr_query: str = Body("", description="用户输入的问题", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD,
|
||||
description="知识库匹配相关度阈值,取值范围在0-1之间,"
|
||||
"SCORE越小,相关度越高,"
|
||||
"取到1相当于不筛选,建议设置在0.5左右",
|
||||
ge=0, le=1),
|
||||
expr: str = Body(EXPR, description="milvus混合检索条件"),
|
||||
file_name: str = Body("", description="文件名称,支持 sql 通配符"),
|
||||
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
|
||||
custom_strategy_config: dict = Body({}, description="自定义策略配置"),
|
||||
query_rewrite_model_name = LLM_MODELS[0]
|
||||
|
||||
) -> List[DocumentWithVSId]:
|
||||
# 获取当前时间并格式化为YYYYMMDD
|
||||
time = datetime.now().strftime("%Y%m%d")
|
||||
if POLICY_KNOWLEDGE_BASE in knowledge_base_name:
|
||||
expr = get_llm_model_response(
|
||||
strategy_name="get policy time",
|
||||
llm_model_name=query_rewrite_model_name,
|
||||
template_prompt_name="get_policy_time",
|
||||
prompt_param_dict={"query": usr_query, "time": time},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
).replace("None", "")
|
||||
print(f'Milvus混合检索表达式:{expr}')
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
data = []
|
||||
if type(expr) is not str:
|
||||
expr = EXPR
|
||||
query1 = ""
|
||||
if kb is not None:
|
||||
if fileName:
|
||||
if query:
|
||||
query1 += "请查询以下几篇文件:" + str(fileName[0]) + "并" + query
|
||||
docs = kb.search_docs(query1, top_k, score_threshold, expr)
|
||||
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id"))for x in docs if x[0].metadata.get("source") in fileName]
|
||||
elif file_name or metadata:
|
||||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||||
else:
|
||||
if query:
|
||||
docs = kb.search_docs(query, top_k, score_threshold, expr)
|
||||
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
|
||||
elif file_name or metadata:
|
||||
data = kb.list_docs(file_name=file_name, metadata=metadata)
|
||||
return data
|
||||
|
||||
def search_self_docs(
|
||||
fileNames: list = Body([], description="文件名称", examples=["123.txt"]),
|
||||
query: str = Body("", description="改写后的query", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD,
|
||||
description="知识库匹配相关度阈值,取值范围在0-1之间,"
|
||||
"SCORE越小,相关度越高,"
|
||||
"取到1相当于不筛选,建议设置在0.5左右",
|
||||
ge=0, le=1),
|
||||
expr: str = Body("", description="milvus混合检索条件"),
|
||||
) -> List[DocumentWithVSId]:
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
data = []
|
||||
|
||||
if fileNames:
|
||||
# 检查是否存在嵌套列表
|
||||
if isinstance(fileNames[0], list):
|
||||
# 如果是嵌套列表,先展平
|
||||
flat_fileNames = flatten(fileNames)
|
||||
else:
|
||||
# 如果不是嵌套列表,直接使用
|
||||
flat_fileNames = fileNames
|
||||
else:
|
||||
flat_fileNames = []
|
||||
|
||||
if not expr or not isinstance(expr, str):
|
||||
if flat_fileNames:
|
||||
expr = ' || '.join([f'source == "{fileName}"' for fileName in flat_fileNames])
|
||||
else:
|
||||
expr = ""
|
||||
logger.info(f"个人知识库检索EXPR: {expr}")
|
||||
|
||||
if kb is not None:
|
||||
docs = kb.search_docs(query, top_k, score_threshold, expr)
|
||||
if top_k > 50:
|
||||
data = docs
|
||||
else:
|
||||
data = [
|
||||
DocumentWithVSId(
|
||||
**{k: v for k, v in x[0].dict().items() if k != 'page_content'}, # 排除原有的 page_content
|
||||
score=x[1],
|
||||
id=x[0].metadata.get("id"),
|
||||
page_content=f"【^[{index +1}]^ {x[0].page_content}】 " # 拼接索引和page_content
|
||||
)
|
||||
for index, x in enumerate(docs) # 使用enumerate来获取索引
|
||||
if x[0].metadata.get("source") in flat_fileNames
|
||||
]
|
||||
else:
|
||||
logger.warning(f"未找到知识库服务: {knowledge_base_name}")
|
||||
|
||||
return data
|
||||
|
||||
def update_docs_by_id(
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
docs: Dict[str, Document] = Body(..., description="要更新的文档内容,形如:{id: Document, ...}")
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
按照文档 ID 更新文档内容
|
||||
'''
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=500, msg=f"指定的知识库 {knowledge_base_name} 不存在")
|
||||
if kb.update_doc_by_ids(docs=docs):
|
||||
return BaseResponse(msg=f"文档更新成功")
|
||||
else:
|
||||
return BaseResponse(msg=f"文档更新失败")
|
||||
|
||||
|
||||
def list_files(
|
||||
knowledge_base_name: str
|
||||
) -> ListResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||
|
||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||
else:
|
||||
all_doc_names = kb.list_files()
|
||||
return ListResponse(data=all_doc_names)
|
||||
|
||||
|
||||
def _save_files_in_thread(files: List[UploadFile],
|
||||
knowledge_base_name: str,
|
||||
override: bool):
|
||||
"""
|
||||
通过多线程将上传的文件保存到对应知识库目录内。
|
||||
生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
|
||||
"""
|
||||
|
||||
def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
|
||||
'''
|
||||
保存单个文件。
|
||||
'''
|
||||
try:
|
||||
filename = file.filename
|
||||
file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename)
|
||||
data = {"knowledge_base_name": knowledge_base_name, "file_name": filename}
|
||||
|
||||
file_content = file.file.read() # 读取上传文件的内容
|
||||
if (os.path.isfile(file_path)
|
||||
and not override
|
||||
and os.path.getsize(file_path) == len(file_content)
|
||||
):
|
||||
file_status = f"文件 {filename} 已存在。"
|
||||
logger.warn(file_status)
|
||||
return dict(code=404, msg=file_status, data=data)
|
||||
|
||||
if not os.path.isdir(os.path.dirname(file_path)):
|
||||
os.makedirs(os.path.dirname(file_path))
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
return dict(code=200, msg=f"成功上传文件 {filename}", data=data)
|
||||
except Exception as e:
|
||||
msg = f"{filename} 文件上传失败,报错信息为: {e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return dict(code=500, msg=msg, data=data)
|
||||
|
||||
params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files]
|
||||
for result in run_in_thread_pool(save_file, params=params):
|
||||
yield result
|
||||
|
||||
|
||||
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||
# override: bool = Form(False, description="覆盖已有文件"),
|
||||
# save: bool = Form(True, description="是否将文件保存到知识库目录")):
|
||||
# def save_files(files, knowledge_base_name, override):
|
||||
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||
# yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
# def files_to_docs(files):
|
||||
# for result in files2docs_in_thread(files):
|
||||
# yield json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
def upload_docs(
|
||||
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||
override: bool = Form(False, description="覆盖已有文件"),
|
||||
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
|
||||
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||
docs: Json = Form({}, description="自定义的docs,需要转为json字符串",
|
||||
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
"""
|
||||
API接口:上传文件,并/或向量化
|
||||
"""
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
failed_files = {}
|
||||
file_names = list(docs.keys())
|
||||
|
||||
# 先将上传的文件保存到磁盘
|
||||
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||
filename = result["data"]["file_name"]
|
||||
if result["code"] != 200:
|
||||
failed_files[filename] = result["msg"]
|
||||
|
||||
if filename not in file_names:
|
||||
file_names.append(filename)
|
||||
|
||||
# 对保存的文件进行向量化
|
||||
if to_vector_store:
|
||||
result = update_docs(
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
file_names=file_names,
|
||||
override_custom_docs=True,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance,
|
||||
docs=docs,
|
||||
not_refresh_vs_cache=True,
|
||||
)
|
||||
failed_files.update(result.data["failed_files"])
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
|
||||
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
def upload_docs_new(
|
||||
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||
override: bool = Form(False, description="覆盖已有文件"),
|
||||
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
|
||||
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||
docs: Json = Form({}, description="自定义的docs,需要转为json字符串",
|
||||
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
"""
|
||||
API接口:上传文件,并/或向量化
|
||||
"""
|
||||
import time # 添加计时模块
|
||||
start_time = time.time()
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
# 自动创建知识库
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, DEFAULT_VS_TYPE, EMBEDDING_MODEL)
|
||||
try:
|
||||
kb.create_kb()
|
||||
logger.info(f"自动创建知识库: {knowledge_base_name}")
|
||||
except Exception as e:
|
||||
msg = f"创建知识库出错: {e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
failed_files = {}
|
||||
file_names = list(docs.keys())
|
||||
|
||||
# 生成摘要、关键词、章节速览的结果存储
|
||||
llm_results = {}
|
||||
|
||||
# 先将上传的文件保存到磁盘
|
||||
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
||||
filename = result["data"]["file_name"]
|
||||
if result["code"] != 200:
|
||||
failed_files[filename] = result["msg"]
|
||||
|
||||
if filename not in file_names:
|
||||
file_names.append(filename)
|
||||
|
||||
# 针对成功上传的文件,生成摘要、关键词、章节速览
|
||||
try:
|
||||
knowledge_file = KnowledgeFile(filename=filename, knowledge_base_name=knowledge_base_name)
|
||||
# 使用线程池运行异步函数,避免事件循环冲突
|
||||
import concurrent.futures
|
||||
def run_async_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(knowledge_file.get_llm_result())
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_async_in_thread)
|
||||
llm_result = future.result()
|
||||
llm_results[filename] = {
|
||||
"full_text": llm_result.get("full_text", "获取全文失败"),
|
||||
"article_abstract": llm_result.get("article_abstract", "生成摘要失败"),
|
||||
"article_keywords": llm_result.get("article_keywords", "生成关键词失败"),
|
||||
"article_paragraph": llm_result.get("article_paragraph", "生成章节速览失败")
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"生成LLM结果时出错:{e}", exc_info=e if log_verbose else None)
|
||||
llm_results[filename] = {
|
||||
"article_abstract": "生成摘要失败",
|
||||
"article_keywords": "生成关键词失败",
|
||||
"article_paragraph": "生成章节速览失败"
|
||||
}
|
||||
|
||||
# 对保存的文件进行向量化
|
||||
if to_vector_store:
|
||||
update_st = time.time()
|
||||
result = _update_docs_impl(
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
file_names=file_names,
|
||||
override_custom_docs=True,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance,
|
||||
docs=docs,
|
||||
not_refresh_vs_cache=True,
|
||||
)
|
||||
failed_files.update(result.data["failed_files"])
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
logger.info(f'向量化用时:{time.time() - update_st}')
|
||||
logger.info(f"总执行时间: {time.time() - start_time:.2f}s")
|
||||
return BaseResponse(code=200, msg="文件上传与向量化完成", data={
|
||||
"failed_files": failed_files,
|
||||
"llm_results": llm_results
|
||||
})
|
||||
|
||||
def delete_docs(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
|
||||
delete_content: bool = Body(False),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
failed_files = {}
|
||||
for file_name in file_names:
|
||||
if not kb.exist_doc(file_name):
|
||||
failed_files[file_name] = f"未找到文件 {file_name}"
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True)
|
||||
except Exception as e:
|
||||
msg = f"{file_name} 文件删除失败,错误信息:{e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
failed_files[file_name] = msg
|
||||
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
|
||||
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
def update_info(
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]),
|
||||
):
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
kb.update_info(kb_info)
|
||||
|
||||
return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info})
|
||||
|
||||
from time import time
|
||||
|
||||
def _update_docs_impl(
|
||||
knowledge_base_name: str,
|
||||
file_names: List[str],
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||
override_custom_docs: bool = False,
|
||||
docs: Dict = {},
|
||||
not_refresh_vs_cache: bool = False,
|
||||
) -> BaseResponse:
|
||||
"""
|
||||
更新知识库文档的核心实现(供内部调用)
|
||||
"""
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
failed_files = {}
|
||||
kb_files = []
|
||||
|
||||
# 生成需要加载docs的文件列表
|
||||
for file_name in file_names:
|
||||
file_detail = get_file_detail(kb_name=knowledge_base_name, filename=file_name)
|
||||
# 如果该文件之前使用了自定义docs,则根据参数决定略过或覆盖
|
||||
if file_detail.get("custom_docs") and not override_custom_docs:
|
||||
continue
|
||||
if file_name not in docs:
|
||||
try:
|
||||
kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name))
|
||||
except Exception as e:
|
||||
msg = f"加载文档 {file_name} 时出错:{e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
failed_files[file_name] = msg
|
||||
|
||||
update_st = time()
|
||||
# 从文件生成docs,并进行向量化。
|
||||
# 这里利用了KnowledgeFile的缓存功能,在多线程中加载Document,然后传给KnowledgeFile
|
||||
for status, result in files2docs_in_thread(kb_files,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance):
|
||||
if status:
|
||||
kb_name, file_name, new_docs = result
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
kb_file.splited_docs = new_docs
|
||||
kb.update_doc(kb_file, not_refresh_vs_cache=True)
|
||||
else:
|
||||
kb_name, file_name, error = result
|
||||
failed_files[file_name] = error
|
||||
print('use time:', time() - update_st)
|
||||
|
||||
# 将自定义的docs进行向量化
|
||||
for file_name, v in docs.items():
|
||||
try:
|
||||
v = [x if isinstance(x, Document) else Document(**x) for x in v]
|
||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)
|
||||
kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True)
|
||||
except Exception as e:
|
||||
msg = f"为 {file_name} 添加自定义docs时出错:{e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
failed_files[file_name] = msg
|
||||
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
|
||||
return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
def update_docs(
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
|
||||
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
|
||||
docs: Json = Body({}, description="自定义的docs,需要转为json字符串",
|
||||
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
) -> BaseResponse:
|
||||
"""
|
||||
更新知识库文档(API 路由)
|
||||
"""
|
||||
return _update_docs_impl(
|
||||
knowledge_base_name=knowledge_base_name,
|
||||
file_names=file_names,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance,
|
||||
override_custom_docs=override_custom_docs,
|
||||
docs=docs,
|
||||
not_refresh_vs_cache=not_refresh_vs_cache,
|
||||
)
|
||||
|
||||
|
||||
def download_doc(
|
||||
knowledge_base_name: str = Query(..., description="知识库名称", examples=["samples"]),
|
||||
file_name: str = Query(..., description="文件名称", examples=["test.txt"]),
|
||||
preview: bool = Query(True, description="是:浏览器内预览;否:下载"),
|
||||
):
|
||||
"""
|
||||
下载/预览知识库文档(支持自动转换为 HTML)
|
||||
"""
|
||||
logger.info(f"是否预览: {preview}")
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
return BaseResponse(code=403, msg="Don't attack me")
|
||||
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file_name,
|
||||
knowledge_base_name=knowledge_base_name)
|
||||
|
||||
if not os.path.exists(kb_file.filepath):
|
||||
return BaseResponse(code=404, msg=f"文件 {file_name} 不存在")
|
||||
|
||||
# 支持转换的文件类型映射
|
||||
CONVERT_MAP = {
|
||||
"pdf": "pdf_to_html",
|
||||
"docx": "docx_to_html",
|
||||
"doc": "doc_to_html",
|
||||
"md": "md_to_html",
|
||||
"txt": "txt_to_html",
|
||||
"xlsx": "xlsx_to_html",
|
||||
"xls": "xls_to_html",
|
||||
}
|
||||
|
||||
# 获取文件扩展名
|
||||
file_ext = os.path.splitext(file_name)[1].lower().lstrip('.')
|
||||
|
||||
# 检查是否需要转换
|
||||
if file_ext in CONVERT_MAP:
|
||||
converter = FileConverter()
|
||||
convert_method = getattr(converter, CONVERT_MAP[file_ext])
|
||||
|
||||
try:
|
||||
# 执行转换并获取 HTML 内容
|
||||
html_content = convert_method(kb_file.filepath, output_path=None)
|
||||
if "转换失败" in html_content:
|
||||
return BaseResponse(code=500, msg=f"文件:{file_name} 处理失败", data=html_content)
|
||||
# 构造响应参数
|
||||
new_filename = f"{os.path.splitext(os.path.basename(file_name))[0]}.html"
|
||||
# 对文件名进行 UTF-8 编码
|
||||
encoded_filename = urllib.parse.quote(new_filename)
|
||||
content_disposition = "inline" if preview else f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
|
||||
# 返回 HTML 响应,以文件流形式
|
||||
return Response(
|
||||
content=html_content.encode('utf-8'),
|
||||
media_type="text/html",
|
||||
headers={
|
||||
"Content-Disposition": content_disposition,
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Pragma": "no-cache",
|
||||
"Expires": "0"
|
||||
}
|
||||
)
|
||||
|
||||
except RuntimeError as e:
|
||||
msg = f"文件转换失败: {str(e)}"
|
||||
logger.error(msg)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
# 不需要转换的文件类型
|
||||
content_disposition_type = "inline" if preview else "attachment"
|
||||
encoded_filename = urllib.parse.quote(kb_file.filename)
|
||||
with open(kb_file.filepath, 'rb') as file:
|
||||
file_content = file.read()
|
||||
|
||||
return Response(
|
||||
content=file_content if preview else html_content,
|
||||
media_type="application/octet-stream",
|
||||
headers={
|
||||
"Content-Disposition": f"{content_disposition_type}; filename*=UTF-8''{encoded_filename}",
|
||||
"Cache-Control": "no-cache, no-store, must-revalidate",
|
||||
"Pragma": "no-cache",
|
||||
"Expires": "0"
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
msg = f"{file_name} 处理失败,错误信息是:{e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
|
||||
def recreate_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
|
||||
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
|
||||
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
|
||||
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"),
|
||||
):
|
||||
"""
|
||||
recreate vector store from the content.
|
||||
this is usefull when user can copy files to content folder directly instead of upload through network.
|
||||
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
|
||||
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
|
||||
"""
|
||||
|
||||
def output():
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
if kb.exists():
|
||||
kb.clear_vs()
|
||||
kb.create_kb()
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
kb_files = [(file, knowledge_base_name) for file in files]
|
||||
i = 0
|
||||
for status, result in files2docs_in_thread(kb_files,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance):
|
||||
if status:
|
||||
kb_name, file_name, docs = result
|
||||
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
|
||||
kb_file.splited_docs = docs
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i + 1,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
kb.add_doc(kb_file, not_refresh_vs_cache=True)
|
||||
else:
|
||||
kb_name, file_name, error = result
|
||||
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
if not not_refresh_vs_cache:
|
||||
kb.save_vector_store()
|
||||
|
||||
return EventSourceResponse(output())
|
||||
775
langchain-chat/server/knowledge_base/kb_service/base.py
Normal file
775
langchain-chat/server/knowledge_base/kb_service/base.py
Normal file
@@ -0,0 +1,775 @@
|
||||
from datetime import datetime
|
||||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
import re
|
||||
import os
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
from joblib import Parallel, delayed
|
||||
import time
|
||||
from textrank4zh import TextRank4Keyword, TextRank4Sentence
|
||||
import multiprocessing
|
||||
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from server.chat.utils import get_personal_knowledge_map, get_similar_documents1
|
||||
from nltk.tokenize import sent_tokenize
|
||||
import logging
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
def generate_weights_as_list(length, total_sum=80):
|
||||
if length == 1:
|
||||
return [50]
|
||||
# Create a range of numbers decreasing logarithmically
|
||||
x = np.linspace(0, length - 1, length)
|
||||
weights = np.exp(-x / (length / 5))
|
||||
|
||||
# Normalize the weights to match the specified sum
|
||||
normalized_weights = weights / sum(weights) * total_sum
|
||||
integer_weights = np.round(normalized_weights).astype(int)
|
||||
|
||||
# Adjust the weights to match the exact sum if rounding causes deviation
|
||||
adjustment = total_sum - sum(integer_weights)
|
||||
for i in range(abs(adjustment)):
|
||||
if adjustment > 0:
|
||||
integer_weights[i % length] += 1
|
||||
elif adjustment < 0:
|
||||
integer_weights[i % length] -= 1
|
||||
|
||||
return integer_weights.tolist()
|
||||
|
||||
def score_threshold_process(query,score_threshold, k, docs):
|
||||
"""
|
||||
根据分数阈值过滤和使用TextRank摘要文档,并返回前k个文档。
|
||||
|
||||
:param score_threshold: 相似度分数阈值;忽略低于此阈值的文档。
|
||||
:param k: 要返回的顶部文档数量。
|
||||
:param docs: 文档列表,每个文档是一个元组(文档,相似度分数)。
|
||||
:return: 根据分数阈值返回的前k个文档的列表。
|
||||
"""
|
||||
# 如果提供了score_threshold,则只过滤大于阈值的文档。
|
||||
if score_threshold is not None:
|
||||
cmp = (
|
||||
operator.le
|
||||
)
|
||||
docs = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs
|
||||
if cmp(similarity, score_threshold)
|
||||
]
|
||||
# 当召回结果都大于score_threshold时
|
||||
if len(docs) == 0:
|
||||
return docs
|
||||
result = []
|
||||
try:
|
||||
for doc in docs:
|
||||
if query.replace(" ","").replace("\n","").replace("\r","") in doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r",""):
|
||||
result.append(doc)
|
||||
|
||||
except Exception as e:
|
||||
for doc in docs:
|
||||
if query.replace(" ","").replace("\n","").replace("\r","") in doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r",""):
|
||||
result.append(doc)
|
||||
|
||||
if len(docs) > 0 and not "h1" in docs[0][0].metadata:
|
||||
# 如果存在用户的问题在标题中的情况则进行去重操作且不需要再匹配相关度,只需要把问题在标题中的文献提交出去
|
||||
if len(result) > 0:
|
||||
temp={}
|
||||
for doc in result:
|
||||
if doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") not in temp:
|
||||
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")] = doc
|
||||
else:
|
||||
if temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
elif temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
else:
|
||||
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
|
||||
docs = []
|
||||
for i in temp:
|
||||
docs.append(temp[i])
|
||||
else:
|
||||
try:
|
||||
sentences = []
|
||||
sentences_page_content = []
|
||||
for doc in docs:
|
||||
meta = doc[0].metadata
|
||||
# 若缺少标题或为空,则用正文首句作为标题(最多50字)
|
||||
if meta["title"] == "":
|
||||
meta["title"] = doc[0].page_content
|
||||
# 如有摘要则替换 page_content,保证后续文本更简洁
|
||||
summary = meta.get("summary")
|
||||
if summary:
|
||||
doc[0].page_content = summary
|
||||
sentences = [doc[0].metadata["title"] for doc in docs]
|
||||
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["title"]+doc[0].page_content+"】" for i,doc in enumerate(docs)]
|
||||
except Exception as e:
|
||||
sentences = [doc[0].metadata["source"] for doc in docs]
|
||||
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"】" for i,doc in enumerate(docs)]
|
||||
res = get_llm_model_response(
|
||||
strategy_name="default_similar",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="default_similar",
|
||||
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
try:
|
||||
index =[]
|
||||
if res == "无":
|
||||
index = []
|
||||
else:
|
||||
index = res.split(",")
|
||||
index = [int(i)-1 for i in index]
|
||||
docs = get_similar_documents1(index=index,sentences=sentences,query=query, docs=docs, top_k=k)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
docs = get_similar_documents1(index=[],sentences=sentences,query=query, docs=docs, top_k=k)
|
||||
|
||||
|
||||
# 去重操作只针对通用知识库
|
||||
temp={}
|
||||
for doc in docs:
|
||||
try:
|
||||
if doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") not in temp:
|
||||
temp[doc[0].metadata["title"]] = doc
|
||||
else:
|
||||
if temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
elif temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
else:
|
||||
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") not in temp:
|
||||
temp[doc[0].metadata["source"]] = doc
|
||||
else:
|
||||
if temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
elif temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
else:
|
||||
temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
|
||||
docs = []
|
||||
for i in temp:
|
||||
docs.append(temp[i])
|
||||
|
||||
|
||||
|
||||
|
||||
#只针对个人知识库
|
||||
if "h1" in docs[0][0].metadata:
|
||||
all_source = [doc[0].metadata["source"] for doc in docs]
|
||||
unique_source = list(set(all_source))
|
||||
all_title_map = get_personal_knowledge_map(unique_source)
|
||||
for doc in docs:
|
||||
doc[0].metadata["uuid_name"] = doc[0].metadata["source"]
|
||||
if doc[0].metadata["source"] in all_title_map:
|
||||
doc[0].metadata["source"] = all_title_map[doc[0].metadata["source"]]
|
||||
else:
|
||||
pass
|
||||
try:
|
||||
sentences = [doc[0].metadata["source"] for doc in docs]
|
||||
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"】" for i,doc in enumerate(docs)]
|
||||
except Exception as e:
|
||||
sentences = [doc[0].metadata["source"] for doc in docs]
|
||||
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"】" for i,doc in enumerate(docs)]
|
||||
kwargs = {}
|
||||
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}}
|
||||
res = get_llm_model_response(
|
||||
strategy_name="default_similar",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="default_similar",
|
||||
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
|
||||
temperature=0.01,
|
||||
max_tokens=None,
|
||||
**kwargs
|
||||
)
|
||||
res = re.sub(r'<think>.*?</think>', '', res,flags=re.DOTALL)
|
||||
try:
|
||||
index =[]
|
||||
if res == "无":
|
||||
index = []
|
||||
else:
|
||||
index = res.split(",")
|
||||
index = [int(i)-1 for i in index]
|
||||
docs = get_similar_documents1(index=index,sentences=sentences,query=query, docs=docs, top_k=k)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
docs = get_similar_documents1(index=[],sentences=sentences,query=query, docs=docs, top_k=k)
|
||||
|
||||
|
||||
# 去重操作只针对通用知识库
|
||||
temp={}
|
||||
for doc in docs:
|
||||
if doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") not in temp:
|
||||
temp[doc[0].metadata["source"]] = doc
|
||||
else:
|
||||
if temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
elif temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
else:
|
||||
temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
|
||||
docs = []
|
||||
for i in temp:
|
||||
docs.append(temp[i])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if len(docs) == 0:
|
||||
return docs
|
||||
# 为TextRank算法生成权重。
|
||||
cont = generate_weights_as_list(len(docs))
|
||||
|
||||
# 处理每个文档以提取或分配摘要。
|
||||
for i, (doc, _) in enumerate(docs):
|
||||
summary_sources = ['content', 'abstract', 'text'] # 根据不同知识库遍历字段
|
||||
for source in summary_sources:
|
||||
try:
|
||||
if docs[i][0].metadata["title"] in docs[i][0].page_content or len(docs[i][0].page_content) < 100:
|
||||
doc.metadata['summary'] = TextRank(doc.metadata[source], cont[i])
|
||||
if len(doc.metadata['summary']) >15000:
|
||||
doc.metadata['summary'] = TextRank(doc.metadata[source], 1)
|
||||
break
|
||||
else:
|
||||
doc.metadata['summary'] = docs[i][0].page_content
|
||||
except KeyError:
|
||||
doc.metadata['summary'] = docs[i][0].page_content
|
||||
continue # 如果当前源失败,则尝试下一个源。
|
||||
|
||||
# 返回前k个文档。
|
||||
return docs[:k]
|
||||
|
||||
# 猴子补丁,为了兼容TexRank,
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
# 进行猴子补丁
|
||||
nx.from_numpy_matrix = nx.from_numpy_array
|
||||
# 进行猴子补丁,入数据类型兼容性检查
|
||||
|
||||
def process_text_segment(text_segment, num_sentences):
|
||||
tr4w = TextRank4Keyword()
|
||||
tr4w.analyze(text=text_segment, lower=True, window=5)
|
||||
keywords = [(item.word, item.weight) for item in tr4w.get_keywords(30, word_min_len=4)]
|
||||
|
||||
tr4s = TextRank4Sentence()
|
||||
tr4s.analyze(text=text_segment, lower=True, source='all_filters')
|
||||
summaries = [item.sentence for item in tr4s.get_key_sentences(num=num_sentences)]
|
||||
|
||||
return keywords, summaries
|
||||
|
||||
def split_text_balanced(text, n_parts):
|
||||
sentences = sent_tokenize(text)
|
||||
min_sentences_per_part = 10
|
||||
n_parts = max(1, min(n_parts, len(sentences) // min_sentences_per_part))
|
||||
k, m = divmod(len(sentences), n_parts)
|
||||
return [' '.join(sentences[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n_parts)]
|
||||
|
||||
|
||||
def TextRank(text,num_sentences, n_cores=multiprocessing.cpu_count()):
|
||||
start_time = time.time()
|
||||
text_parts = split_text_balanced(text, n_cores)
|
||||
all_keywords = []
|
||||
all_summaries = []
|
||||
# 在这里直接顺序处理每个分块,或用线程池而非进程池
|
||||
for part in text_parts:
|
||||
keywords, summaries = process_text_segment(part, num_sentences)
|
||||
all_keywords.extend(keywords)
|
||||
all_summaries.extend(summaries)
|
||||
|
||||
# Print results
|
||||
# print('关键词:')
|
||||
for word, weight in sorted(all_keywords, key=lambda x: x[1], reverse=True):
|
||||
print(word, weight)
|
||||
|
||||
end_time = time.time()
|
||||
logging.info(f"TextRank耗时: {end_time - start_time:.2f}秒")
|
||||
all_summaries = "".join(all_summaries)
|
||||
return all_summaries
|
||||
|
||||
|
||||
from server.db.repository.knowledge_base_repository import (
|
||||
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||
load_kb_from_db, get_kb_detail,
|
||||
)
|
||||
from server.db.repository.knowledge_file_repository import (
|
||||
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
|
||||
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
|
||||
list_docs_from_db,
|
||||
)
|
||||
|
||||
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,EXPR,
|
||||
EMBEDDING_MODEL, KB_INFO)
|
||||
from server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, KnowledgeFile,
|
||||
list_kbs_from_folder, list_files_from_folder,
|
||||
)
|
||||
|
||||
from typing import List, Union, Dict, Optional, Tuple
|
||||
|
||||
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
import time
|
||||
|
||||
|
||||
def get_emb_time(f):
|
||||
def inner(*arg,**kwarg):
|
||||
s_time = time.time()
|
||||
res = f(*arg,**kwarg)
|
||||
e_time = time.time()
|
||||
print('向量化耗时:{}秒'.format(e_time - s_time))
|
||||
return res
|
||||
return inner
|
||||
|
||||
def normalize(embeddings: List[List[float]]) -> np.ndarray:
|
||||
'''
|
||||
sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn
|
||||
'''
|
||||
# 过滤掉 None 值
|
||||
embeddings = [e for e in embeddings if e is not None]
|
||||
if not embeddings:
|
||||
raise ValueError("No valid embeddings found (all are None)")
|
||||
embeddings = np.array(embeddings)
|
||||
norm = np.linalg.norm(embeddings, axis=1)
|
||||
norm = np.reshape(norm, (norm.shape[0], 1))
|
||||
norm = np.tile(norm, (1, len(embeddings[0])))
|
||||
return np.divide(embeddings, norm)
|
||||
|
||||
|
||||
class SupportedVSType:
|
||||
FAISS = 'faiss'
|
||||
MILVUS = 'milvus'
|
||||
DEFAULT = 'default'
|
||||
ZILLIZ = 'zilliz'
|
||||
PG = 'pg'
|
||||
ES = 'es'
|
||||
CHROMADB = 'chromadb'
|
||||
|
||||
|
||||
class KBService(ABC):
|
||||
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
||||
self.embed_model = embed_model
|
||||
self.kb_path = get_kb_path(self.kb_name)
|
||||
self.doc_path = get_doc_path(self.kb_name)
|
||||
self.do_init()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.kb_name} @ {self.embed_model}"
|
||||
|
||||
def save_vector_store(self):
|
||||
'''
|
||||
保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持
|
||||
'''
|
||||
pass
|
||||
|
||||
def create_kb(self):
|
||||
"""
|
||||
创建知识库
|
||||
"""
|
||||
if not os.path.exists(self.doc_path):
|
||||
os.makedirs(self.doc_path)
|
||||
self.do_create_kb()
|
||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||
return status
|
||||
|
||||
def clear_vs(self):
|
||||
"""
|
||||
删除向量库中所有内容
|
||||
"""
|
||||
self.do_clear_vs()
|
||||
status = delete_files_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
def drop_kb(self):
|
||||
"""
|
||||
删除知识库
|
||||
"""
|
||||
self.do_drop_kb()
|
||||
status = delete_kb_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
|
||||
'''
|
||||
将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
|
||||
'''
|
||||
return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
|
||||
|
||||
@get_emb_time
|
||||
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
|
||||
"""
|
||||
if docs:
|
||||
custom_docs = True
|
||||
for doc in docs:
|
||||
doc.metadata.setdefault("source", kb_file.filename)
|
||||
else:
|
||||
docs = kb_file.file2text()
|
||||
custom_docs = False
|
||||
|
||||
if docs:
|
||||
# 将 metadata["source"] 改为相对路径
|
||||
for doc in docs:
|
||||
try:
|
||||
source = doc.metadata.get("source", "")
|
||||
if os.path.isabs(source):
|
||||
rel_path = Path(source).relative_to(self.doc_path)
|
||||
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
|
||||
except Exception as e:
|
||||
print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")
|
||||
self.delete_doc(kb_file)
|
||||
doc_infos = self.do_add_doc(docs, **kwargs)
|
||||
status = add_file_to_db(kb_file,
|
||||
custom_docs=custom_docs,
|
||||
docs_count=len(docs),
|
||||
doc_infos=doc_infos)
|
||||
else:
|
||||
status = False
|
||||
return status
|
||||
|
||||
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
|
||||
"""
|
||||
从知识库删除文件
|
||||
"""
|
||||
self.do_delete_doc(kb_file, **kwargs)
|
||||
status = delete_file_from_db(kb_file)
|
||||
if delete_content and os.path.exists(kb_file.filepath):
|
||||
os.remove(kb_file.filepath)
|
||||
return status
|
||||
|
||||
def update_info(self, kb_info: str):
|
||||
"""
|
||||
更新知识库介绍
|
||||
"""
|
||||
self.kb_info = kb_info
|
||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||
return status
|
||||
|
||||
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||
"""
|
||||
使用content中的文件更新向量库
|
||||
如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
|
||||
"""
|
||||
if os.path.exists(kb_file.filepath):
|
||||
self.delete_doc(kb_file, **kwargs)
|
||||
return self.add_doc(kb_file, docs=docs, **kwargs)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
filename=file_name))
|
||||
|
||||
def list_files(self):
|
||||
return list_files_from_db(self.kb_name)
|
||||
|
||||
def count_files(self):
|
||||
return count_files_from_db(self.kb_name)
|
||||
|
||||
def search_docs(self,
|
||||
query: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
expr: str = EXPR,
|
||||
custom_strategy_config: dict = {}
|
||||
) ->List[Document]:
|
||||
docs = self.do_search(query, top_k, score_threshold, expr, custom_strategy_config)
|
||||
return docs
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
return []
|
||||
|
||||
def get_doc_by_sources_name(self, source_name_list: List[str]) -> List[Document]:
|
||||
return []
|
||||
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool:
|
||||
'''
|
||||
传入参数为: {doc_id: Document, ...}
|
||||
如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
|
||||
'''
|
||||
self.del_doc_by_ids(list(docs.keys()))
|
||||
docs = []
|
||||
ids = []
|
||||
for k, v in docs.items():
|
||||
if not v or not v.page_content.strip():
|
||||
continue
|
||||
ids.append(k)
|
||||
docs.append(v)
|
||||
self.do_add_doc(docs=docs, ids=ids)
|
||||
return True
|
||||
|
||||
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
|
||||
'''
|
||||
通过file_name或metadata检索Document
|
||||
'''
|
||||
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
||||
docs = []
|
||||
for x in doc_infos:
|
||||
doc_info = self.get_doc_by_ids([x["id"]])[0]
|
||||
if doc_info is not None:
|
||||
# 处理非空的情况
|
||||
doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"])
|
||||
docs.append(doc_with_id)
|
||||
else:
|
||||
# 处理空的情况
|
||||
# 可以选择跳过当前循环迭代或执行其他操作
|
||||
pass
|
||||
return docs
|
||||
|
||||
@abstractmethod
|
||||
def do_create_kb(self):
|
||||
"""
|
||||
创建知识库子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def list_kbs_type():
|
||||
return list(kbs_config.keys())
|
||||
|
||||
@classmethod
|
||||
def list_kbs(cls):
|
||||
return list_kbs_from_db()
|
||||
|
||||
def exists(self, kb_name: str = None):
|
||||
kb_name = kb_name or self.kb_name
|
||||
return kb_exists(kb_name)
|
||||
|
||||
@abstractmethod
|
||||
def vs_type(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_init(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_drop_kb(self):
|
||||
"""
|
||||
删除知识库子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_search(self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
expr: str,
|
||||
custom_strategy_config: dict = {},
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""
|
||||
搜索知识库子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
向知识库添加文档子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile):
|
||||
"""
|
||||
从知识库删除文档子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_clear_vs(self):
|
||||
"""
|
||||
从知识库删除全部向量子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class KBServiceFactory:
|
||||
|
||||
@staticmethod
|
||||
def get_service(kb_name: str,
|
||||
vector_store_type: Union[str, SupportedVSType],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
) -> KBService:
|
||||
if isinstance(vector_store_type, str):
|
||||
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
||||
if SupportedVSType.FAISS == vector_store_type:
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
return FaissKBService(kb_name, embed_model=embed_model)
|
||||
elif SupportedVSType.PG == vector_store_type:
|
||||
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
|
||||
return PGKBService(kb_name, embed_model=embed_model)
|
||||
elif SupportedVSType.MILVUS == vector_store_type:
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name,embed_model=embed_model)
|
||||
elif SupportedVSType.ZILLIZ == vector_store_type:
|
||||
from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
|
||||
return ZillizKBService(kb_name, embed_model=embed_model)
|
||||
elif SupportedVSType.DEFAULT == vector_store_type:
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name,
|
||||
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
||||
elif SupportedVSType.ES == vector_store_type:
|
||||
from server.knowledge_base.kb_service.es_kb_service import ESKBService
|
||||
return ESKBService(kb_name, embed_model=embed_model)
|
||||
elif SupportedVSType.CHROMADB == vector_store_type:
|
||||
from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
|
||||
return ChromaKBService(kb_name, embed_model=embed_model)
|
||||
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||
return DefaultKBService(kb_name)
|
||||
|
||||
@staticmethod
|
||||
def get_service_by_name(kb_name: str) -> KBService:
|
||||
_, vs_type, embed_model = load_kb_from_db(kb_name)
|
||||
if _ is None: # kb not in db, just return None
|
||||
return None
|
||||
from server.utils import resolve_embed_model_name
|
||||
|
||||
embed_model = resolve_embed_model_name(embed_model)
|
||||
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
||||
|
||||
@staticmethod
|
||||
def get_default():
|
||||
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
|
||||
|
||||
|
||||
def get_kb_details() -> List[Dict]:
|
||||
kbs_in_folder = list_kbs_from_folder()
|
||||
kbs_in_db = KBService.list_kbs()
|
||||
result = {}
|
||||
|
||||
for kb in kbs_in_folder:
|
||||
result[kb] = {
|
||||
"kb_name": kb,
|
||||
"vs_type": "",
|
||||
"kb_info": "",
|
||||
"embed_model": "",
|
||||
"file_count": 0,
|
||||
"create_time": None,
|
||||
"in_folder": True,
|
||||
"in_db": False,
|
||||
}
|
||||
|
||||
for kb in kbs_in_db:
|
||||
kb_detail = get_kb_detail(kb)
|
||||
if kb_detail:
|
||||
kb_detail["in_db"] = True
|
||||
if kb in result:
|
||||
result[kb].update(kb_detail)
|
||||
else:
|
||||
kb_detail["in_folder"] = False
|
||||
result[kb] = kb_detail
|
||||
|
||||
data = []
|
||||
for i, v in enumerate(result.values()):
|
||||
v['No'] = i + 1
|
||||
data.append(v)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_kb_file_details(kb_name: str) -> List[Dict]:
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
if kb is None:
|
||||
return []
|
||||
|
||||
files_in_folder = list_files_from_folder(kb_name)
|
||||
files_in_db = kb.list_files()
|
||||
result = {}
|
||||
|
||||
for doc in files_in_folder:
|
||||
result[doc] = {
|
||||
"kb_name": kb_name,
|
||||
"file_name": doc,
|
||||
"file_ext": os.path.splitext(doc)[-1],
|
||||
"file_version": 0,
|
||||
"document_loader": "",
|
||||
"docs_count": 0,
|
||||
"text_splitter": "",
|
||||
"create_time": None,
|
||||
"in_folder": True,
|
||||
"in_db": False,
|
||||
}
|
||||
lower_names = {x.lower(): x for x in result}
|
||||
for doc in files_in_db:
|
||||
doc_detail = get_file_detail(kb_name, doc)
|
||||
if doc_detail:
|
||||
doc_detail["in_db"] = True
|
||||
if doc.lower() in lower_names:
|
||||
result[lower_names[doc.lower()]].update(doc_detail)
|
||||
else:
|
||||
doc_detail["in_folder"] = False
|
||||
result[doc] = doc_detail
|
||||
|
||||
data = []
|
||||
for i, v in enumerate(result.values()):
|
||||
v['No'] = i + 1
|
||||
data.append(v)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingsFunAdapter(Embeddings):
|
||||
def __init__(self, embed_model: str = EMBEDDING_MODEL):
|
||||
self.embed_model = embed_model
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
result = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False)
|
||||
embeddings = result.data if result and hasattr(result, 'data') else None
|
||||
if not embeddings:
|
||||
raise ValueError(f"Failed to get embeddings for texts: {texts[:2]}...")
|
||||
return normalize(embeddings).tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
|
||||
query_embed = embeddings[0]
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = normalize(query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
|
||||
return normalize(embeddings).tolist()
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
|
||||
query_embed = embeddings[0]
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = normalize(query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
|
||||
# def score_threshold_process(score_threshold, k, docs):
|
||||
# if score_threshold is not None:
|
||||
# cmp = (
|
||||
# operator.le
|
||||
# )
|
||||
# docs = [
|
||||
# (doc, similarity)
|
||||
# for doc, similarity in docs
|
||||
# if cmp(similarity, score_threshold)
|
||||
# ]
|
||||
# return docs[:k]
|
||||
@@ -0,0 +1,105 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.types import (GetResult, QueryResult)
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from configs import SCORE_THRESHOLD
|
||||
from server.knowledge_base.kb_service.base import (EmbeddingsFunAdapter,
|
||||
KBService, SupportedVSType)
|
||||
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||
|
||||
|
||||
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
|
||||
if not get_result['documents']:
|
||||
return []
|
||||
|
||||
_metadatas = get_result['metadatas'] if get_result['metadatas'] else [{}] * len(get_result['documents'])
|
||||
|
||||
document_list = []
|
||||
for page_content, metadata in zip(get_result['documents'], _metadatas):
|
||||
document_list.append(Document(**{'page_content': page_content, 'metadata': metadata}))
|
||||
|
||||
return document_list
|
||||
|
||||
|
||||
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
|
||||
"""
|
||||
from langchain_community.vectorstores.chroma import Chroma
|
||||
"""
|
||||
return [
|
||||
# TODO: Chroma can do batch querying,
|
||||
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
||||
for result in zip(
|
||||
results["documents"][0],
|
||||
results["metadatas"][0],
|
||||
results["distances"][0],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class ChromaKBService(KBService):
|
||||
vs_path: str
|
||||
kb_path: str
|
||||
|
||||
client = None
|
||||
collection = None
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.CHROMADB
|
||||
|
||||
def get_vs_path(self) -> str:
|
||||
return get_vs_path(self.kb_name, self.embed_model)
|
||||
|
||||
def get_kb_path(self) -> str:
|
||||
return get_kb_path(self.kb_name)
|
||||
|
||||
def do_init(self) -> None:
|
||||
self.kb_path = self.get_kb_path()
|
||||
self.vs_path = self.get_vs_path()
|
||||
self.client = chromadb.PersistentClient(path=self.vs_path)
|
||||
self.collection = self.client.get_or_create_collection(self.kb_name)
|
||||
|
||||
def do_create_kb(self) -> None:
|
||||
# In ChromaDB, creating a KB is equivalent to creating a collection
|
||||
self.collection = self.client.get_or_create_collection(self.kb_name)
|
||||
|
||||
def do_drop_kb(self):
|
||||
# Dropping a KB is equivalent to deleting a collection in ChromaDB
|
||||
try:
|
||||
self.client.delete_collection(self.kb_name)
|
||||
except ValueError as e:
|
||||
if not str(e) == f"Collection {self.kb_name} does not exist.":
|
||||
raise e
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD, expr: str) -> List[
|
||||
Tuple[Document, float]]:
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
|
||||
return _results_to_docs_and_scores(query_result)
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
doc_infos = []
|
||||
data = self._docs_to_embeddings(docs)
|
||||
ids = [str(uuid.uuid1()) for _ in range(len(data["texts"]))]
|
||||
for _id, text, embedding, metadata in zip(ids, data["texts"], data["embeddings"], data["metadatas"]):
|
||||
self.collection.add(ids=_id, embeddings=embedding, metadatas=metadata, documents=text)
|
||||
doc_infos.append({"id": _id, "metadata": metadata})
|
||||
return doc_infos
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
get_result: GetResult = self.collection.get(ids=ids)
|
||||
return _get_result_to_documents(get_result)
|
||||
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
self.collection.delete(ids=ids)
|
||||
return True
|
||||
|
||||
def do_clear_vs(self):
|
||||
# Clearing the vector store might be equivalent to dropping and recreating the collection
|
||||
self.do_drop_kb()
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
return self.collection.delete(where={"source": kb_file.filepath})
|
||||
@@ -0,0 +1,38 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBService
|
||||
|
||||
|
||||
class DefaultKBService(KBService):
|
||||
def do_create_kb(self):
|
||||
pass
|
||||
|
||||
def do_drop_kb(self):
|
||||
pass
|
||||
|
||||
def do_add_doc(self, docs: List[Document]):
|
||||
pass
|
||||
|
||||
def do_clear_vs(self):
|
||||
pass
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return "default"
|
||||
|
||||
def do_init(self):
|
||||
pass
|
||||
|
||||
def do_search(self):
|
||||
pass
|
||||
|
||||
def do_insert_multi_knowledge(self):
|
||||
pass
|
||||
|
||||
def do_insert_one_knowledge(self):
|
||||
pass
|
||||
|
||||
def do_delete_doc(self):
|
||||
pass
|
||||
261
langchain-chat/server/knowledge_base/kb_service/es_kb_service.py
Normal file
261
langchain-chat/server/knowledge_base/kb_service/es_kb_service.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from typing import List
|
||||
import os
|
||||
import shutil
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
from server.utils import load_local_embeddings
|
||||
from elasticsearch import Elasticsearch,BadRequestError
|
||||
from configs import logger
|
||||
from configs import kbs_config
|
||||
|
||||
class ESKBService(KBService):
|
||||
|
||||
def do_init(self):
|
||||
self.kb_path = self.get_kb_path(self.kb_name)
|
||||
self.index_name = os.path.split(self.kb_path)[-1]
|
||||
self.IP = kbs_config[self.vs_type()]['host']
|
||||
self.PORT = kbs_config[self.vs_type()]['port']
|
||||
self.user = kbs_config[self.vs_type()].get("user",'')
|
||||
self.password = kbs_config[self.vs_type()].get("password",'')
|
||||
self.dims_length = kbs_config[self.vs_type()].get("dims_length",None)
|
||||
self.embeddings_model = load_local_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||
try:
|
||||
# ES python客户端连接(仅连接)
|
||||
if self.user != "" and self.password != "":
|
||||
self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}",
|
||||
basic_auth=(self.user,self.password))
|
||||
else:
|
||||
logger.warning("ES未配置用户名和密码")
|
||||
self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}")
|
||||
except ConnectionError:
|
||||
logger.error("连接到 Elasticsearch 失败!")
|
||||
raise ConnectionError
|
||||
except Exception as e:
|
||||
logger.error(f"Error 发生 : {e}")
|
||||
raise e
|
||||
try:
|
||||
# 首先尝试通过es_client_python创建
|
||||
mappings = {
|
||||
"properties": {
|
||||
"dense_vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": self.dims_length,
|
||||
"index": True
|
||||
}
|
||||
}
|
||||
}
|
||||
self.es_client_python.indices.create(index=self.index_name, mappings=mappings)
|
||||
except BadRequestError as e:
|
||||
logger.error("创建索引失败,重新")
|
||||
logger.error(e)
|
||||
|
||||
try:
|
||||
# langchain ES 连接、创建索引
|
||||
if self.user != "" and self.password != "":
|
||||
self.db_init = ElasticsearchStore(
|
||||
es_url=f"http://{self.IP}:{self.PORT}",
|
||||
index_name=self.index_name,
|
||||
query_field="context",
|
||||
vector_query_field="dense_vector",
|
||||
embedding=self.embeddings_model,
|
||||
es_user=self.user,
|
||||
es_password=self.password
|
||||
)
|
||||
else:
|
||||
logger.warning("ES未配置用户名和密码")
|
||||
self.db_init = ElasticsearchStore(
|
||||
es_url=f"http://{self.IP}:{self.PORT}",
|
||||
index_name=self.index_name,
|
||||
query_field="context",
|
||||
vector_query_field="dense_vector",
|
||||
embedding=self.embeddings_model,
|
||||
)
|
||||
except ConnectionError:
|
||||
print("### 初始化 Elasticsearch 失败!")
|
||||
logger.error("### 初始化 Elasticsearch 失败!")
|
||||
raise ConnectionError
|
||||
except Exception as e:
|
||||
logger.error(f"Error 发生 : {e}")
|
||||
raise e
|
||||
try:
|
||||
# 尝试通过db_init创建索引
|
||||
self.db_init._create_index_if_not_exists(
|
||||
index_name=self.index_name,
|
||||
dims_length=self.dims_length
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("创建索引失败...")
|
||||
logger.error(e)
|
||||
# raise e
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_kb_path(knowledge_base_name: str):
|
||||
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||
|
||||
@staticmethod
|
||||
def get_vs_path(knowledge_base_name: str):
|
||||
return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store")
|
||||
|
||||
def do_create_kb(self):
|
||||
if os.path.exists(self.doc_path):
|
||||
if not os.path.exists(os.path.join(self.kb_path, "vector_store")):
|
||||
os.makedirs(os.path.join(self.kb_path, "vector_store"))
|
||||
else:
|
||||
logger.warning("directory `vector_store` already exists.")
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.ES
|
||||
|
||||
def _load_es(self, docs, embed_model):
|
||||
# 将docs写入到ES中
|
||||
try:
|
||||
# 连接 + 同时写入文档
|
||||
if self.user != "" and self.password != "":
|
||||
self.db = ElasticsearchStore.from_documents(
|
||||
documents=docs,
|
||||
embedding=embed_model,
|
||||
es_url= f"http://{self.IP}:{self.PORT}",
|
||||
index_name=self.index_name,
|
||||
distance_strategy="COSINE",
|
||||
query_field="context",
|
||||
vector_query_field="dense_vector",
|
||||
verify_certs=False,
|
||||
es_user=self.user,
|
||||
es_password=self.password
|
||||
)
|
||||
else:
|
||||
self.db = ElasticsearchStore.from_documents(
|
||||
documents=docs,
|
||||
embedding=embed_model,
|
||||
es_url= f"http://{self.IP}:{self.PORT}",
|
||||
index_name=self.index_name,
|
||||
distance_strategy="COSINE",
|
||||
query_field="context",
|
||||
vector_query_field="dense_vector",
|
||||
verify_certs=False)
|
||||
except ConnectionError as ce:
|
||||
print(ce)
|
||||
print("连接到 Elasticsearch 失败!")
|
||||
logger.error("连接到 Elasticsearch 失败!")
|
||||
except Exception as e:
|
||||
logger.error(f"Error 发生 : {e}")
|
||||
print(e)
|
||||
|
||||
|
||||
|
||||
def do_search(self, query:str, top_k: int, score_threshold: float,expr:str):
|
||||
# 文本相似性检索
|
||||
docs = self.db_init.similarity_search_with_score(query=query,
|
||||
k=top_k)
|
||||
return docs
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
results = []
|
||||
for doc_id in ids:
|
||||
try:
|
||||
response = self.es_client_python.get(index=self.index_name, id=doc_id)
|
||||
source = response["_source"]
|
||||
# Assuming your document has "text" and "metadata" fields
|
||||
text = source.get("context", "")
|
||||
metadata = source.get("metadata", {})
|
||||
results.append(Document(page_content=text, metadata=metadata))
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving document from Elasticsearch! {e}")
|
||||
return results
|
||||
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
for doc_id in ids:
|
||||
try:
|
||||
self.es_client_python.delete(index=self.index_name,
|
||||
id=doc_id,
|
||||
refresh=True)
|
||||
except Exception as e:
|
||||
logger.error(f"ES Docs Delete Error! {e}")
|
||||
|
||||
def do_delete_doc(self, kb_file, **kwargs):
|
||||
if self.es_client_python.indices.exists(index=self.index_name):
|
||||
# 从向量数据库中删除索引(文档名称是Keyword)
|
||||
query = {
|
||||
"query": {
|
||||
"term": {
|
||||
"metadata.source.keyword": kb_file.filepath
|
||||
}
|
||||
}
|
||||
}
|
||||
# 注意设置size,默认返回10个。
|
||||
search_results = self.es_client_python.search(body=query, size=50)
|
||||
delete_list = [hit["_id"] for hit in search_results['hits']['hits']]
|
||||
if len(delete_list) == 0:
|
||||
return None
|
||||
else:
|
||||
for doc_id in delete_list:
|
||||
try:
|
||||
self.es_client_python.delete(index=self.index_name,
|
||||
id=doc_id,
|
||||
refresh=True)
|
||||
except Exception as e:
|
||||
logger.error(f"ES Docs Delete Error! {e}")
|
||||
|
||||
# self.db_init.delete(ids=delete_list)
|
||||
#self.es_client_python.indices.refresh(index=self.index_name)
|
||||
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs):
|
||||
'''向知识库添加文件'''
|
||||
print(f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}")
|
||||
print("*"*100)
|
||||
self._load_es(docs=docs, embed_model=self.embeddings_model)
|
||||
# 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...]
|
||||
print("写入数据成功.")
|
||||
print("*"*100)
|
||||
|
||||
if self.es_client_python.indices.exists(index=self.index_name):
|
||||
file_path = docs[0].metadata.get("source")
|
||||
query = {
|
||||
"query": {
|
||||
"term": {
|
||||
"metadata.source.keyword": file_path
|
||||
},
|
||||
"term": {
|
||||
"_index": self.index_name
|
||||
}
|
||||
}
|
||||
}
|
||||
# 注意设置size,默认返回10个。
|
||||
search_results = self.es_client_python.search(body=query, size=50)
|
||||
if len(search_results["hits"]["hits"]) == 0:
|
||||
raise ValueError("召回元素个数为0")
|
||||
info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]]
|
||||
return info_docs
|
||||
|
||||
|
||||
def do_clear_vs(self):
|
||||
"""从知识库删除全部向量"""
|
||||
if self.es_client_python.indices.exists(index=self.kb_name):
|
||||
self.es_client_python.indices.delete(index=self.kb_name)
|
||||
|
||||
|
||||
def do_drop_kb(self):
|
||||
"""删除知识库"""
|
||||
# self.kb_file: 知识库路径
|
||||
if os.path.exists(self.kb_path):
|
||||
shutil.rmtree(self.kb_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
esKBService = ESKBService("test")
|
||||
#esKBService.clear_vs()
|
||||
#esKBService.create_kb()
|
||||
esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test"))
|
||||
print(esKBService.search_docs("如何启动api服务"))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from configs import SCORE_THRESHOLD, EXPR
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter
|
||||
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||
from server.utils import torch_gc
|
||||
from langchain.docstore.document import Document
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
vs_path: str
|
||||
kb_path: str
|
||||
vector_name: str = None
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.FAISS
|
||||
|
||||
def get_vs_path(self):
|
||||
return get_vs_path(self.kb_name, self.vector_name)
|
||||
|
||||
def get_kb_path(self):
|
||||
return get_kb_path(self.kb_name)
|
||||
|
||||
def load_vector_store(self) -> ThreadSafeFaiss:
|
||||
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
|
||||
vector_name=self.vector_name,
|
||||
embed_model=self.embed_model)
|
||||
|
||||
def save_vector_store(self):
|
||||
self.load_vector_store().save(self.vs_path)
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
return [vs.docstore._dict.get(id) for id in ids]
|
||||
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
vs.delete(ids)
|
||||
|
||||
def do_init(self):
|
||||
self.vector_name = self.vector_name or self.embed_model
|
||||
self.kb_path = self.get_kb_path()
|
||||
self.vs_path = self.get_vs_path()
|
||||
|
||||
def do_create_kb(self):
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
self.load_vector_store()
|
||||
|
||||
def do_drop_kb(self):
|
||||
self.clear_vs()
|
||||
try:
|
||||
shutil.rmtree(self.kb_path)
|
||||
except Exception:
|
||||
...
|
||||
|
||||
def do_search(self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
expr: str = EXPR,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间
|
||||
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
|
||||
metadatas=data["metadatas"],
|
||||
ids=kwargs.get("ids"))
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vs.save_local(self.vs_path)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
torch_gc()
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile,
|
||||
**kwargs):
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source").lower() == kb_file.filename.lower()]
|
||||
if len(ids) > 0:
|
||||
vs.delete(ids)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vs.save_local(self.vs_path)
|
||||
return ids
|
||||
|
||||
def do_clear_vs(self):
|
||||
with kb_faiss_pool.atomic:
|
||||
kb_faiss_pool.pop((self.kb_name, self.vector_name))
|
||||
try:
|
||||
shutil.rmtree(self.vs_path)
|
||||
except Exception:
|
||||
...
|
||||
os.makedirs(self.vs_path, exist_ok=True)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
if super().exist_doc(file_name):
|
||||
return "in_db"
|
||||
|
||||
content_path = os.path.join(self.kb_path, "content")
|
||||
if os.path.isfile(os.path.join(content_path, file_name)):
|
||||
return "in_folder"
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
faissService = FaissKBService("test")
|
||||
faissService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
faissService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
faissService.do_drop_kb()
|
||||
print(faissService.search_docs("如何启动api服务"))
|
||||
@@ -0,0 +1,207 @@
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.milvus import Milvus
|
||||
import os
|
||||
import logging
|
||||
|
||||
from configs import kbs_config
|
||||
from server.db.repository import list_file_num_docs_id_by_kb_name_and_file_name
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
import numpy as np
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MilvusKBService(KBService):
|
||||
milvus: Milvus
|
||||
|
||||
@staticmethod
|
||||
def get_collection(milvus_name):
|
||||
from pymilvus import Collection
|
||||
return Collection(milvus_name)
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
result = []
|
||||
if self.milvus and self.milvus.col:
|
||||
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
|
||||
data_list = self.milvus.col.query(expr=f'pk in {[int(_id) for _id in ids]}', output_fields=["*"])
|
||||
for data in data_list:
|
||||
text = data.pop("text")
|
||||
result.append(Document(page_content=text, metadata=data))
|
||||
return result
|
||||
|
||||
def get_doc_by_sources_name(self, source_name_list: List[str]) -> List[Document]:
|
||||
result = []
|
||||
if self.milvus and self.milvus.col:
|
||||
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
|
||||
data_list = self.milvus.col.query(expr=f'source in {source_name_list}', output_fields=["*"])
|
||||
for data in data_list:
|
||||
text = data.pop("text")
|
||||
result.append(Document(page_content=text, metadata=data))
|
||||
return result
|
||||
|
||||
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
if self.milvus and self.milvus.col:
|
||||
self.milvus.col.delete(expr=f'pk in {ids}')
|
||||
|
||||
@staticmethod
|
||||
def search(milvus_name, content, limit=3):
|
||||
search_params = {
|
||||
"metric_type": "L2",
|
||||
"params": {"nprobe": 10},
|
||||
}
|
||||
c = MilvusKBService.get_collection(milvus_name)
|
||||
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
|
||||
|
||||
def do_create_kb(self):
|
||||
pass
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.MILVUS
|
||||
|
||||
def _load_milvus(self):
|
||||
try:
|
||||
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||
collection_name=self.kb_name,
|
||||
connection_args=kbs_config.get("milvus"),
|
||||
index_params=kbs_config.get("milvus_kwargs")["index_params"],
|
||||
search_params=kbs_config.get("milvus_kwargs")["search_params"],
|
||||
auto_id=True
|
||||
)
|
||||
logger.info("成功加载 Milvus 实例 'milvus'。")
|
||||
|
||||
# -------- 兼容不同 schema 的文本字段 --------
|
||||
# 新库尚无 Milvus 集合时 langchain_community.Milvus.col 为 None,
|
||||
# 会在首次 add_documents 建表后再有 schema,此处勿访问 .col.schema。
|
||||
try:
|
||||
col = self.milvus.col
|
||||
if col is None:
|
||||
logger.debug(
|
||||
"集合 %s 尚未在 Milvus 中建表,跳过文本字段探测(首次写入时会自动创建)",
|
||||
self.kb_name,
|
||||
)
|
||||
else:
|
||||
field_names = [f.name for f in col.schema.fields]
|
||||
if self.milvus._text_field not in field_names:
|
||||
if "page_content" in field_names:
|
||||
self.milvus._text_field = "page_content"
|
||||
elif "content" in field_names:
|
||||
self.milvus._text_field = "content"
|
||||
else:
|
||||
for f in col.schema.fields:
|
||||
if hasattr(f, "dtype") and str(f.dtype).startswith("DataType.VARCHAR"):
|
||||
self.milvus._text_field = f.name
|
||||
break
|
||||
logger.info(f"集合 {self.kb_name} 使用文本字段: {self.milvus._text_field}")
|
||||
except Exception as e:
|
||||
logger.warning(f"检测并设置文本字段失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载 Milvus 实例 'milvus' 失败: {e}")
|
||||
self._create_collection_if_not_exists()
|
||||
# 重新加载
|
||||
# self._load_milvus()
|
||||
|
||||
def _create_collection_if_not_exists(self):
|
||||
"""根据传入字段创建 Milvus 集合"""
|
||||
from pymilvus import Collection, CollectionSchema, FieldSchema, DataType
|
||||
from langchain_community.vectorstores import Milvus
|
||||
|
||||
# 定义你的字段(根据你的需求修改)
|
||||
fields = [
|
||||
FieldSchema(name="pk", dtype=DataType.Int64, is_primary=True, auto_id=True),
|
||||
FieldSchema(name="vector", dtype=DataType.FloatVector, dim=768), # dim 根据 embedding 模型调整
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=1024),
|
||||
FieldSchema(name="metadata", dtype=DataType.VARCHAR, max_length=65535),
|
||||
# 添加其他自定义字段...
|
||||
]
|
||||
|
||||
schema = CollectionSchema(fields=fields, description=self.kb_name)
|
||||
|
||||
# 创建集合
|
||||
collection = Collection(name=self.kb_name, schema=schema, using="default")
|
||||
|
||||
# 创建索引
|
||||
index_params = kbs_config.get("milvus_kwargs")["index_params"]
|
||||
collection.create_index(field_name="vector", index_params=index_params)
|
||||
|
||||
logger.info(f"成功创建 Milvus 集合: {self.kb_name}")
|
||||
|
||||
def do_init(self):
|
||||
self._load_milvus()
|
||||
|
||||
def do_drop_kb(self):
|
||||
if self.milvus and self.milvus.col:
|
||||
self.milvus.col.release()
|
||||
# self.milvus.col.drop() # 禁用从chatchat删除集合
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, expr: str, custom_strategy_config: dict = {}):
|
||||
self._load_milvus()
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
try:
|
||||
embeddings = embed_func.embed_query(query)
|
||||
if top_k > 50:
|
||||
# 按顺序返回全文内容
|
||||
docs = self.milvus.similarity_search_by_vector(embeddings, top_k, expr = expr)
|
||||
docs = sorted(docs, key=lambda doc: doc.metadata['pk']) # 根据 pk 从小到大排序
|
||||
# return score_threshold_process(query,score_threshold, top_k, docs)
|
||||
return docs
|
||||
else:
|
||||
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k, expr = expr)
|
||||
# TODO 动态score_threshold
|
||||
return score_threshold_process(query,score_threshold, top_k, docs)
|
||||
except Exception as e:
|
||||
logger.error(f"搜索 Milvus 集合 '{self.kb_name}' 失败: {e}")
|
||||
return []
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
for doc in docs:
|
||||
for k, v in doc.metadata.items():
|
||||
doc.metadata[k] = str(v)
|
||||
for field in self.milvus.fields:
|
||||
doc.metadata.setdefault(field, "")
|
||||
doc.metadata.pop(self.milvus._text_field, None)
|
||||
doc.metadata.pop(self.milvus._vector_field, None)
|
||||
|
||||
ids = self.milvus.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
id_list = list_file_num_docs_id_by_kb_name_and_file_name(kb_file.kb_name, kb_file.filename)
|
||||
if self.milvus and self.milvus.col:
|
||||
self.milvus.col.delete(expr=f'pk in {id_list}')
|
||||
|
||||
# Issue 2846, for windows
|
||||
# if self.milvus.col:
|
||||
# file_path = kb_file.filepath.replace("\\", "\\\\")
|
||||
# file_name = os.path.basename(file_path)
|
||||
# id_list = [item.get("pk") for item in
|
||||
# self.milvus.col.query(expr=f'source == "{file_name}"', output_fields=["pk"])]
|
||||
# self.milvus.col.delete(expr=f'pk in {id_list}')
|
||||
|
||||
def do_clear_vs(self):
|
||||
if self.milvus and self.milvus.col:
|
||||
self.do_drop_kb()
|
||||
self.do_init()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 测试建表使用
|
||||
from server.db.base import Base, engine
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
milvusService = MilvusKBService("t_policy_total_bce_v1")
|
||||
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
|
||||
# print(milvusService.get_doc_by_ids(["444022434274215486"]))
|
||||
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
# milvusService.do_drop_kb()
|
||||
# print(milvusService.search_docs("如何启动api服务"))
|
||||
@@ -0,0 +1,96 @@
|
||||
import json
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.pgvector import PGVector, DistanceStrategy
|
||||
from sqlalchemy import text
|
||||
|
||||
from configs import kbs_config
|
||||
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
import shutil
|
||||
import sqlalchemy
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class PGKBService(KBService):
|
||||
engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10)
|
||||
|
||||
def _load_pg_vector(self):
|
||||
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||
collection_name=self.kb_name,
|
||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||
connection=PGKBService.engine,
|
||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
with Session(PGKBService.engine) as session:
|
||||
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
|
||||
results = [Document(page_content=row[0], metadata=row[1]) for row in
|
||||
session.execute(stmt, {'ids': ids}).fetchall()]
|
||||
return results
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
return super().del_doc_by_ids(ids)
|
||||
|
||||
def do_init(self):
|
||||
self._load_pg_vector()
|
||||
|
||||
def do_create_kb(self):
|
||||
pass
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.PG
|
||||
|
||||
def do_drop_kb(self):
|
||||
with Session(PGKBService.engine) as session:
|
||||
session.execute(text(f'''
|
||||
-- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录
|
||||
DELETE FROM langchain_pg_embedding
|
||||
WHERE collection_id IN (
|
||||
SELECT uuid FROM langchain_pg_collection WHERE name = '{self.kb_name}'
|
||||
);
|
||||
-- 删除 langchain_pg_collection 表中 记录
|
||||
DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}';
|
||||
'''))
|
||||
session.commit()
|
||||
shutil.rmtree(self.kb_path)
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float,expr:str):
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
ids = self.pg_vector.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
with Session(PGKBService.engine) as session:
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
session.execute(
|
||||
text(
|
||||
''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace(
|
||||
"filepath", filepath)))
|
||||
session.commit()
|
||||
|
||||
def do_clear_vs(self):
|
||||
self.pg_vector.delete_collection()
|
||||
self.pg_vector.create_collection()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from server.db.base import Base, engine
|
||||
|
||||
# Base.metadata.create_all(bind=engine)
|
||||
pGKBService = PGKBService("test")
|
||||
# pGKBService.create_kb()
|
||||
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
# pGKBService.drop_kb()
|
||||
print(pGKBService.get_doc_by_ids(["f1e51390-3029-4a19-90dc-7118aaa25772"]))
|
||||
# print(pGKBService.search_docs("如何启动api服务"))
|
||||
@@ -0,0 +1,97 @@
|
||||
from typing import List, Dict, Optional
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Zilliz
|
||||
from configs import kbs_config
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
|
||||
class ZillizKBService(KBService):
|
||||
zilliz: Zilliz
|
||||
|
||||
@staticmethod
|
||||
def get_collection(zilliz_name):
|
||||
from pymilvus import Collection
|
||||
return Collection(zilliz_name)
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
result = []
|
||||
if self.zilliz.col:
|
||||
# ids = [int(id) for id in ids] # for zilliz if needed #pr 2725
|
||||
data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"])
|
||||
for data in data_list:
|
||||
text = data.pop("text")
|
||||
result.append(Document(page_content=text, metadata=data))
|
||||
return result
|
||||
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
self.zilliz.col.delete(expr=f'pk in {ids}')
|
||||
|
||||
@staticmethod
|
||||
def search(zilliz_name, content, limit=3):
|
||||
search_params = {
|
||||
"metric_type": "IP",
|
||||
"params": {},
|
||||
}
|
||||
c = ZillizKBService.get_collection(zilliz_name)
|
||||
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
|
||||
|
||||
def do_create_kb(self):
|
||||
pass
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.ZILLIZ
|
||||
|
||||
def _load_zilliz(self):
|
||||
zilliz_args = kbs_config.get("zilliz")
|
||||
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||
collection_name=self.kb_name, connection_args=zilliz_args)
|
||||
|
||||
def do_init(self):
|
||||
self._load_zilliz()
|
||||
|
||||
def do_drop_kb(self):
|
||||
if self.zilliz.col:
|
||||
self.zilliz.col.release()
|
||||
self.zilliz.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float,expr:str):
|
||||
self._load_zilliz()
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
for doc in docs:
|
||||
for k, v in doc.metadata.items():
|
||||
doc.metadata[k] = str(v)
|
||||
for field in self.zilliz.fields:
|
||||
doc.metadata.setdefault(field, "")
|
||||
doc.metadata.pop(self.zilliz._text_field, None)
|
||||
doc.metadata.pop(self.zilliz._vector_field, None)
|
||||
|
||||
ids = self.zilliz.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
if self.zilliz.col:
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
delete_list = [item.get("pk") for item in
|
||||
self.zilliz.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
||||
self.zilliz.col.delete(expr=f'pk in {delete_list}')
|
||||
|
||||
def do_clear_vs(self):
|
||||
if self.zilliz.col:
|
||||
self.do_drop_kb()
|
||||
self.do_init()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from server.db.base import Base, engine
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
zillizService = ZillizKBService("test")
|
||||
78
langchain-chat/server/knowledge_base/kb_summary/base.py
Normal file
78
langchain-chat/server/knowledge_base/kb_summary/base.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import List
|
||||
|
||||
from configs import (
|
||||
EMBEDDING_MODEL,
|
||||
KB_ROOT_PATH)
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||
import os
|
||||
import shutil
|
||||
from server.db.repository.knowledge_metadata_repository import add_summary_to_db, delete_summary_from_db
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class KBSummaryService(ABC):
|
||||
kb_name: str
|
||||
embed_model: str
|
||||
vs_path: str
|
||||
kb_path: str
|
||||
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.embed_model = embed_model
|
||||
|
||||
self.kb_path = self.get_kb_path()
|
||||
self.vs_path = self.get_vs_path()
|
||||
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
|
||||
|
||||
def get_vs_path(self):
|
||||
return os.path.join(self.get_kb_path(), "summary_vector_store")
|
||||
|
||||
def get_kb_path(self):
|
||||
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
||||
|
||||
def load_vector_store(self) -> ThreadSafeFaiss:
|
||||
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
|
||||
vector_name="summary_vector_store",
|
||||
embed_model=self.embed_model,
|
||||
create=True)
|
||||
|
||||
def add_kb_summary(self, summary_combine_docs: List[Document]):
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
ids = vs.add_documents(documents=summary_combine_docs)
|
||||
vs.save_local(self.vs_path)
|
||||
|
||||
summary_infos = [{"summary_context": doc.page_content,
|
||||
"summary_id": id,
|
||||
"doc_ids": doc.metadata.get('doc_ids'),
|
||||
"metadata": doc.metadata} for id, doc in zip(ids, summary_combine_docs)]
|
||||
status = add_summary_to_db(kb_name=self.kb_name, summary_infos=summary_infos)
|
||||
return status
|
||||
|
||||
def create_kb_summary(self):
|
||||
"""
|
||||
创建知识库chunk summary
|
||||
:return:
|
||||
"""
|
||||
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
|
||||
def drop_kb_summary(self):
|
||||
"""
|
||||
删除知识库chunk summary
|
||||
:param kb_name:
|
||||
:return:
|
||||
"""
|
||||
with kb_faiss_pool.atomic:
|
||||
kb_faiss_pool.pop(self.kb_name)
|
||||
shutil.rmtree(self.vs_path)
|
||||
delete_summary_from_db(kb_name=self.kb_name)
|
||||
241
langchain-chat/server/knowledge_base/kb_summary/summary_chunk.py
Normal file
241
langchain-chat/server/knowledge_base/kb_summary/summary_chunk.py
Normal file
@@ -0,0 +1,241 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from langchain.schema.language_model import BaseLanguageModel
|
||||
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
from configs import (logger)
|
||||
from langchain.chains import StuffDocumentsChain, LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.output_parsers.regex import RegexParser
|
||||
from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
|
||||
|
||||
class SummaryAdapter:
|
||||
_OVERLAP_SIZE: int
|
||||
token_max: int
|
||||
_separator: str = "\n\n"
|
||||
chain: MapReduceDocumentsChain
|
||||
|
||||
def __init__(self, overlap_size: int, token_max: int,
|
||||
chain: MapReduceDocumentsChain):
|
||||
self._OVERLAP_SIZE = overlap_size
|
||||
self.chain = chain
|
||||
self.token_max = token_max
|
||||
|
||||
@classmethod
|
||||
def form_summary(cls,
|
||||
llm: BaseLanguageModel,
|
||||
reduce_llm: BaseLanguageModel,
|
||||
overlap_size: int,
|
||||
token_max: int = 1300):
|
||||
"""
|
||||
获取实例
|
||||
:param reduce_llm: 用于合并摘要的llm
|
||||
:param llm: 用于生成摘要的llm
|
||||
:param overlap_size: 重叠部分大小
|
||||
:param token_max: 最大的chunk数量,每个chunk长度小于token_max长度,第一次生成摘要时,大于token_max长度的摘要会报错
|
||||
:return:
|
||||
"""
|
||||
|
||||
# This controls how each document will be formatted. Specifically,
|
||||
document_prompt = PromptTemplate(
|
||||
input_variables=["page_content"],
|
||||
template="{page_content}"
|
||||
)
|
||||
|
||||
# The prompt here should take as an input variable the
|
||||
# `document_variable_name`
|
||||
prompt_template = (
|
||||
"根据文本执行任务。以下任务信息"
|
||||
"{task_briefing}"
|
||||
"文本内容如下: "
|
||||
"\r\n"
|
||||
"{context}"
|
||||
)
|
||||
prompt = PromptTemplate(
|
||||
template=prompt_template,
|
||||
input_variables=["task_briefing", "context"]
|
||||
)
|
||||
llm_chain = LLMChain(llm=llm, prompt=prompt)
|
||||
# We now define how to combine these summaries
|
||||
reduce_prompt = PromptTemplate.from_template(
|
||||
"Combine these summaries: {context}"
|
||||
)
|
||||
reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt)
|
||||
|
||||
document_variable_name = "context"
|
||||
combine_documents_chain = StuffDocumentsChain(
|
||||
llm_chain=reduce_llm_chain,
|
||||
document_prompt=document_prompt,
|
||||
document_variable_name=document_variable_name
|
||||
)
|
||||
reduce_documents_chain = ReduceDocumentsChain(
|
||||
token_max=token_max,
|
||||
combine_documents_chain=combine_documents_chain,
|
||||
)
|
||||
chain = MapReduceDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
reduce_documents_chain=reduce_documents_chain,
|
||||
# 返回中间步骤
|
||||
return_intermediate_steps=True
|
||||
)
|
||||
return cls(overlap_size=overlap_size,
|
||||
chain=chain,
|
||||
token_max=token_max)
|
||||
|
||||
def summarize(self,
|
||||
file_description: str,
|
||||
docs: List[DocumentWithVSId] = []
|
||||
) -> List[Document]:
|
||||
|
||||
if sys.version_info < (3, 10):
|
||||
loop = asyncio.get_event_loop()
|
||||
else:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
asyncio.set_event_loop(loop)
|
||||
# 同步调用协程代码
|
||||
return loop.run_until_complete(self.asummarize(file_description=file_description,
|
||||
docs=docs))
|
||||
|
||||
async def asummarize(self,
|
||||
file_description: str,
|
||||
docs: List[DocumentWithVSId] = []) -> List[Document]:
|
||||
|
||||
logger.info("start summary")
|
||||
"""
|
||||
这个过程分成两个部分:
|
||||
1. 对每个文档进行处理,得到每个文档的摘要
|
||||
map_results = self.llm_chain.apply(
|
||||
# FYI - this is parallelized and so it is fast.
|
||||
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
|
||||
callbacks=callbacks,
|
||||
)
|
||||
2. 对每个文档的摘要进行合并,得到最终的摘要,return_intermediate_steps=True,返回中间步骤
|
||||
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
|
||||
result_docs, token_max=token_max, callbacks=callbacks, **kwargs
|
||||
)
|
||||
"""
|
||||
summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs,
|
||||
task_briefing="描述不同方法之间的接近度和相似性,"
|
||||
"以帮助读者理解它们之间的关系。")
|
||||
print(summary_combine)
|
||||
print(summary_intermediate_steps)
|
||||
|
||||
# if len(summary_combine) == 0:
|
||||
# # 为空重新生成,数量减半
|
||||
# result_docs = [
|
||||
# Document(page_content=question_result_key, metadata=docs[i].metadata)
|
||||
# # This uses metadata from the docs, and the textual results from `results`
|
||||
# for i, question_result_key in enumerate(
|
||||
# summary_intermediate_steps["intermediate_steps"][
|
||||
# :len(summary_intermediate_steps["intermediate_steps"]) // 2
|
||||
# ])
|
||||
# ]
|
||||
# summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs(
|
||||
# result_docs, token_max=self.token_max
|
||||
# )
|
||||
logger.info("end summary")
|
||||
doc_ids = ",".join([doc.id for doc in docs])
|
||||
_metadata = {
|
||||
"file_description": file_description,
|
||||
"summary_intermediate_steps": summary_intermediate_steps,
|
||||
"doc_ids": doc_ids
|
||||
}
|
||||
summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata)
|
||||
|
||||
return [summary_combine_doc]
|
||||
|
||||
def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]:
|
||||
"""
|
||||
# 将文档中page_content句子叠加的部分去掉
|
||||
:param docs:
|
||||
:param separator:
|
||||
:return:
|
||||
"""
|
||||
merge_docs = []
|
||||
|
||||
pre_doc = None
|
||||
for doc in docs:
|
||||
# 第一个文档直接添加
|
||||
if len(merge_docs) == 0:
|
||||
pre_doc = doc.page_content
|
||||
merge_docs.append(doc.page_content)
|
||||
continue
|
||||
|
||||
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||||
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
|
||||
# 查询重叠部分,直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator)
|
||||
for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1):
|
||||
# 每次迭代删除前面的字符
|
||||
pre_doc = pre_doc[1:]
|
||||
if doc.page_content[:len(pre_doc)] == pre_doc:
|
||||
# 删除下一个开头重叠的部分
|
||||
merge_docs.append(doc.page_content[len(pre_doc):])
|
||||
break
|
||||
|
||||
pre_doc = doc.page_content
|
||||
|
||||
return merge_docs
|
||||
|
||||
def _join_docs(self, docs: List[str]) -> Optional[str]:
|
||||
text = self._separator.join(docs)
|
||||
text = text.strip()
|
||||
if text == "":
|
||||
return None
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
docs = [
|
||||
|
||||
'梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的',
|
||||
|
||||
'梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象',
|
||||
|
||||
'使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各',
|
||||
'值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全'
|
||||
]
|
||||
_OVERLAP_SIZE = 1
|
||||
separator: str = "\n\n"
|
||||
merge_docs = []
|
||||
# 将文档中page_content句子叠加的部分去掉,
|
||||
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||||
pre_doc = None
|
||||
for doc in docs:
|
||||
# 第一个文档直接添加
|
||||
if len(merge_docs) == 0:
|
||||
pre_doc = doc
|
||||
merge_docs.append(doc)
|
||||
continue
|
||||
|
||||
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
|
||||
# 迭代递减pre_doc的长度,每次迭代删除前面的字符,
|
||||
# 查询重叠部分,直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator)
|
||||
for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1):
|
||||
# 每次迭代删除前面的字符
|
||||
pre_doc = pre_doc[1:]
|
||||
if doc[:len(pre_doc)] == pre_doc:
|
||||
# 删除下一个开头重叠的部分
|
||||
page_content = doc[len(pre_doc):]
|
||||
merge_docs.append(page_content)
|
||||
|
||||
pre_doc = doc
|
||||
break
|
||||
|
||||
# 将merge_docs中的句子合并成一个文档
|
||||
text = separator.join(merge_docs)
|
||||
text = text.strip()
|
||||
|
||||
print(text)
|
||||
220
langchain-chat/server/knowledge_base/kb_summary_api.py
Normal file
220
langchain-chat/server/knowledge_base/kb_summary_api.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from fastapi import Body
|
||||
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||
OVERLAP_SIZE,
|
||||
logger, log_verbose, )
|
||||
from server.knowledge_base.utils import (list_files_from_folder)
|
||||
from sse_starlette import EventSourceResponse
|
||||
import json
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from typing import List, Optional
|
||||
from server.knowledge_base.kb_summary.base import KBSummaryService
|
||||
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
|
||||
from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
|
||||
from configs import LLM_MODELS, TEMPERATURE
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
|
||||
def recreate_summary_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
):
|
||||
"""
|
||||
重建单个知识库文件摘要
|
||||
:param max_tokens:
|
||||
:param model_name:
|
||||
:param temperature:
|
||||
:param file_description:
|
||||
:param knowledge_base_name:
|
||||
:param allow_empty_kb:
|
||||
:param vs_type:
|
||||
:param embed_model:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def output():
|
||||
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
# 重新创建知识库
|
||||
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||
kb_summary.drop_kb_summary()
|
||||
kb_summary.create_kb_summary()
|
||||
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
# 文本摘要适配器
|
||||
summary = SummaryAdapter.form_summary(llm=llm,
|
||||
reduce_llm=reduce_llm,
|
||||
overlap_size=OVERLAP_SIZE)
|
||||
files = list_files_from_folder(knowledge_base_name)
|
||||
|
||||
i = 0
|
||||
for i, file_name in enumerate(files):
|
||||
|
||||
doc_infos = kb.list_docs(file_name=file_name)
|
||||
docs = summary.summarize(file_description=file_description,
|
||||
docs=doc_infos)
|
||||
|
||||
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||
if status_kb_summary:
|
||||
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"({i + 1} / {len(files)}): {file_name}",
|
||||
"total": len(files),
|
||||
"finished": i + 1,
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
|
||||
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
i += 1
|
||||
|
||||
return EventSourceResponse(output())
|
||||
|
||||
|
||||
def summary_file_to_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
file_name: str = Body(..., examples=["test.pdf"]),
|
||||
allow_empty_kb: bool = Body(True),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
):
|
||||
"""
|
||||
单个知识库根据文件名称摘要
|
||||
:param model_name:
|
||||
:param max_tokens:
|
||||
:param temperature:
|
||||
:param file_description:
|
||||
:param file_name:
|
||||
:param knowledge_base_name:
|
||||
:param allow_empty_kb:
|
||||
:param vs_type:
|
||||
:param embed_model:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def output():
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists() and not allow_empty_kb:
|
||||
yield {"code": 404, "msg": f"未找到知识库 ‘{knowledge_base_name}’"}
|
||||
else:
|
||||
# 重新创建知识库
|
||||
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
|
||||
kb_summary.create_kb_summary()
|
||||
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
# 文本摘要适配器
|
||||
summary = SummaryAdapter.form_summary(llm=llm,
|
||||
reduce_llm=reduce_llm,
|
||||
overlap_size=OVERLAP_SIZE)
|
||||
|
||||
doc_infos = kb.list_docs(file_name=file_name)
|
||||
docs = summary.summarize(file_description=file_description,
|
||||
docs=doc_infos)
|
||||
|
||||
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
|
||||
if status_kb_summary:
|
||||
logger.info(f" {file_name} 总结完成")
|
||||
yield json.dumps({
|
||||
"code": 200,
|
||||
"msg": f"{file_name} 总结完成",
|
||||
"doc": file_name,
|
||||
}, ensure_ascii=False)
|
||||
else:
|
||||
|
||||
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
|
||||
logger.error(msg)
|
||||
yield json.dumps({
|
||||
"code": 500,
|
||||
"msg": msg,
|
||||
})
|
||||
|
||||
return EventSourceResponse(output())
|
||||
|
||||
|
||||
def summary_doc_ids_to_vector_store(
|
||||
knowledge_base_name: str = Body(..., examples=["samples"]),
|
||||
doc_ids: List = Body([], examples=[["uuid"]]),
|
||||
vs_type: str = Body(DEFAULT_VS_TYPE),
|
||||
embed_model: str = Body(EMBEDDING_MODEL),
|
||||
file_description: str = Body(''),
|
||||
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
||||
) -> BaseResponse:
|
||||
"""
|
||||
单个知识库根据doc_ids摘要
|
||||
:param knowledge_base_name:
|
||||
:param doc_ids:
|
||||
:param model_name:
|
||||
:param max_tokens:
|
||||
:param temperature:
|
||||
:param file_description:
|
||||
:param vs_type:
|
||||
:param embed_model:
|
||||
:return:
|
||||
"""
|
||||
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
|
||||
if not kb.exists():
|
||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={})
|
||||
else:
|
||||
llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
reduce_llm = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
# 文本摘要适配器
|
||||
summary = SummaryAdapter.form_summary(llm=llm,
|
||||
reduce_llm=reduce_llm,
|
||||
overlap_size=OVERLAP_SIZE)
|
||||
|
||||
doc_infos = kb.get_doc_by_ids(ids=doc_ids)
|
||||
# doc_infos转换成DocumentWithVSId包装的对象
|
||||
doc_info_with_ids = [DocumentWithVSId(**doc.dict(), id=with_id) for with_id, doc in zip(doc_ids, doc_infos)]
|
||||
|
||||
docs = summary.summarize(file_description=file_description,
|
||||
docs=doc_info_with_ids)
|
||||
|
||||
# 将docs转换成dict
|
||||
resp_summarize = [{**doc.dict()} for doc in docs]
|
||||
|
||||
return BaseResponse(code=200, msg="总结完成", data={"summarize": resp_summarize})
|
||||
192
langchain-chat/server/knowledge_base/migrate.py
Normal file
192
langchain-chat/server/knowledge_base/migrate.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from configs import (
|
||||
EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE,
|
||||
logger, log_verbose
|
||||
)
|
||||
from server.knowledge_base.utils import (
|
||||
get_file_path, list_kbs_from_folder,
|
||||
list_files_from_folder, files2docs_in_thread,
|
||||
KnowledgeFile
|
||||
)
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.models.conversation_model import ConversationModel
|
||||
from server.db.models.message_model import MessageModel
|
||||
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
|
||||
from server.db.repository.knowledge_metadata_repository import add_summary_to_db
|
||||
|
||||
from server.db.base import Base, engine
|
||||
from server.db.session import session_scope
|
||||
import os
|
||||
from dateutil.parser import parse
|
||||
from typing import Literal, List
|
||||
import time
|
||||
|
||||
def create_tables():
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def reset_tables():
|
||||
Base.metadata.drop_all(bind=engine)
|
||||
create_tables()
|
||||
|
||||
|
||||
def import_from_db(
|
||||
sqlite_path: str = None,
|
||||
# csv_path: str = None,
|
||||
) -> bool:
|
||||
"""
|
||||
在知识库与向量库无变化的情况下,从备份数据库中导入数据到 info.db。
|
||||
适用于版本升级时,info.db 结构变化,但无需重新向量化的情况。
|
||||
请确保两边数据库表名一致,需要导入的字段名一致
|
||||
当前仅支持 sqlite
|
||||
"""
|
||||
import sqlite3 as sql
|
||||
from pprint import pprint
|
||||
|
||||
models = list(Base.registry.mappers)
|
||||
|
||||
try:
|
||||
con = sql.connect(sqlite_path)
|
||||
con.row_factory = sql.Row
|
||||
cur = con.cursor()
|
||||
tables = [x["name"] for x in cur.execute("select name from sqlite_master where type='table'").fetchall()]
|
||||
for model in models:
|
||||
table = model.local_table.fullname
|
||||
if table not in tables:
|
||||
continue
|
||||
print(f"processing table: {table}")
|
||||
with session_scope() as session:
|
||||
for row in cur.execute(f"select * from {table}").fetchall():
|
||||
data = {k: row[k] for k in row.keys() if k in model.columns}
|
||||
if "create_time" in data:
|
||||
data["create_time"] = parse(data["create_time"])
|
||||
pprint(data)
|
||||
session.add(model.class_(**data))
|
||||
con.close()
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"无法读取备份数据库:{sqlite_path}。错误信息:{e}")
|
||||
return False
|
||||
|
||||
|
||||
def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
|
||||
kb_files = []
|
||||
for file in files:
|
||||
try:
|
||||
kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name)
|
||||
kb_files.append(kb_file)
|
||||
except Exception as e:
|
||||
msg = f"{e},已跳过"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return kb_files
|
||||
|
||||
|
||||
def folder2db(
|
||||
kb_names: List[str],
|
||||
mode: Literal["recreate_vs", "update_in_db", "increment"],
|
||||
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||
):
|
||||
"""
|
||||
use existed files in local folder to populate database and/or vector store.
|
||||
set parameter `mode` to:
|
||||
recreate_vs: recreate all vector store and fill info to database using existed files in local folder
|
||||
fill_info_only(disabled): do not create vector store, fill info to db using existed files only
|
||||
update_in_db: update vector store and database info using local files that existed in database only
|
||||
increment: create vector store and database info for local files that not existed in database only
|
||||
"""
|
||||
|
||||
def files2vs(kb_name: str, kb_files: List[KnowledgeFile]):
|
||||
# 切片
|
||||
for success, result in files2docs_in_thread(kb_files,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance):
|
||||
if success:
|
||||
_, filename, docs = result
|
||||
print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
|
||||
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kb_file.splited_docs = docs
|
||||
# 向量化
|
||||
kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True)
|
||||
else:
|
||||
print(result)
|
||||
|
||||
kb_names = kb_names or list_kbs_from_folder()
|
||||
for kb_name in kb_names:
|
||||
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
||||
if not kb.exists():
|
||||
kb.create_kb()
|
||||
|
||||
# 清除向量库,从本地文件重建
|
||||
if mode == "recreate_vs":
|
||||
kb.clear_vs()
|
||||
kb.create_kb()
|
||||
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
|
||||
files2vs(kb_name, kb_files)
|
||||
start_time_by_save_vs = time.time()
|
||||
kb.save_vector_store()
|
||||
end_time_by_save_vs = time.time()
|
||||
print('向量入库耗时:{}秒'.format(end_time_by_save_vs - start_time_by_save_vs))
|
||||
# # 不做文件内容的向量化,仅将文件元信息存到数据库
|
||||
# # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。
|
||||
# elif mode == "fill_info_only":
|
||||
# files = list_files_from_folder(kb_name)
|
||||
# kb_files = file_to_kbfile(kb_name, files)
|
||||
# for kb_file in kb_files:
|
||||
# add_file_to_db(kb_file)
|
||||
# print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
|
||||
# 以数据库中文件列表为基准,利用本地文件更新向量库
|
||||
elif mode == "update_in_db":
|
||||
files = kb.list_files()
|
||||
kb_files = file_to_kbfile(kb_name, files)
|
||||
files2vs(kb_name, kb_files)
|
||||
kb.save_vector_store()
|
||||
# 对比本地目录与数据库中的文件列表,进行增量向量化
|
||||
elif mode == "increment":
|
||||
db_files = kb.list_files()
|
||||
folder_files = list_files_from_folder(kb_name)
|
||||
files = list(set(folder_files) - set(db_files))
|
||||
kb_files = file_to_kbfile(kb_name, files)
|
||||
files2vs(kb_name, kb_files)
|
||||
kb.save_vector_store()
|
||||
else:
|
||||
print(f"unsupported migrate mode: {mode}")
|
||||
|
||||
|
||||
def prune_db_docs(kb_names: List[str]):
|
||||
"""
|
||||
delete docs in database that not existed in local folder.
|
||||
it is used to delete database docs after user deleted some doc files in file browser
|
||||
"""
|
||||
for kb_name in kb_names:
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
if kb is not None:
|
||||
files_in_db = kb.list_files()
|
||||
files_in_folder = list_files_from_folder(kb_name)
|
||||
files = list(set(files_in_db) - set(files_in_folder))
|
||||
kb_files = file_to_kbfile(kb_name, files)
|
||||
for kb_file in kb_files:
|
||||
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
|
||||
print(f"success to delete docs for file: {kb_name}/{kb_file.filename}")
|
||||
kb.save_vector_store()
|
||||
|
||||
|
||||
def prune_folder_files(kb_names: List[str]):
|
||||
"""
|
||||
delete doc files in local folder that not existed in database.
|
||||
it is used to free local disk space by delete unused doc files.
|
||||
"""
|
||||
for kb_name in kb_names:
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
if kb is not None:
|
||||
files_in_db = kb.list_files()
|
||||
files_in_folder = list_files_from_folder(kb_name)
|
||||
files = list(set(files_in_folder) - set(files_in_db))
|
||||
for file in files:
|
||||
os.remove(get_file_path(kb_name, file))
|
||||
print(f"success to delete file: {kb_name}/{file}")
|
||||
@@ -0,0 +1,10 @@
|
||||
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class DocumentWithVSId(Document):
|
||||
"""
|
||||
矢量化后的文档
|
||||
"""
|
||||
id: str = None
|
||||
score: float = 3.0
|
||||
16
langchain-chat/server/knowledge_base/pdf_convert_url.py
Normal file
16
langchain-chat/server/knowledge_base/pdf_convert_url.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
PDF 转 Markdown 微服务的 HTTP 地址(仅此一处拼 URL,不 import configs,避免旧 .pyc / 错误包名导致仍用 0.0.0.0)。
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
_DEFAULT = "http://127.0.0.1:6006/convert/"
|
||||
|
||||
|
||||
def resolve_pdf_convert_post_url() -> str:
|
||||
u = (os.environ.get("PDF_CONVERT_API_URL") or "").strip() or _DEFAULT
|
||||
u = u.replace("0.0.0.0", "127.0.0.1")
|
||||
if not u.startswith("http"):
|
||||
return _DEFAULT
|
||||
return u
|
||||
580
langchain-chat/server/knowledge_base/utils.py
Normal file
580
langchain-chat/server/knowledge_base/utils.py
Normal file
@@ -0,0 +1,580 @@
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from configs import (
|
||||
KB_ROOT_PATH,
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
ZH_TITLE_ENHANCE,
|
||||
logger,
|
||||
log_verbose,
|
||||
text_splitter_dict,
|
||||
LLM_MODELS,
|
||||
TEXT_SPLITTER_NAME,
|
||||
TEXT_SPLITTER_MAP
|
||||
)
|
||||
import importlib
|
||||
from server.chat.policy_fun_iast import get_llm_model_response_async
|
||||
from server.knowledge_base import kb_service as tr
|
||||
from server.knowledge_base.TexkRank import TextRank
|
||||
from text_splitter import zh_title_enhance as func_zh_title_enhance
|
||||
import langchain.document_loaders
|
||||
from langchain.docstore.document import Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from pathlib import Path
|
||||
from server.utils import run_in_thread_pool, get_model_worker_config
|
||||
import json
|
||||
from typing import List, Union,Dict, Tuple, Generator
|
||||
import chardet
|
||||
import time
|
||||
|
||||
def get_split_time(f):
|
||||
def inner(*arg,**kwarg):
|
||||
s_time = time.time()
|
||||
res = f(*arg,**kwarg)
|
||||
e_time = time.time()
|
||||
print('切片耗时:{}秒'.format(e_time - s_time))
|
||||
return res
|
||||
return inner
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
# 检查是否包含预期外的字符或路径攻击关键字
|
||||
if "../" in knowledge_base_id:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_kb_path(knowledge_base_name: str):
|
||||
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||
|
||||
|
||||
def get_doc_path(knowledge_base_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
||||
|
||||
|
||||
def get_vs_path(knowledge_base_name: str, vector_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name)
|
||||
|
||||
|
||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
||||
|
||||
|
||||
def list_kbs_from_folder():
|
||||
return [f for f in os.listdir(KB_ROOT_PATH)
|
||||
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
|
||||
|
||||
|
||||
def list_files_from_folder(kb_name: str):
|
||||
doc_path = get_doc_path(kb_name)
|
||||
result = []
|
||||
|
||||
def is_skiped_path(path: str):
|
||||
tail = os.path.basename(path).lower()
|
||||
for x in ["temp", "tmp", ".", "~$"]:
|
||||
if tail.startswith(x):
|
||||
return True
|
||||
return False
|
||||
|
||||
def process_entry(entry):
|
||||
if is_skiped_path(entry.path):
|
||||
return
|
||||
|
||||
if entry.is_symlink():
|
||||
target_path = os.path.realpath(entry.path)
|
||||
with os.scandir(target_path) as target_it:
|
||||
for target_entry in target_it:
|
||||
process_entry(target_entry)
|
||||
elif entry.is_file():
|
||||
file_path = (Path(os.path.relpath(entry.path, doc_path)).as_posix()) # 路径统一为 posix 格式
|
||||
result.append(file_path)
|
||||
elif entry.is_dir():
|
||||
with os.scandir(entry.path) as it:
|
||||
for sub_entry in it:
|
||||
process_entry(sub_entry)
|
||||
|
||||
with os.scandir(doc_path) as it:
|
||||
for entry in it:
|
||||
process_entry(entry)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
LOADER_DICT = {"GCYHTMLLoader": ['.html', '.htm'],
|
||||
"GCYWordLoader2": ['.docx', '.doc'],
|
||||
# "GCYWordLoader": ['.docx'], # .doc 解析目前有点问题,暂时关掉
|
||||
"MHTMLLoader": ['.mhtml'],
|
||||
"TextLoader": ['.md', '.txt'],
|
||||
"JSONLoader": [".json"],
|
||||
"JSONLinesLoader": [".jsonl"],
|
||||
"RapidOCRCSVLoader": [".csv"],
|
||||
# "CSVLoader": [".csv"],
|
||||
# "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv
|
||||
"PyMuPDFLoader": [".pdf"],
|
||||
#"RapidOCRDocLoader": ['.docx', '.doc'],
|
||||
"RapidOCRPPTLoader": ['.ppt', '.pptx', ],
|
||||
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
|
||||
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
|
||||
'.rtf', '.xml',
|
||||
'.epub', '.odt','.tsv'],
|
||||
"UnstructuredEmailLoader": ['.eml', '.msg'],
|
||||
"UnstructuredEPubLoader": ['.epub'],
|
||||
"ExcelLoader": ['.xlsx', '.xls', '.xlsd'],
|
||||
"NotebookLoader": ['.ipynb'],
|
||||
"UnstructuredODTLoader": ['.odt'],
|
||||
"PythonLoader": ['.py'],
|
||||
"UnstructuredRSTLoader": ['.rst'],
|
||||
"UnstructuredRTFLoader": ['.rtf'],
|
||||
"SRTLoader": ['.srt'],
|
||||
"TomlLoader": ['.toml'],
|
||||
"UnstructuredTSVLoader": ['.tsv'],
|
||||
"UnstructuredXMLLoader": ['.xml'],
|
||||
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
|
||||
"EverNoteLoader": ['.enex'],
|
||||
}
|
||||
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
|
||||
|
||||
|
||||
# patch json.dumps to disable ensure_ascii
|
||||
def _new_json_dumps(obj, **kwargs):
|
||||
kwargs["ensure_ascii"] = False
|
||||
return _origin_json_dumps(obj, **kwargs)
|
||||
|
||||
if json.dumps is not _new_json_dumps:
|
||||
_origin_json_dumps = json.dumps
|
||||
json.dumps = _new_json_dumps
|
||||
|
||||
|
||||
class JSONLinesLoader(langchain.document_loaders.JSONLoader):
|
||||
'''
|
||||
行式 Json 加载器,要求文件扩展名为 .jsonl
|
||||
'''
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._json_lines = True
|
||||
|
||||
|
||||
langchain.document_loaders.JSONLinesLoader = JSONLinesLoader
|
||||
|
||||
|
||||
def get_LoaderClass(file_extension):
|
||||
for LoaderClass, extensions in LOADER_DICT.items():
|
||||
if file_extension in extensions:
|
||||
return LoaderClass
|
||||
|
||||
def get_SplitterClass(file_extension):
|
||||
"""
|
||||
根据文件类型获取文本分块器类型
|
||||
"""
|
||||
# print('get Splitter Class', file_extension)
|
||||
for SplitterClass, extensions in TEXT_SPLITTER_MAP.items():
|
||||
if file_extension in extensions:
|
||||
return SplitterClass
|
||||
print(f'未找到文件类型"{file_extension}"对应的切分器,使用默认值')
|
||||
return TEXT_SPLITTER_NAME
|
||||
|
||||
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
|
||||
'''
|
||||
根据loader_name和文件路径或内容返回文档加载器。
|
||||
'''
|
||||
loader_kwargs = loader_kwargs or {}
|
||||
try:
|
||||
# print(loader_name)
|
||||
if loader_name in ["RapidOCRLoader", "FilteredCSVLoader",
|
||||
"GCYWordLoader","GCYWordLoader2", "GCYHTMLLoader",
|
||||
"RapidOCRPPTLoader", "RapidOCRCSVLoader","ExcelLoader"]:
|
||||
|
||||
document_loaders_module = importlib.import_module('document_loaders')
|
||||
else:
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, loader_name)
|
||||
except Exception as e:
|
||||
msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
document_loaders_module = importlib.import_module('langchain.document_loaders')
|
||||
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
|
||||
|
||||
if loader_name in ["UnstructuredFileLoader", "TextLoader"]:
|
||||
loader_kwargs.setdefault("autodetect_encoding", True)
|
||||
elif loader_name == "CSVLoader":
|
||||
if not loader_kwargs.get("encoding"):
|
||||
# 如果未指定 encoding,自动识别文件编码类型,避免langchain loader 加载文件报编码错误
|
||||
with open(file_path, 'rb') as struct_file:
|
||||
encode_detect = chardet.detect(struct_file.read())
|
||||
if encode_detect is None:
|
||||
encode_detect = {"encoding": "utf-8"}
|
||||
loader_kwargs["encoding"] = encode_detect["encoding"]
|
||||
|
||||
elif loader_name == "JSONLoader":
|
||||
loader_kwargs.setdefault("jq_schema", ".")
|
||||
loader_kwargs.setdefault("text_content", False)
|
||||
elif loader_name == "JSONLinesLoader":
|
||||
loader_kwargs.setdefault("jq_schema", ".")
|
||||
loader_kwargs.setdefault("text_content", False)
|
||||
|
||||
loader = DocumentLoader(file_path, **loader_kwargs)
|
||||
return loader
|
||||
|
||||
|
||||
def make_text_splitter(
|
||||
splitter_name: str = TEXT_SPLITTER_NAME,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
llm_model: str = LLM_MODELS[0],
|
||||
):
|
||||
"""
|
||||
根据参数获取特定的分词器
|
||||
"""
|
||||
# print('spliter name', splitter_name)
|
||||
splitter_name = splitter_name or "SpacyTextSplitter"
|
||||
try:
|
||||
# if splitter_name == "GCYMarkdownTextSplitter": # MarkdownHeaderTextSplitter特殊判定
|
||||
if splitter_name == "MarkdownTextSplitter": # MarkdownHeaderTextSplitter特殊判定
|
||||
text_splitter_module = importlib.import_module('text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||
headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on']
|
||||
text_splitter = TextSplitter(
|
||||
headers_to_split_on=headers_to_split_on,
|
||||
strip_headers=False, # 不要将标题从分块文本中去掉
|
||||
promote_headers=True
|
||||
)
|
||||
else:
|
||||
|
||||
try: ## 优先使用用户自定义的text_splitter
|
||||
text_splitter_module = importlib.import_module('text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||
except: ## 否则使用langchain的text_splitter
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||
|
||||
if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载
|
||||
try:
|
||||
text_splitter = TextSplitter.from_tiktoken_encoder(
|
||||
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
except:
|
||||
text_splitter = TextSplitter.from_tiktoken_encoder(
|
||||
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
|
||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
||||
config = get_model_worker_config(llm_model)
|
||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
||||
config.get("model_path")
|
||||
|
||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
||||
from transformers import GPT2TokenizerFast
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||
else: ## 字符长度加载
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||
trust_remote_code=True)
|
||||
text_splitter = TextSplitter.from_huggingface_tokenizer(
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
elif text_splitter_dict[splitter_name]["source"] == "no_tokenizer": # IAST 0429: 目前不需要使用分词器
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
else:
|
||||
try:
|
||||
text_splitter = TextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
except:
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||
text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
|
||||
# If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287
|
||||
# text_splitter._tokenizer.max_length = 37016792
|
||||
# text_splitter._tokenizer.prefer_gpu()
|
||||
return text_splitter
|
||||
|
||||
|
||||
class KnowledgeFile:
|
||||
def __init__(
|
||||
self,
|
||||
filename: str,
|
||||
knowledge_base_name: str,
|
||||
loader_kwargs: Dict = {},
|
||||
):
|
||||
'''
|
||||
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
|
||||
'''
|
||||
self.kb_name = knowledge_base_name
|
||||
self.filename = str(Path(filename).as_posix())
|
||||
self.ext = os.path.splitext(filename)[-1].lower()
|
||||
if self.ext not in SUPPORTED_EXTS:
|
||||
raise ValueError(f"暂未支持的文件格式 {self.filename}")
|
||||
self.loader_kwargs = loader_kwargs
|
||||
self.filepath = get_file_path(knowledge_base_name, filename)
|
||||
self.docs = None
|
||||
self.splited_docs = None
|
||||
self.document_loader_name = get_LoaderClass(self.ext)
|
||||
self.text_splitter_name = get_SplitterClass(self.ext)
|
||||
def get_full_text(self) -> Dict[str, str]:
|
||||
"""
|
||||
获取文件的全文内容,并返回文件名和全文内容的结构。
|
||||
"""
|
||||
try:
|
||||
docs = self.file2docs()
|
||||
full_text = "".join([doc.page_content for doc in docs])
|
||||
result = json.dumps( {
|
||||
"filename": self.filename,
|
||||
"full_text": full_text
|
||||
}, ensure_ascii=False)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"获取文件全文内容时出错:{e}", exc_info=e if log_verbose else None)
|
||||
return {
|
||||
"filename": self.filename,
|
||||
"full_text": "加载文件失败或文件内容为空"
|
||||
}
|
||||
async def get_llm_result(self) -> Dict[str, str]:
|
||||
"""
|
||||
根据文件的全文内容,异步调用模型生成文章摘要、关键词和章节速览。
|
||||
"""
|
||||
try:
|
||||
# full_text_data = self.get_full_text()
|
||||
# full_text = full_text_data.get("full_text", "")
|
||||
loop = asyncio.get_event_loop()
|
||||
full_text_data = await loop.run_in_executor(None, self.get_full_text)
|
||||
# full_text = full_text_data.get("full_text", "")
|
||||
try:
|
||||
# 将 JSON 字符串解析为字典
|
||||
full_text_dict = json.loads(full_text_data)
|
||||
full_text = full_text_dict.get("full_text", "")
|
||||
except json.JSONDecodeError:
|
||||
print("解析 JSON 数据时出错")
|
||||
full_text = ""
|
||||
if len(full_text) > 40000:
|
||||
# 判断英文占比
|
||||
# english_chars = re.findall(r'[a-zA-Z]', full_text)
|
||||
# english_ratio = len(english_chars) / len(full_text) if len(full_text) > 0 else 0
|
||||
# if english_ratio > 0.9 and len(full_text) > 50000:
|
||||
full_text = full_text[:40000]
|
||||
|
||||
# logger.info(f'=============文章长度{len(full_text)}')
|
||||
# full_text_80 = TextRank(full_text, 80)
|
||||
# logger.info(f'=============按80句压缩后文章长度{len(full_text_80)}')
|
||||
# if len(full_text_80) > 55000:
|
||||
# full_text_10 = TextRank(full_text_80, num_sentences=10)
|
||||
# logger.info(f'=============按10句压缩后文章长度{len(full_text_10)}')
|
||||
# full_text = full_text_10
|
||||
# else:
|
||||
# full_text = full_text_80
|
||||
else:
|
||||
pass
|
||||
# 异步调用模型
|
||||
from asyncio import gather
|
||||
llm_time = time.time()
|
||||
abstract_task = get_llm_model_response_async(
|
||||
strategy_name="gen_abstract",
|
||||
llm_model_name=LLM_MODELS[1],
|
||||
template_prompt_name="gen_abstract",
|
||||
prompt_param_dict={"context": full_text},
|
||||
temperature=0.7,
|
||||
max_tokens=4096
|
||||
)
|
||||
|
||||
keywords_task = get_llm_model_response_async(
|
||||
strategy_name="gen_keywords",
|
||||
llm_model_name=LLM_MODELS[1],
|
||||
template_prompt_name="gen_keywords",
|
||||
prompt_param_dict={"context": full_text},
|
||||
temperature=0.7,
|
||||
max_tokens=512
|
||||
)
|
||||
|
||||
paragraph_task = get_llm_model_response_async(
|
||||
strategy_name="gen_paragraph",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="gen_paragraph",
|
||||
prompt_param_dict={"context": full_text},
|
||||
temperature=0.7,
|
||||
max_tokens=8192
|
||||
)
|
||||
|
||||
# 并行执行任务
|
||||
article_abstract, article_keywords, article_paragraph = await gather(
|
||||
abstract_task, keywords_task, paragraph_task
|
||||
)
|
||||
|
||||
logger.info(f'生成导读用时:{time.time() - llm_time}')
|
||||
return {
|
||||
"filename": self.filename,
|
||||
"full_text": full_text,
|
||||
"article_abstract": article_abstract,
|
||||
"article_keywords": article_keywords,
|
||||
"article_paragraph": article_paragraph
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"生成LLM结果时出错:{e}", exc_info=e if log_verbose else None)
|
||||
return {
|
||||
"filename": self.filename,
|
||||
"article_abstract": "生成摘要失败",
|
||||
"article_keywords": "生成关键词失败",
|
||||
"article_paragraph": "生成章节速览失败"
|
||||
}
|
||||
def file2docs(self, refresh: bool = False):
|
||||
if self.docs is None or refresh:
|
||||
try:
|
||||
logger.info(f"{self.document_loader_name} used for {self.filepath}")
|
||||
loader = get_loader(loader_name=self.document_loader_name,
|
||||
file_path=self.filepath,
|
||||
loader_kwargs=self.loader_kwargs)
|
||||
self.docs = loader.load()
|
||||
except Exception as e:
|
||||
if self.document_loader_name == 'GCYWordLoader':
|
||||
loader = get_loader(loader_name='GCYWordLoader2',
|
||||
file_path=self.filepath,
|
||||
loader_kwargs=self.loader_kwargs)
|
||||
else:
|
||||
logger.error(f"加载文件 {self.filepath} 时出错:{e}", exc_info=e if log_verbose else None)
|
||||
self.docs = loader.load()
|
||||
return self.docs
|
||||
|
||||
@get_split_time
|
||||
def docs2texts(
|
||||
self,
|
||||
docs: List[Document] = None,
|
||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||
refresh: bool = False,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
text_splitter: TextSplitter = None,
|
||||
):
|
||||
docs = docs or self.file2docs(refresh=refresh)
|
||||
# debug 0429
|
||||
# print('docs2texts',docs )
|
||||
if not docs:
|
||||
return []
|
||||
|
||||
if text_splitter is None:
|
||||
self.text_splitter_name = get_SplitterClass(self.ext)
|
||||
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap)
|
||||
# if self.text_splitter_name == "GCYMarkdownTextSplitter":
|
||||
if self.text_splitter_name == "MarkdownTextSplitter":
|
||||
doc_source = (docs[0].metadata)["source"]
|
||||
docs = text_splitter.split_markdown_text(docs[0].page_content, doc_source)
|
||||
else:
|
||||
docs = text_splitter.split_documents(docs)
|
||||
|
||||
if not docs:
|
||||
return []
|
||||
|
||||
# 检查切分好的文档是否有'h1'标题字段,如果没有,就加上。为之后入库其它有h1的文件做准备
|
||||
if 'h1' not in docs[0].metadata:
|
||||
for doc in docs:
|
||||
doc.metadata['h1'] = ''
|
||||
|
||||
print(f"文档切分示例:{docs[0]}")
|
||||
if zh_title_enhance:
|
||||
docs = func_zh_title_enhance(docs)
|
||||
self.splited_docs = docs
|
||||
return self.splited_docs
|
||||
|
||||
def file2text(
|
||||
self,
|
||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||
refresh: bool = False,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
text_splitter: TextSplitter = None,
|
||||
):
|
||||
if self.splited_docs is None or refresh:
|
||||
docs = self.file2docs()
|
||||
self.splited_docs = self.docs2texts(docs=docs,
|
||||
zh_title_enhance=zh_title_enhance,
|
||||
refresh=refresh,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
text_splitter=text_splitter)
|
||||
return self.splited_docs
|
||||
|
||||
def file_exist(self):
|
||||
return os.path.isfile(self.filepath)
|
||||
|
||||
def get_mtime(self):
|
||||
return os.path.getmtime(self.filepath)
|
||||
|
||||
def get_size(self):
|
||||
return os.path.getsize(self.filepath)
|
||||
|
||||
|
||||
def files2docs_in_thread(
|
||||
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||
) -> Generator:
|
||||
'''
|
||||
利用多线程批量将磁盘文件转化成langchain Document.
|
||||
如果传入参数是Tuple,形式为(filename, kb_name)
|
||||
生成器返回值为 status, (kb_name, file_name, docs | error)
|
||||
'''
|
||||
|
||||
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
|
||||
try:
|
||||
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
|
||||
except Exception as e:
|
||||
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return False, (file.kb_name, file.filename, msg)
|
||||
|
||||
kwargs_list = []
|
||||
for i, file in enumerate(files):
|
||||
kwargs = {}
|
||||
try:
|
||||
if isinstance(file, tuple) and len(file) >= 2:
|
||||
filename = file[0]
|
||||
kb_name = file[1]
|
||||
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
elif isinstance(file, dict):
|
||||
filename = file.pop("filename")
|
||||
kb_name = file.pop("kb_name")
|
||||
kwargs.update(file)
|
||||
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
|
||||
kwargs["file"] = file
|
||||
kwargs["chunk_size"] = chunk_size
|
||||
kwargs["chunk_overlap"] = chunk_overlap
|
||||
kwargs["zh_title_enhance"] = zh_title_enhance
|
||||
kwargs_list.append(kwargs)
|
||||
except Exception as e:
|
||||
yield False, (kb_name, filename, str(e))
|
||||
|
||||
for result in run_in_thread_pool(func=file2docs, params=kwargs_list):
|
||||
yield result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pprint import pprint
|
||||
|
||||
kb_file = KnowledgeFile(
|
||||
filename="/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat/knowledge_base/csv1/content/gm.csv",
|
||||
knowledge_base_name="samples")
|
||||
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
|
||||
docs = kb_file.file2docs()
|
||||
# pprint(docs[-1])
|
||||
Reference in New Issue
Block a user