[全量] 初始化项目代码、配置、文档及Agent协同harness

This commit is contained in:
2026-04-02 11:36:05 +08:00
parent 0553309cdf
commit 87e571d9ec
1133 changed files with 221948 additions and 0 deletions

View File

@@ -0,0 +1,76 @@
import multiprocessing
import re
import time
import networkx as nx
import numpy as np
from textrank4zh import TextRank4Keyword, TextRank4Sentence
from joblib import Parallel, delayed, parallel_backend
import logging
nx.from_numpy_matrix = nx.from_numpy_array
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def process_text_segment(text_segment, num_sentences):
tr4w = TextRank4Keyword()
tr4w.analyze(text=text_segment, lower=True, window=5)
keywords = [(item.word, item.weight) for item in tr4w.get_keywords(30, word_min_len=4)]
tr4s = TextRank4Sentence()
tr4s.analyze(text=text_segment, lower=True, source='all_filters')
summaries = [item.sentence for item in tr4s.get_key_sentences(num=num_sentences)]
return keywords, summaries
def split_text_by_sentences(text, n_parts):
"""Split the text into n_parts based on sentences using regular expressions."""
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text)
k, m = divmod(len(sentences), n_parts)
return [' '.join(sentences[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n_parts)]
from nltk.tokenize import sent_tokenize
def split_text_balanced(text, n_parts):
sentences = sent_tokenize(text)
min_sentences_per_part = 10
n_parts = max(1, min(n_parts, len(sentences) // min_sentences_per_part))
k, m = divmod(len(sentences), n_parts)
return [' '.join(sentences[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n_parts)]
from concurrent.futures import ProcessPoolExecutor, as_completed
def TextRank(text,num_sentences, n_cores=multiprocessing.cpu_count()):
start_time = time.time()
logging.info("TextRank 函数开始执行")
# text_parts = split_text_by_sentences(text, n_cores)
text_parts = split_text_balanced(text, n_cores)
all_keywords = []
all_summaries = []
# with ProcessPoolExecutor (max_workers=n_cores) as executor:
# future_to_part = {executor.submit(process_text_segment, part, num_sentences): part for part in text_parts}
# for future in as_completed(future_to_part):
# keywords, summaries = future.result()
# all_keywords.extend(keywords)
# all_summaries.extend(summaries)
for part in text_parts:
keywords, summaries = process_text_segment(part, num_sentences)
all_keywords.extend(keywords)
all_summaries.extend(summaries)
for word, weight in sorted(all_keywords, key=lambda x: x[1], reverse=True):
print(word, weight)
all_summaries = "".join(all_summaries)
end_time = time.time()
elapsed_time = end_time - start_time
logging.info(f"TextRank 函数执行结束,耗时: {elapsed_time:.2f}")
return all_summaries
if __name__ == '__main__':
# 传入必要的参数
num_sentences = 80
text = """中华人民共和国国民经济和社会发展第十四个五年20212025年规划和2035年远景目标纲要"""
summary = TextRank(text, num_sentences)
print(f"原文长度{len(text)},压缩文本后长度 {len(summary)}")

View File

@@ -0,0 +1,3 @@
# from .kb_api import list_kbs, create_kb, delete_kb
# from .kb_doc_api import list_docs, upload_doc, delete_doc, update_doc, download_doc, recreate_vector_store
# from .utils import KnowledgeFile, KBServiceFactory

View File

@@ -0,0 +1,372 @@
import os
import re
from typing import Optional
from bs4 import BeautifulSoup
from collections import defaultdict
import cssutils
from server.knowledge_base.file_converter import FileConverter
import uuid
import base64
class PdfConverter(FileConverter):
def _clean_pdf_html(self, html: str) -> str:
"""HTML后处理方法"""
soup = BeautifulSoup(html, 'html.parser')
# 处理样式表中的CSS规则
def process_rule(rule):
if rule.type == rule.MEDIA_RULE:
for nested_rule in rule:
process_rule(nested_rule)
elif rule.type == rule.STYLE_RULE:
# 移除文本选择限制属性
for prop in ['user-select', '-webkit-user-select', '-moz-user-select', '-ms-user-select']:
rule.style.removeProperty(prop)
# 原有处理逻辑保持不变
if any('#page-container-1' in selector.selectorText for selector in rule.selectorList):
rule.style.removeProperty('background-color')
rule.style.removeProperty('background-image')
if any(re.search(r'(^|[\s>+~])\.pf($|[\s\[.:>+~])', selector.selectorText)
for selector in rule.selectorList):
for prop in ['box-shadow', 'border-collapse']:
for _ in range(3):
if rule.style.removeProperty(prop):
break
# 处理内联样式
def clean_inline_styles(tag):
if tag.has_attr('style'):
style = cssutils.parseStyle(tag['style'])
# 移除文本选择限制属性
for prop in ['user-select', '-webkit-user-select', '-moz-user-select', '-ms-user-select']:
style.removeProperty(prop)
# 原有处理逻辑保持不变
if tag.get('id') == 'page-container-1':
style.removeProperty('background-color')
style.removeProperty('background-image')
if 'pf' in tag.get('class', []):
style.removeProperty('box-shadow')
style.removeProperty('border-collapse')
tag['style'] = style.cssText.replace('\n', ' ').strip()
if not tag['style']:
del tag['style']
# 清理空的和仅含空格的span标签
for span in soup.find_all('span'):
# 判断是否包含可见内容
if not span.text.strip():
span.decompose()
else:
# 清理内部的空白字符
if span.string and span.string.isspace():
span.string.replace_with(' ')
# 处理包含多个空白文本节点的情况
elif all(isinstance(c, str) and c.isspace() for c in span.contents):
span.replace_with(' ')
# 原有处理流程
for style_tag in soup.find_all('style'):
if style_tag.string:
try:
sheet = cssutils.parseString(style_tag.string)
for rule in sheet:
process_rule(rule)
style_tag.string = sheet.cssText.decode('utf-8')\
.replace('\\n', '\n')\
.replace(' !important', '!important')
except Exception as e:
print(f"CSS处理错误: {str(e)}")
continue
for container in soup.select('#page-container-1'):
clean_inline_styles(container)
for pf_element in soup.select('.pf'):
clean_inline_styles(pf_element)
content = str(soup)
content = self._add_pdf_element_ids(content)
if hasattr(self, 'page_container_id') and self.page_container_id:
new_id = self.page_container_id
head_pattern = re.compile(
r'(<head[^>]*>)(.*?)(</head>)',
re.DOTALL | re.IGNORECASE
)
def replace_head(match):
head_content = match.group(2)
head_content = re.sub(
r'(id\s*=\s*["\']?)page-container(["\'\]>])',
f'\\g<1>{new_id}\\g<2>',
head_content
)
head_content = re.sub(
r'(#[^{\s>]+?{.*?)(\bbackground-(color|image)\s*:[^;]+;?)',
lambda m: m.group(1) if m.group(2) else m.group(0),
head_content,
flags=re.DOTALL|re.IGNORECASE
)
return f"{match.group(1)}{head_content}{match.group(3)}"
content = head_pattern.sub(replace_head, content)
content = re.sub(
r'<script\b[^>]*>[\s\S]*?</script>',
'',
content,
flags=re.IGNORECASE
)
return content.strip()
def _add_pdf_element_ids(self, content: str) -> str:
"""为元素添加唯一ID"""
counters = defaultdict(int)
self.page_container_id = None # 重置ID记录
def replace_tag(match):
tag = match.group(1).lower()
attrs = match.group(2)
# 处理page-container的特殊逻辑
if tag == "div":
id_match = re.search(
r'\bid\s*=\s*["\']page-container["\']',
attrs,
flags=re.IGNORECASE
)
if id_match:
# 生成唯一ID并记录
if not self.page_container_id:
counters['page-container'] += 1
self.page_container_id = f"page-container-{counters['page-container']}"
# 保留其他属性
clean_attrs = re.sub(r'\s+id="[^"]*"', '', attrs)
return f'<div id="{self.page_container_id}"{clean_attrs}>'
# 常规标签处理
counters[tag] += 1
clean_attrs = re.sub(r'\s+id="[^"]*"', '', attrs)
return f'<{tag} id="{tag}-{counters[tag]}"{clean_attrs}>'
# 处理所有目标标签
return re.sub(
r'<(h[1-6]|p|div|span)(\b[^>]*)>',
replace_tag,
content,
flags=re.IGNORECASE
)
def _save_pdf_html(self, content: str, output_path: Optional[str] = None) -> str:
"""统一保存方法"""
cleaned = self._clean_pdf_html(content)
# cleaned = self._add_pdf_element_ids(content)
if output_path:
with open(output_path, 'w', encoding='utf-8') as f:
f.write(cleaned)
return cleaned
# def pdf_to_html(self, input_path: str, output_path: Optional[str] = None) -> str:
# """PDF转换方法"""
# cmd = [
# 'pdf2htmlEX',
# '--zoom', '1.2', # 放大
# '--split-pages', '0', # 保持整体布局
# # '--embed-css', '0', # 避免内联样式冲突
# # '--embed-image', '0', # 避免内联图片冲突
# # '--optimize-text', '1', # 优化文本渲染
# input_path
# ]
# result = subprocess.run(
# 'cd /data3/pdffiles && ' + ' '.join(cmd),
# shell=True,
# stdout=subprocess.PIPE,
# stderr=subprocess.STDOUT
# )
# print(f"转换状态: {result.returncode}\n输出: {result.stdout.decode()[:200]}")
# # 准备文件名
# file_name = os.path.basename(input_path)[:-3] + "html"
# html_path = f"/data3/pdffiles/{file_name}"
# if not os.path.exists(html_path):
# return f"{file_name} 转换失败"
# # 读取并处理HTML内容
# with open(html_path, 'r', encoding='utf-8') as file:
# soup = BeautifulSoup(file, 'html.parser')
# # 移除注释
# for comment in soup.find_all(string=lambda text: isinstance(text, str) and "Created by pdf2htmlEX" in text):
# comment.extract()
# # 移除loading-indicator
# for div in soup.find_all('div', class_='loading-indicator'):
# div.decompose()
# # 移除所有包含sidebar的div
# for div in soup.find_all('div', id=lambda x: x and 'sidebar' in x.lower()):
# div.decompose()
# # 转换为字符串并处理base64
# html_content = str(soup)
# # 清理临时文件
# os.remove(html_path)
# # 处理base64图片
# html_content = self.read_and_replace_base64(
# html_content,
# output_dir={GENERATED_IMAGES_BASE_PATH}
# )
# return f"{self._save_pdf_html(html_content, output_path)}"
def pdf_to_html(self, input_path: str, output_path: Optional[str] = None) -> str:
"""PDF 预览:与基类一致(本机 PyMuPDF 抽文本)。如需后处理可在此包装 super() 结果。"""
return super().pdf_to_html(input_path, output_path)
def read_and_replace_base64(self,html_content, output_dir):
image_index = 0 # 用于生成唯一的文件名
def replace_base64(match):
nonlocal image_index
base64_data = match.group(0)
# 保存 Base64 图片并获取文件路径
# 提取文件类型和实际的 Base64 数据
header, data = base64_data.split(',', 1)
file_extension = header.split(';')[0].split('/')[1] # 获取文件扩展名
file_name = f'image_{uuid.uuid1()}_{image_index}.{file_extension}' # 生成文件名
file_path = os.path.join(output_dir, file_name)
# 将 Base64 数据解码并保存为文件
with open(file_path, 'wb') as image_file:
image_file.write(base64.b64decode(data))
image_index += 1
# 返回文件的 URL
return f"http://127.0.0.1:8099/chat_web_backend/get-image?file_name={os.path.basename(file_path)}"
# 使用正则表达式匹配 Base64 字符串
base64_pattern = r'data:image/(png|jpg|jpeg);base64,[A-Za-z0-9+/=]+'
# base64_pattern = r'data:image/(png|jpg|jpeg);base64,[A-Za-z0-9+/=]+|data:application/font-woff;base64,[A-Za-z0-9+/=]+'
updated_html_content = re.sub(base64_pattern, replace_base64, html_content)
return updated_html_content
# def pdf_to_html(self, input_path: str, output_path: Optional[str] = None) -> str:
# """PDF转换方法"""
# try:
# doc = fitz.open(input_path)
# page_width = doc[0].rect.width
# page_height = doc[0].rect.height
# border_radius = 5
# html = ['<style>','pre { background-color: #2d2d2d;color: #f8f8f2; padding: 10px;margin: 0;width: 80%;box-sizing: border-box;border-radius: 0px;}', '</style>', '<body style="position: relative;">']
# image_save_path = '{GENERATED_IMAGES_BASE_PATH}'
# pic_num =0
# # 确保图片保存路径存在
# os.makedirs(image_save_path, exist_ok=True)
# for page in doc:
# blocks = page.get_text("dict")["blocks"]
# sorted_blocks = sorted(blocks, key=lambda b: (b["bbox"][1], b["bbox"][0])) # 按y坐标和x坐标排序
# for block in sorted_blocks:
# if "image" in block:
# pic_num += 1
# bbox = block["bbox"]
# image_bytes = block["image"]
# image_ext = block["ext"]
# image_name = f'image_{page.number}_{pic_num}.{image_ext}'
# image_url = f'http://127.0.0.1:8099/chat_web_backend/get-image?file_name={image_name}'
# image_path = os.path.join(image_save_path, image_name)
# # 保存图片到指定路径
# with open(image_path, 'wb') as img_file:
# img_file.write(image_bytes)
# percent_left = (bbox[0]) / page_width * 100
# # 获取页面的宽度和高度
# container_width = page_width # 页面宽度
# container_height = page_height # 页面高度
# # 计算图像的宽度和高度
# img_width = bbox[2] - bbox[0] # 计算宽度
# img_height = bbox[3] - bbox[1] # 计算高度
# # 计算百分比
# width_percent = (img_width / container_width) * 100
# height_percent = (img_height / container_height) * 100
# html.append(f'<div style="width: {width_percent}%; height: {height_percent}%; margin-left: {percent_left}%;clear: both;overflow: auto;"><img src="{image_url}" alt="Image {pic_num}" style="max-width: 100%; height: auto;display: block;"/></div>')
# if "lines" in block:
# text_nums = 0
# for line in block["lines"]:
# is_code_block =any(span["font"].startswith(("Courier", "NSimSun")) for span in line["spans"]) # 假设代码使用Courier字体
# if is_code_block:
# html.append(f"<pre>")
# for span in line["spans"]:
# text_nums += 1
# bbox = span["bbox"]
# text = span["text"]
# font = span["font"] # 字体
# size = span["size"] # 字体大小
# color = span["color"] # 字体颜色
# # 动态生成CSS样式
# css_style = f'font-family: {font}; font-size: {size}px; color: #{color:06x};'
# percent_left = (bbox[0]) / page_width * 100
# # 根据字体大小判断标题
# if size > 20: # 假设大于20的字体为标题
# if text_nums == 1:
# html.append(f'<h2 style="{css_style};display: inline; margin-left: {percent_left}%; ">{text.strip()}</h2>')
# else:
# html.append(f'<h3 style="{css_style};display: inline;">{text.strip()}</h3>')
# else:
# if text_nums == 1:
# html.append(f'<p style="{css_style};display: inline; margin-left: {percent_left}%; ">{text.strip()}</p>')
# else:
# html.append(f'<p style="{css_style};display: inline; ">{text.strip()}</p>')
# if is_code_block or size<=20:
# if is_code_block:
# html.append("</pre>")
# else:
# html.append("<br>")
# else:
# html.append('<br>')
# # html.append('<br>')
# html.append('</body>')
# # 将HTML内容保存到指定路径
# html_content = ''.join(html)
# if output_path:
# with open(output_path, 'w', encoding='utf-8') as file:
# file.write(html_content)
# else:
# # 如果没有指定路径使用默认路径或返回HTML内容
# output_path = 'output.html'
# with open(output_path, 'w', encoding='utf-8') as file:
# file.write(html_content)
# return output_path
# except Exception as e:
# raise RuntimeError(f"PDF转换失败: {str(e)}")
# def replace_base64_with_url(self,html_content, output_dir):
# image_index = 0 # 用于生成唯一的文件名
# def replace_base64(match):
# nonlocal image_index
# base64_data = match.group(0)
# # 保存 Base64 图片并获取文件路径
# # 提取文件类型和实际的 Base64 数据
# header, data = base64_data.split(',', 1)
# file_extension = header.split(';')[0].split('/')[1] # 获取文件扩展名
# file_name = f'image_{uuid.uuid1()}_{image_index}.{file_extension}' # 生成文件名
# file_path = os.path.join(output_dir, file_name)
# # 将 Base64 数据解码并保存为文件
# with open(file_path, 'wb') as image_file:
# image_file.write(base64.b64decode(data))
# image_index += 1
# # 返回文件的 URL
# return f"http://127.0.0.1:8099/chat_web_backend/get-image?file_name={os.path.basename(file_path)}"
# # 使用正则表达式匹配 Base64 字符串
# base64_pattern = r'data:image/(png|jpg|jpeg);base64,[A-Za-z0-9+/=]+'
# updated_html_content = re.sub(base64_pattern, replace_base64, html_content)
# return updated_html_content

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,65 @@
import urllib
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_base_repository import list_kbs_from_db
from configs import EMBEDDING_MODEL, logger, log_verbose
from fastapi import Body
def list_kbs():
# Get List of Knowledge Base
return ListResponse(data=list_kbs_from_db())
def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]),
vector_store_type: str = Body("faiss"),
embed_model: str = Body(EMBEDDING_MODEL),
) -> BaseResponse:
# Create selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
if knowledge_base_name is None or knowledge_base_name.strip() == "":
return BaseResponse(code=404, msg="知识库名称不能为空,请重新填写知识库名称")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is not None:
return BaseResponse(code=404, msg=f"已存在同名知识库 {knowledge_base_name}")
kb = KBServiceFactory.get_service(knowledge_base_name, vector_store_type, embed_model)
try:
kb.create_kb()
except Exception as e:
msg = f"创建知识库出错: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}")
def delete_kb(
knowledge_base_name: str = Body(..., examples=["samples"])
) -> BaseResponse:
# Delete selected knowledge base
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
try:
status = kb.clear_vs()
status = kb.drop_kb()
if status:
return BaseResponse(code=200, msg=f"成功删除知识库 {knowledge_base_name}")
except Exception as e:
msg = f"删除知识库时出现意外: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=500, msg=f"删除知识库失败 {knowledge_base_name}")

View File

@@ -0,0 +1,164 @@
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.faiss import FAISS
import threading
from configs import (EMBEDDING_MODEL, CHUNK_SIZE,
logger, log_verbose)
from server.utils import embedding_device, get_model_path, list_online_embed_models, resolve_embed_model_name
from contextlib import contextmanager
from collections import OrderedDict
from typing import List, Any, Union, Tuple
class ThreadSafeObject:
def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None):
self._obj = obj
self._key = key
self._pool = pool
self._lock = threading.RLock()
self._loaded = threading.Event()
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self.key}, obj: {self._obj}>"
@property
def key(self):
return self._key
@contextmanager
def acquire(self, owner: str = "", msg: str = "") -> FAISS:
owner = owner or f"thread {threading.get_native_id()}"
try:
self._lock.acquire()
if self._pool is not None:
self._pool._cache.move_to_end(self.key)
if log_verbose:
logger.info(f"{owner} 开始操作:{self.key}{msg}")
yield self._obj
finally:
if log_verbose:
logger.info(f"{owner} 结束操作:{self.key}{msg}")
self._lock.release()
def start_loading(self):
self._loaded.clear()
def finish_loading(self):
self._loaded.set()
def wait_for_loading(self):
self._loaded.wait()
@property
def obj(self):
return self._obj
@obj.setter
def obj(self, val: Any):
self._obj = val
class CachePool:
def __init__(self, cache_num: int = -1):
self._cache_num = cache_num
self._cache = OrderedDict()
self.atomic = threading.RLock()
def keys(self) -> List[str]:
return list(self._cache.keys())
def _check_count(self):
if isinstance(self._cache_num, int) and self._cache_num > 0:
while len(self._cache) > self._cache_num:
self._cache.popitem(last=False)
def get(self, key: str) -> ThreadSafeObject:
if cache := self._cache.get(key):
cache.wait_for_loading()
return cache
def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject:
self._cache[key] = obj
self._check_count()
return obj
def pop(self, key: str = None) -> ThreadSafeObject:
if key is None:
return self._cache.popitem(last=False)
else:
return self._cache.pop(key, None)
def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""):
cache = self.get(key)
if cache is None:
raise RuntimeError(f"请求的资源 {key} 不存在")
elif isinstance(cache, ThreadSafeObject):
self._cache.move_to_end(key)
return cache.acquire(owner=owner, msg=msg)
else:
return cache
def load_kb_embeddings(
self,
kb_name: str,
embed_device: str = embedding_device(),
default_embed_model: str = EMBEDDING_MODEL,
) -> Embeddings:
from server.db.repository.knowledge_base_repository import get_kb_detail
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
kb_detail = get_kb_detail(kb_name)
embed_model = resolve_embed_model_name(
kb_detail.get("embed_model", default_embed_model)
)
if embed_model in list_online_embed_models():
return EmbeddingsFunAdapter(embed_model)
else:
return embeddings_pool.load_embeddings(model=embed_model, device=embed_device)
class EmbeddingsPool(CachePool):
def load_embeddings(self, model: str = None, device: str = None) -> Embeddings:
self.atomic.acquire()
model = model or EMBEDDING_MODEL
device = embedding_device()
key = (model, device)
if not self.get(key):
item = ThreadSafeObject(key, pool=self)
self.set(key, item)
with item.acquire(msg="初始化"):
self.atomic.release()
if model == "text-embedding-ada-002": # openai text-embedding-ada-002
from langchain.embeddings.openai import OpenAIEmbeddings
embeddings = OpenAIEmbeddings(model=model,
openai_api_key=get_model_path(model),
chunk_size=CHUNK_SIZE)
elif 'bge-' in model:
from langchain.embeddings import HuggingFaceBgeEmbeddings
if 'zh' in model:
# for chinese model
query_instruction = "为这个句子生成表示以用于检索相关文章:"
elif 'en' in model:
# for english model
query_instruction = "Represent this sentence for searching relevant passages:"
else:
# maybe ReRanker or else, just use empty string instead
query_instruction = ""
embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model),
model_kwargs={'device': device},
query_instruction=query_instruction)
if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding
embeddings.query_instruction = ""
else:
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model),
model_kwargs={'device': device})
item.obj = embeddings
item.finish_loading()
else:
self.atomic.release()
return self.get(key).obj
embeddings_pool = EmbeddingsPool(cache_num=1)

View File

@@ -0,0 +1,175 @@
from configs import CACHED_VS_NUM, CACHED_MEMO_VS_NUM
from server.knowledge_base.kb_cache.base import *
from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter
from server.utils import load_local_embeddings
from server.knowledge_base.utils import get_vs_path
from langchain.vectorstores.faiss import FAISS
from langchain.docstore.in_memory import InMemoryDocstore
from langchain.schema import Document
import os
from langchain.schema import Document
# patch FAISS to include doc id in Document.metadata
def _new_ds_search(self, search: str) -> Union[str, Document]:
if search not in self._dict:
return f"ID {search} not found."
else:
doc = self._dict[search]
if isinstance(doc, Document):
doc.metadata["id"] = search
return doc
InMemoryDocstore.search = _new_ds_search
class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str:
cls = type(self).__name__
return f"<{cls}: key: {self.key}, obj: {self._obj}, docs_count: {self.docs_count()}>"
def docs_count(self) -> int:
return len(self._obj.docstore._dict)
def save(self, path: str, create_path: bool = True):
with self.acquire():
if not os.path.isdir(path) and create_path:
os.makedirs(path)
ret = self._obj.save_local(path)
logger.info(f"已将向量库 {self.key} 保存到磁盘")
return ret
def clear(self):
ret = []
with self.acquire():
ids = list(self._obj.docstore._dict.keys())
if ids:
ret = self._obj.delete(ids)
assert len(self._obj.docstore._dict) == 0
logger.info(f"已将向量库 {self.key} 清空")
return ret
class _FaissPool(CachePool):
def new_vector_store(
self,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> FAISS:
embeddings = EmbeddingsFunAdapter(embed_model)
doc = Document(page_content="init", metadata={})
vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
ids = list(vector_store.docstore._dict.keys())
vector_store.delete(ids)
return vector_store
def save_vector_store(self, kb_name: str, path: str=None):
if cache := self.get(kb_name):
return cache.save(path)
def unload_vector_store(self, kb_name: str):
if cache := self.get(kb_name):
self.pop(kb_name)
logger.info(f"成功释放向量库:{kb_name}")
class KBFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
vector_name: str = None,
create: bool = True,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
vector_name = vector_name or embed_model
cache = self.get((kb_name, vector_name)) # 用元组比拼接字符串好一些
if cache is None:
item = ThreadSafeFaiss((kb_name, vector_name), pool=self)
self.set((kb_name, vector_name), item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}/vector_store/{vector_name}' from disk.")
vs_path = get_vs_path(kb_name, vector_name)
if os.path.isfile(os.path.join(vs_path, "index.faiss")):
embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device, default_embed_model=embed_model)
vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True,distance_strategy="METRIC_INNER_PRODUCT")
elif create:
# create an empty vector store
if not os.path.exists(vs_path):
os.makedirs(vs_path)
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
vector_store.save_local(vs_path)
else:
raise RuntimeError(f"knowledge base {kb_name} not exist.")
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get((kb_name, vector_name))
class MemoFaissPool(_FaissPool):
def load_vector_store(
self,
kb_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = embedding_device(),
) -> ThreadSafeFaiss:
self.atomic.acquire()
cache = self.get(kb_name)
if cache is None:
item = ThreadSafeFaiss(kb_name, pool=self)
self.set(kb_name, item)
with item.acquire(msg="初始化"):
self.atomic.release()
logger.info(f"loading vector store in '{kb_name}' to memory.")
# create an empty vector store
vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device)
item.obj = vector_store
item.finish_loading()
else:
self.atomic.release()
return self.get(kb_name)
kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM)
memo_faiss_pool = MemoFaissPool(cache_num=CACHED_MEMO_VS_NUM)
if __name__ == "__main__":
import time, random
from pprint import pprint
kb_names = ["vs1", "vs2", "vs3"]
# for name in kb_names:
# memo_faiss_pool.load_vector_store(name)
def worker(vs_name: str, name: str):
vs_name = "samples"
time.sleep(random.randint(1, 5))
embeddings = load_local_embeddings()
r = random.randint(1, 3)
with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs:
if r == 1: # add docs
ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings)
pprint(ids)
elif r == 2: # search docs
docs = vs.similarity_search_with_score(f"{name}", k=3, score_threshold=1.0)
pprint(docs)
if r == 3: # delete docs
logger.warning(f"清除 {vs_name} by {name}")
kb_faiss_pool.get(vs_name).clear()
threads = []
for n in range(1, 30):
t = threading.Thread(target=worker,
kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"},
daemon=True)
t.start()
threads.append(t)
for t in threads:
t.join()

View File

@@ -0,0 +1,673 @@
import asyncio
import os
import urllib
from fastapi import File, Form, Body, Query, Response, UploadFile
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EXPR,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
logger, log_verbose, POLICY_KNOWLEDGE_BASE)
from configs.model_config import LLM_MODELS
from server.knowledge_base.cleanpdf import PdfConverter
from server.knowledge_base.file_converter import FileConverter
from server.utils import BaseResponse, ListResponse, flatten, run_in_thread_pool
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
files2docs_in_thread, KnowledgeFile)
from fastapi.responses import FileResponse
from sse_starlette import EventSourceResponse
from pydantic import Json
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.repository.knowledge_file_repository import get_file_detail
from langchain.docstore.document import Document
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from typing import List, Dict
from server.chat.policy_fun_iast import get_llm_model_response
from datetime import datetime
def search_docs(
fileName: list = Body([], description="文件名称", examples=["123.txt"]),
query: str = Body("", description="改写后的query", examples=["你好"]),
usr_query: str = Body("", description="用户输入的问题", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD,
description="知识库匹配相关度阈值取值范围在0-1之间"
"SCORE越小相关度越高"
"取到1相当于不筛选建议设置在0.5左右",
ge=0, le=1),
expr: str = Body(EXPR, description="milvus混合检索条件"),
file_name: str = Body("", description="文件名称,支持 sql 通配符"),
metadata: dict = Body({}, description="根据 metadata 进行过滤,仅支持一级键"),
custom_strategy_config: dict = Body({}, description="自定义策略配置"),
query_rewrite_model_name = LLM_MODELS[0]
) -> List[DocumentWithVSId]:
# 获取当前时间并格式化为YYYYMMDD
time = datetime.now().strftime("%Y%m%d")
if POLICY_KNOWLEDGE_BASE in knowledge_base_name:
expr = get_llm_model_response(
strategy_name="get policy time",
llm_model_name=query_rewrite_model_name,
template_prompt_name="get_policy_time",
prompt_param_dict={"query": usr_query, "time": time},
temperature=0.01,
max_tokens=512
).replace("None", "")
print(f'Milvus混合检索表达式{expr}')
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
data = []
if type(expr) is not str:
expr = EXPR
query1 = ""
if kb is not None:
if fileName:
if query:
query1 += "请查询以下几篇文件:" + str(fileName[0]) + "" + query
docs = kb.search_docs(query1, top_k, score_threshold, expr)
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id"))for x in docs if x[0].metadata.get("source") in fileName]
elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata)
else:
if query:
docs = kb.search_docs(query, top_k, score_threshold, expr)
data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata)
return data
def search_self_docs(
fileNames: list = Body([], description="文件名称", examples=["123.txt"]),
query: str = Body("", description="改写后的query", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD,
description="知识库匹配相关度阈值取值范围在0-1之间"
"SCORE越小相关度越高"
"取到1相当于不筛选建议设置在0.5左右",
ge=0, le=1),
expr: str = Body("", description="milvus混合检索条件"),
) -> List[DocumentWithVSId]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
data = []
if fileNames:
# 检查是否存在嵌套列表
if isinstance(fileNames[0], list):
# 如果是嵌套列表,先展平
flat_fileNames = flatten(fileNames)
else:
# 如果不是嵌套列表,直接使用
flat_fileNames = fileNames
else:
flat_fileNames = []
if not expr or not isinstance(expr, str):
if flat_fileNames:
expr = ' || '.join([f'source == "{fileName}"' for fileName in flat_fileNames])
else:
expr = ""
logger.info(f"个人知识库检索EXPR: {expr}")
if kb is not None:
docs = kb.search_docs(query, top_k, score_threshold, expr)
if top_k > 50:
data = docs
else:
data = [
DocumentWithVSId(
**{k: v for k, v in x[0].dict().items() if k != 'page_content'}, # 排除原有的 page_content
score=x[1],
id=x[0].metadata.get("id"),
page_content=f"【^[{index +1}]^ {x[0].page_content}" # 拼接索引和page_content
)
for index, x in enumerate(docs) # 使用enumerate来获取索引
if x[0].metadata.get("source") in flat_fileNames
]
else:
logger.warning(f"未找到知识库服务: {knowledge_base_name}")
return data
def update_docs_by_id(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
docs: Dict[str, Document] = Body(..., description="要更新的文档内容,形如:{id: Document, ...}")
) -> BaseResponse:
'''
按照文档 ID 更新文档内容
'''
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=500, msg=f"指定的知识库 {knowledge_base_name} 不存在")
if kb.update_doc_by_ids(docs=docs):
return BaseResponse(msg=f"文档更新成功")
else:
return BaseResponse(msg=f"文档更新失败")
def list_files(
knowledge_base_name: str
) -> ListResponse:
if not validate_kb_name(knowledge_base_name):
return ListResponse(code=403, msg="Don't attack me", data=[])
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else:
all_doc_names = kb.list_files()
return ListResponse(data=all_doc_names)
def _save_files_in_thread(files: List[UploadFile],
knowledge_base_name: str,
override: bool):
"""
通过多线程将上传的文件保存到对应知识库目录内。
生成器返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
"""
def save_file(file: UploadFile, knowledge_base_name: str, override: bool) -> dict:
'''
保存单个文件。
'''
try:
filename = file.filename
file_path = get_file_path(knowledge_base_name=knowledge_base_name, doc_name=filename)
data = {"knowledge_base_name": knowledge_base_name, "file_name": filename}
file_content = file.file.read() # 读取上传文件的内容
if (os.path.isfile(file_path)
and not override
and os.path.getsize(file_path) == len(file_content)
):
file_status = f"文件 {filename} 已存在。"
logger.warn(file_status)
return dict(code=404, msg=file_status, data=data)
if not os.path.isdir(os.path.dirname(file_path)):
os.makedirs(os.path.dirname(file_path))
with open(file_path, "wb") as f:
f.write(file_content)
return dict(code=200, msg=f"成功上传文件 {filename}", data=data)
except Exception as e:
msg = f"{filename} 文件上传失败,报错信息为: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return dict(code=500, msg=msg, data=data)
params = [{"file": file, "knowledge_base_name": knowledge_base_name, "override": override} for file in files]
for result in run_in_thread_pool(save_file, params=params):
yield result
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
# override: bool = Form(False, description="覆盖已有文件"),
# save: bool = Form(True, description="是否将文件保存到知识库目录")):
# def save_files(files, knowledge_base_name, override):
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
# yield json.dumps(result, ensure_ascii=False)
# def files_to_docs(files):
# for result in files2docs_in_thread(files):
# yield json.dumps(result, ensure_ascii=False)
def upload_docs(
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
docs: Json = Form({}, description="自定义的docs需要转为json字符串",
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
"""
API接口上传文件并/或向量化
"""
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
failed_files = {}
file_names = list(docs.keys())
# 先将上传的文件保存到磁盘
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
filename = result["data"]["file_name"]
if result["code"] != 200:
failed_files[filename] = result["msg"]
if filename not in file_names:
file_names.append(filename)
# 对保存的文件进行向量化
if to_vector_store:
result = update_docs(
knowledge_base_name=knowledge_base_name,
file_names=file_names,
override_custom_docs=True,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance,
docs=docs,
not_refresh_vs_cache=True,
)
failed_files.update(result.data["failed_files"])
if not not_refresh_vs_cache:
kb.save_vector_store()
return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files})
def upload_docs_new(
files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
override: bool = Form(False, description="覆盖已有文件"),
to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"),
chunk_size: int = Form(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Form(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Form(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
docs: Json = Form({}, description="自定义的docs需要转为json字符串",
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
"""
API接口上传文件并/或向量化
"""
import time # 添加计时模块
start_time = time.time()
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
# 自动创建知识库
kb = KBServiceFactory.get_service(knowledge_base_name, DEFAULT_VS_TYPE, EMBEDDING_MODEL)
try:
kb.create_kb()
logger.info(f"自动创建知识库: {knowledge_base_name}")
except Exception as e:
msg = f"创建知识库出错: {e}"
logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None)
return BaseResponse(code=500, msg=msg)
failed_files = {}
file_names = list(docs.keys())
# 生成摘要、关键词、章节速览的结果存储
llm_results = {}
# 先将上传的文件保存到磁盘
for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
filename = result["data"]["file_name"]
if result["code"] != 200:
failed_files[filename] = result["msg"]
if filename not in file_names:
file_names.append(filename)
# 针对成功上传的文件,生成摘要、关键词、章节速览
try:
knowledge_file = KnowledgeFile(filename=filename, knowledge_base_name=knowledge_base_name)
# 使用线程池运行异步函数,避免事件循环冲突
import concurrent.futures
def run_async_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(knowledge_file.get_llm_result())
finally:
new_loop.close()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_async_in_thread)
llm_result = future.result()
llm_results[filename] = {
"full_text": llm_result.get("full_text", "获取全文失败"),
"article_abstract": llm_result.get("article_abstract", "生成摘要失败"),
"article_keywords": llm_result.get("article_keywords", "生成关键词失败"),
"article_paragraph": llm_result.get("article_paragraph", "生成章节速览失败")
}
except Exception as e:
logger.error(f"生成LLM结果时出错{e}", exc_info=e if log_verbose else None)
llm_results[filename] = {
"article_abstract": "生成摘要失败",
"article_keywords": "生成关键词失败",
"article_paragraph": "生成章节速览失败"
}
# 对保存的文件进行向量化
if to_vector_store:
update_st = time.time()
result = _update_docs_impl(
knowledge_base_name=knowledge_base_name,
file_names=file_names,
override_custom_docs=True,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance,
docs=docs,
not_refresh_vs_cache=True,
)
failed_files.update(result.data["failed_files"])
if not not_refresh_vs_cache:
kb.save_vector_store()
logger.info(f'向量化用时:{time.time() - update_st}')
logger.info(f"总执行时间: {time.time() - start_time:.2f}s")
return BaseResponse(code=200, msg="文件上传与向量化完成", data={
"failed_files": failed_files,
"llm_results": llm_results
})
def delete_docs(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]),
delete_content: bool = Body(False),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
failed_files = {}
for file_name in file_names:
if not kb.exist_doc(file_name):
failed_files[file_name] = f"未找到文件 {file_name}"
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
kb.delete_doc(kb_file, delete_content, not_refresh_vs_cache=True)
except Exception as e:
msg = f"{file_name} 文件删除失败,错误信息:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
failed_files[file_name] = msg
if not not_refresh_vs_cache:
kb.save_vector_store()
return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files})
def update_info(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
kb_info: str = Body(..., description="知识库介绍", examples=["这是一个知识库"]),
):
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb.update_info(kb_info)
return BaseResponse(code=200, msg=f"知识库介绍修改完成", data={"kb_info": kb_info})
from time import time
def _update_docs_impl(
knowledge_base_name: str,
file_names: List[str],
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
override_custom_docs: bool = False,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
) -> BaseResponse:
"""
更新知识库文档的核心实现(供内部调用)
"""
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
failed_files = {}
kb_files = []
# 生成需要加载docs的文件列表
for file_name in file_names:
file_detail = get_file_detail(kb_name=knowledge_base_name, filename=file_name)
# 如果该文件之前使用了自定义docs则根据参数决定略过或覆盖
if file_detail.get("custom_docs") and not override_custom_docs:
continue
if file_name not in docs:
try:
kb_files.append(KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name))
except Exception as e:
msg = f"加载文档 {file_name} 时出错:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
failed_files[file_name] = msg
update_st = time()
# 从文件生成docs并进行向量化。
# 这里利用了KnowledgeFile的缓存功能在多线程中加载Document然后传给KnowledgeFile
for status, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if status:
kb_name, file_name, new_docs = result
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
kb_file.splited_docs = new_docs
kb.update_doc(kb_file, not_refresh_vs_cache=True)
else:
kb_name, file_name, error = result
failed_files[file_name] = error
print('use time:', time() - update_st)
# 将自定义的docs进行向量化
for file_name, v in docs.items():
try:
v = [x if isinstance(x, Document) else Document(**x) for x in v]
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=knowledge_base_name)
kb.update_doc(kb_file, docs=v, not_refresh_vs_cache=True)
except Exception as e:
msg = f"{file_name} 添加自定义docs时出错{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
failed_files[file_name] = msg
if not not_refresh_vs_cache:
kb.save_vector_store()
return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files})
def update_docs(
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=[["file_name1", "text.txt"]]),
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"),
docs: Json = Body({}, description="自定义的docs需要转为json字符串",
examples=[{"test.txt": [Document(page_content="custom doc")]}]),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
) -> BaseResponse:
"""
更新知识库文档API 路由)
"""
return _update_docs_impl(
knowledge_base_name=knowledge_base_name,
file_names=file_names,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance,
override_custom_docs=override_custom_docs,
docs=docs,
not_refresh_vs_cache=not_refresh_vs_cache,
)
def download_doc(
knowledge_base_name: str = Query(..., description="知识库名称", examples=["samples"]),
file_name: str = Query(..., description="文件名称", examples=["test.txt"]),
preview: bool = Query(True, description="是:浏览器内预览;否:下载"),
):
"""
下载/预览知识库文档(支持自动转换为 HTML
"""
logger.info(f"是否预览: {preview}")
if not validate_kb_name(knowledge_base_name):
return BaseResponse(code=403, msg="Don't attack me")
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
try:
kb_file = KnowledgeFile(filename=file_name,
knowledge_base_name=knowledge_base_name)
if not os.path.exists(kb_file.filepath):
return BaseResponse(code=404, msg=f"文件 {file_name} 不存在")
# 支持转换的文件类型映射
CONVERT_MAP = {
"pdf": "pdf_to_html",
"docx": "docx_to_html",
"doc": "doc_to_html",
"md": "md_to_html",
"txt": "txt_to_html",
"xlsx": "xlsx_to_html",
"xls": "xls_to_html",
}
# 获取文件扩展名
file_ext = os.path.splitext(file_name)[1].lower().lstrip('.')
# 检查是否需要转换
if file_ext in CONVERT_MAP:
converter = FileConverter()
convert_method = getattr(converter, CONVERT_MAP[file_ext])
try:
# 执行转换并获取 HTML 内容
html_content = convert_method(kb_file.filepath, output_path=None)
if "转换失败" in html_content:
return BaseResponse(code=500, msg=f"文件:{file_name} 处理失败", data=html_content)
# 构造响应参数
new_filename = f"{os.path.splitext(os.path.basename(file_name))[0]}.html"
# 对文件名进行 UTF-8 编码
encoded_filename = urllib.parse.quote(new_filename)
content_disposition = "inline" if preview else f"attachment; filename*=UTF-8''{encoded_filename}"
# 返回 HTML 响应,以文件流形式
return Response(
content=html_content.encode('utf-8'),
media_type="text/html",
headers={
"Content-Disposition": content_disposition,
"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache",
"Expires": "0"
}
)
except RuntimeError as e:
msg = f"文件转换失败: {str(e)}"
logger.error(msg)
return BaseResponse(code=500, msg=msg)
# 不需要转换的文件类型
content_disposition_type = "inline" if preview else "attachment"
encoded_filename = urllib.parse.quote(kb_file.filename)
with open(kb_file.filepath, 'rb') as file:
file_content = file.read()
return Response(
content=file_content if preview else html_content,
media_type="application/octet-stream",
headers={
"Content-Disposition": f"{content_disposition_type}; filename*=UTF-8''{encoded_filename}",
"Cache-Control": "no-cache, no-store, must-revalidate",
"Pragma": "no-cache",
"Expires": "0"
}
)
except Exception as e:
msg = f"{file_name} 处理失败,错误信息是:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return BaseResponse(code=500, msg=msg)
def recreate_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
chunk_size: int = Body(CHUNK_SIZE, description="知识库中单段文本最大长度"),
chunk_overlap: int = Body(OVERLAP_SIZE, description="知识库中相邻文本重合长度"),
zh_title_enhance: bool = Body(ZH_TITLE_ENHANCE, description="是否开启中文标题加强"),
not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库用于FAISS"),
):
"""
recreate vector store from the content.
this is usefull when user can copy files to content folder directly instead of upload through network.
by default, get_service_by_name only return knowledge base in the info.db and having document files in it.
set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents.
"""
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
if kb.exists():
kb.clear_vs()
kb.create_kb()
files = list_files_from_folder(knowledge_base_name)
kb_files = [(file, knowledge_base_name) for file in files]
i = 0
for status, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if status:
kb_name, file_name, docs = result
kb_file = KnowledgeFile(filename=file_name, knowledge_base_name=kb_name)
kb_file.splited_docs = docs
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i + 1,
"doc": file_name,
}, ensure_ascii=False)
kb.add_doc(kb_file, not_refresh_vs_cache=True)
else:
kb_name, file_name, error = result
msg = f"添加文件‘{file_name}’到知识库‘{knowledge_base_name}’时出错:{error}。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
i += 1
if not not_refresh_vs_cache:
kb.save_vector_store()
return EventSourceResponse(output())

View File

@@ -0,0 +1,775 @@
from datetime import datetime
import operator
from abc import ABC, abstractmethod
import re
import os
from pathlib import Path
import numpy as np
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
from joblib import Parallel, delayed
import time
from textrank4zh import TextRank4Keyword, TextRank4Sentence
import multiprocessing
from configs.model_config import LLM_MODELS
from server.chat.policy_fun_iast import get_llm_model_response
from server.chat.utils import get_personal_knowledge_map, get_similar_documents1
from nltk.tokenize import sent_tokenize
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
def generate_weights_as_list(length, total_sum=80):
if length == 1:
return [50]
# Create a range of numbers decreasing logarithmically
x = np.linspace(0, length - 1, length)
weights = np.exp(-x / (length / 5))
# Normalize the weights to match the specified sum
normalized_weights = weights / sum(weights) * total_sum
integer_weights = np.round(normalized_weights).astype(int)
# Adjust the weights to match the exact sum if rounding causes deviation
adjustment = total_sum - sum(integer_weights)
for i in range(abs(adjustment)):
if adjustment > 0:
integer_weights[i % length] += 1
elif adjustment < 0:
integer_weights[i % length] -= 1
return integer_weights.tolist()
def score_threshold_process(query,score_threshold, k, docs):
"""
根据分数阈值过滤和使用TextRank摘要文档并返回前k个文档。
:param score_threshold: 相似度分数阈值;忽略低于此阈值的文档。
:param k: 要返回的顶部文档数量。
:param docs: 文档列表,每个文档是一个元组(文档,相似度分数)。
:return: 根据分数阈值返回的前k个文档的列表。
"""
# 如果提供了score_threshold则只过滤大于阈值的文档。
if score_threshold is not None:
cmp = (
operator.le
)
docs = [
(doc, similarity)
for doc, similarity in docs
if cmp(similarity, score_threshold)
]
# 当召回结果都大于score_threshold时
if len(docs) == 0:
return docs
result = []
try:
for doc in docs:
if query.replace(" ","").replace("\n","").replace("\r","") in doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r",""):
result.append(doc)
except Exception as e:
for doc in docs:
if query.replace(" ","").replace("\n","").replace("\r","") in doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r",""):
result.append(doc)
if len(docs) > 0 and not "h1" in docs[0][0].metadata:
# 如果存在用户的问题在标题中的情况则进行去重操作且不需要再匹配相关度,只需要把问题在标题中的文献提交出去
if len(result) > 0:
temp={}
for doc in result:
if doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") not in temp:
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")] = doc
else:
if temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
continue
elif temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") :
continue
else:
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
docs = []
for i in temp:
docs.append(temp[i])
else:
try:
sentences = []
sentences_page_content = []
for doc in docs:
meta = doc[0].metadata
# 若缺少标题或为空则用正文首句作为标题最多50字
if meta["title"] == "":
meta["title"] = doc[0].page_content
# 如有摘要则替换 page_content保证后续文本更简洁
summary = meta.get("summary")
if summary:
doc[0].page_content = summary
sentences = [doc[0].metadata["title"] for doc in docs]
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["title"]+doc[0].page_content+"" for i,doc in enumerate(docs)]
except Exception as e:
sentences = [doc[0].metadata["source"] for doc in docs]
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"" for i,doc in enumerate(docs)]
res = get_llm_model_response(
strategy_name="default_similar",
llm_model_name=LLM_MODELS[0],
template_prompt_name="default_similar",
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=512
)
try:
index =[]
if res == "":
index = []
else:
index = res.split(",")
index = [int(i)-1 for i in index]
docs = get_similar_documents1(index=index,sentences=sentences,query=query, docs=docs, top_k=k)
except Exception as e:
print(e)
docs = get_similar_documents1(index=[],sentences=sentences,query=query, docs=docs, top_k=k)
# 去重操作只针对通用知识库
temp={}
for doc in docs:
try:
if doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") not in temp:
temp[doc[0].metadata["title"]] = doc
else:
if temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
continue
elif temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") :
continue
else:
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
except Exception as e:
print(e)
if doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") not in temp:
temp[doc[0].metadata["source"]] = doc
else:
if temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
continue
elif temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") :
continue
else:
temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
docs = []
for i in temp:
docs.append(temp[i])
#只针对个人知识库
if "h1" in docs[0][0].metadata:
all_source = [doc[0].metadata["source"] for doc in docs]
unique_source = list(set(all_source))
all_title_map = get_personal_knowledge_map(unique_source)
for doc in docs:
doc[0].metadata["uuid_name"] = doc[0].metadata["source"]
if doc[0].metadata["source"] in all_title_map:
doc[0].metadata["source"] = all_title_map[doc[0].metadata["source"]]
else:
pass
try:
sentences = [doc[0].metadata["source"] for doc in docs]
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"" for i,doc in enumerate(docs)]
except Exception as e:
sentences = [doc[0].metadata["source"] for doc in docs]
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"" for i,doc in enumerate(docs)]
kwargs = {}
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}}
res = get_llm_model_response(
strategy_name="default_similar",
llm_model_name=LLM_MODELS[0],
template_prompt_name="default_similar",
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=None,
**kwargs
)
res = re.sub(r'<think>.*?</think>', '', res,flags=re.DOTALL)
try:
index =[]
if res == "":
index = []
else:
index = res.split(",")
index = [int(i)-1 for i in index]
docs = get_similar_documents1(index=index,sentences=sentences,query=query, docs=docs, top_k=k)
except Exception as e:
print(e)
docs = get_similar_documents1(index=[],sentences=sentences,query=query, docs=docs, top_k=k)
# 去重操作只针对通用知识库
temp={}
for doc in docs:
if doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") not in temp:
temp[doc[0].metadata["source"]] = doc
else:
if temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
continue
elif temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") :
continue
else:
temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
docs = []
for i in temp:
docs.append(temp[i])
if len(docs) == 0:
return docs
# 为TextRank算法生成权重。
cont = generate_weights_as_list(len(docs))
# 处理每个文档以提取或分配摘要。
for i, (doc, _) in enumerate(docs):
summary_sources = ['content', 'abstract', 'text'] # 根据不同知识库遍历字段
for source in summary_sources:
try:
if docs[i][0].metadata["title"] in docs[i][0].page_content or len(docs[i][0].page_content) < 100:
doc.metadata['summary'] = TextRank(doc.metadata[source], cont[i])
if len(doc.metadata['summary']) >15000:
doc.metadata['summary'] = TextRank(doc.metadata[source], 1)
break
else:
doc.metadata['summary'] = docs[i][0].page_content
except KeyError:
doc.metadata['summary'] = docs[i][0].page_content
continue # 如果当前源失败,则尝试下一个源。
# 返回前k个文档。
return docs[:k]
# 猴子补丁为了兼容TexRank
import networkx as nx
import numpy as np
# 进行猴子补丁
nx.from_numpy_matrix = nx.from_numpy_array
# 进行猴子补丁,入数据类型兼容性检查
def process_text_segment(text_segment, num_sentences):
tr4w = TextRank4Keyword()
tr4w.analyze(text=text_segment, lower=True, window=5)
keywords = [(item.word, item.weight) for item in tr4w.get_keywords(30, word_min_len=4)]
tr4s = TextRank4Sentence()
tr4s.analyze(text=text_segment, lower=True, source='all_filters')
summaries = [item.sentence for item in tr4s.get_key_sentences(num=num_sentences)]
return keywords, summaries
def split_text_balanced(text, n_parts):
sentences = sent_tokenize(text)
min_sentences_per_part = 10
n_parts = max(1, min(n_parts, len(sentences) // min_sentences_per_part))
k, m = divmod(len(sentences), n_parts)
return [' '.join(sentences[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n_parts)]
def TextRank(text,num_sentences, n_cores=multiprocessing.cpu_count()):
start_time = time.time()
text_parts = split_text_balanced(text, n_cores)
all_keywords = []
all_summaries = []
# 在这里直接顺序处理每个分块,或用线程池而非进程池
for part in text_parts:
keywords, summaries = process_text_segment(part, num_sentences)
all_keywords.extend(keywords)
all_summaries.extend(summaries)
# Print results
# print('关键词:')
for word, weight in sorted(all_keywords, key=lambda x: x[1], reverse=True):
print(word, weight)
end_time = time.time()
logging.info(f"TextRank耗时: {end_time - start_time:.2f}")
all_summaries = "".join(all_summaries)
return all_summaries
from server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail,
)
from server.db.repository.knowledge_file_repository import (
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
list_docs_from_db,
)
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,EXPR,
EMBEDDING_MODEL, KB_INFO)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder,
)
from typing import List, Union, Dict, Optional, Tuple
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
import time
def get_emb_time(f):
def inner(*arg,**kwarg):
s_time = time.time()
res = f(*arg,**kwarg)
e_time = time.time()
print('向量化耗时:{}'.format(e_time - s_time))
return res
return inner
def normalize(embeddings: List[List[float]]) -> np.ndarray:
'''
sklearn.preprocessing.normalize 的替代(使用 L2避免安装 scipy, scikit-learn
'''
# 过滤掉 None 值
embeddings = [e for e in embeddings if e is not None]
if not embeddings:
raise ValueError("No valid embeddings found (all are None)")
embeddings = np.array(embeddings)
norm = np.linalg.norm(embeddings, axis=1)
norm = np.reshape(norm, (norm.shape[0], 1))
norm = np.tile(norm, (1, len(embeddings[0])))
return np.divide(embeddings, norm)
class SupportedVSType:
FAISS = 'faiss'
MILVUS = 'milvus'
DEFAULT = 'default'
ZILLIZ = 'zilliz'
PG = 'pg'
ES = 'es'
CHROMADB = 'chromadb'
class KBService(ABC):
def __init__(self,
knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL,
):
self.kb_name = knowledge_base_name
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
self.embed_model = embed_model
self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name)
self.do_init()
def __repr__(self) -> str:
return f"{self.kb_name} @ {self.embed_model}"
def save_vector_store(self):
'''
保存向量库:FAISS保存到磁盘milvus保存到数据库。PGVector暂未支持
'''
pass
def create_kb(self):
"""
创建知识库
"""
if not os.path.exists(self.doc_path):
os.makedirs(self.doc_path)
self.do_create_kb()
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
return status
def clear_vs(self):
"""
删除向量库中所有内容
"""
self.do_clear_vs()
status = delete_files_from_db(self.kb_name)
return status
def drop_kb(self):
"""
删除知识库
"""
self.do_drop_kb()
status = delete_kb_from_db(self.kb_name)
return status
def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
'''
将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
'''
return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
@get_emb_time
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
向知识库添加文件
如果指定了docs则不再将文本向量化并将数据库对应条目标为custom_docs=True
"""
if docs:
custom_docs = True
for doc in docs:
doc.metadata.setdefault("source", kb_file.filename)
else:
docs = kb_file.file2text()
custom_docs = False
if docs:
# 将 metadata["source"] 改为相对路径
for doc in docs:
try:
source = doc.metadata.get("source", "")
if os.path.isabs(source):
rel_path = Path(source).relative_to(self.doc_path)
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
except Exception as e:
print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")
self.delete_doc(kb_file)
doc_infos = self.do_add_doc(docs, **kwargs)
status = add_file_to_db(kb_file,
custom_docs=custom_docs,
docs_count=len(docs),
doc_infos=doc_infos)
else:
status = False
return status
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
"""
从知识库删除文件
"""
self.do_delete_doc(kb_file, **kwargs)
status = delete_file_from_db(kb_file)
if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
return status
def update_info(self, kb_info: str):
"""
更新知识库介绍
"""
self.kb_info = kb_info
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
return status
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
使用content中的文件更新向量库
如果指定了docs则使用自定义docs并将数据库对应条目标为custom_docs=True
"""
if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, docs=docs, **kwargs)
def exist_doc(self, file_name: str):
return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
filename=file_name))
def list_files(self):
return list_files_from_db(self.kb_name)
def count_files(self):
return count_files_from_db(self.kb_name)
def search_docs(self,
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
expr: str = EXPR,
custom_strategy_config: dict = {}
) ->List[Document]:
docs = self.do_search(query, top_k, score_threshold, expr, custom_strategy_config)
return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
return []
def get_doc_by_sources_name(self, source_name_list: List[str]) -> List[Document]:
return []
def del_doc_by_ids(self, ids: List[str]) -> bool:
raise NotImplementedError
def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool:
'''
传入参数为: {doc_id: Document, ...}
如果对应 doc_id 的值为 None或其 page_content 为空,则删除该文档
'''
self.del_doc_by_ids(list(docs.keys()))
docs = []
ids = []
for k, v in docs.items():
if not v or not v.page_content.strip():
continue
ids.append(k)
docs.append(v)
self.do_add_doc(docs=docs, ids=ids)
return True
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
'''
通过file_name或metadata检索Document
'''
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
docs = []
for x in doc_infos:
doc_info = self.get_doc_by_ids([x["id"]])[0]
if doc_info is not None:
# 处理非空的情况
doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"])
docs.append(doc_with_id)
else:
# 处理空的情况
# 可以选择跳过当前循环迭代或执行其他操作
pass
return docs
@abstractmethod
def do_create_kb(self):
"""
创建知识库子类实自己逻辑
"""
pass
@staticmethod
def list_kbs_type():
return list(kbs_config.keys())
@classmethod
def list_kbs(cls):
return list_kbs_from_db()
def exists(self, kb_name: str = None):
kb_name = kb_name or self.kb_name
return kb_exists(kb_name)
@abstractmethod
def vs_type(self) -> str:
pass
@abstractmethod
def do_init(self):
pass
@abstractmethod
def do_drop_kb(self):
"""
删除知识库子类实自己逻辑
"""
pass
@abstractmethod
def do_search(self,
query: str,
top_k: int,
score_threshold: float,
expr: str,
custom_strategy_config: dict = {},
) -> List[Tuple[Document, float]]:
"""
搜索知识库子类实自己逻辑
"""
pass
@abstractmethod
def do_add_doc(self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
"""
向知识库添加文档子类实自己逻辑
"""
pass
@abstractmethod
def do_delete_doc(self,
kb_file: KnowledgeFile):
"""
从知识库删除文档子类实自己逻辑
"""
pass
@abstractmethod
def do_clear_vs(self):
"""
从知识库删除全部向量子类实自己逻辑
"""
pass
class KBServiceFactory:
@staticmethod
def get_service(kb_name: str,
vector_store_type: Union[str, SupportedVSType],
embed_model: str = EMBEDDING_MODEL,
) -> KBService:
if isinstance(vector_store_type, str):
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
if SupportedVSType.FAISS == vector_store_type:
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
return FaissKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.PG == vector_store_type:
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
return PGKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.MILVUS == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name,embed_model=embed_model)
elif SupportedVSType.ZILLIZ == vector_store_type:
from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
return ZillizKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name,
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.ES == vector_store_type:
from server.knowledge_base.kb_service.es_kb_service import ESKBService
return ESKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.CHROMADB == vector_store_type:
from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
return ChromaKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)
@staticmethod
def get_service_by_name(kb_name: str) -> KBService:
_, vs_type, embed_model = load_kb_from_db(kb_name)
if _ is None: # kb not in db, just return None
return None
from server.utils import resolve_embed_model_name
embed_model = resolve_embed_model_name(embed_model)
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@staticmethod
def get_default():
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
def get_kb_details() -> List[Dict]:
kbs_in_folder = list_kbs_from_folder()
kbs_in_db = KBService.list_kbs()
result = {}
for kb in kbs_in_folder:
result[kb] = {
"kb_name": kb,
"vs_type": "",
"kb_info": "",
"embed_model": "",
"file_count": 0,
"create_time": None,
"in_folder": True,
"in_db": False,
}
for kb in kbs_in_db:
kb_detail = get_kb_detail(kb)
if kb_detail:
kb_detail["in_db"] = True
if kb in result:
result[kb].update(kb_detail)
else:
kb_detail["in_folder"] = False
result[kb] = kb_detail
data = []
for i, v in enumerate(result.values()):
v['No'] = i + 1
data.append(v)
return data
def get_kb_file_details(kb_name: str) -> List[Dict]:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb is None:
return []
files_in_folder = list_files_from_folder(kb_name)
files_in_db = kb.list_files()
result = {}
for doc in files_in_folder:
result[doc] = {
"kb_name": kb_name,
"file_name": doc,
"file_ext": os.path.splitext(doc)[-1],
"file_version": 0,
"document_loader": "",
"docs_count": 0,
"text_splitter": "",
"create_time": None,
"in_folder": True,
"in_db": False,
}
lower_names = {x.lower(): x for x in result}
for doc in files_in_db:
doc_detail = get_file_detail(kb_name, doc)
if doc_detail:
doc_detail["in_db"] = True
if doc.lower() in lower_names:
result[lower_names[doc.lower()]].update(doc_detail)
else:
doc_detail["in_folder"] = False
result[doc] = doc_detail
data = []
for i, v in enumerate(result.values()):
v['No'] = i + 1
data.append(v)
return data
class EmbeddingsFunAdapter(Embeddings):
def __init__(self, embed_model: str = EMBEDDING_MODEL):
self.embed_model = embed_model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
result = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False)
embeddings = result.data if result and hasattr(result, 'data') else None
if not embeddings:
raise ValueError(f"Failed to get embeddings for texts: {texts[:2]}...")
return normalize(embeddings).tolist()
def embed_query(self, text: str) -> List[float]:
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
return normalize(embeddings).tolist()
async def aembed_query(self, text: str) -> List[float]:
embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
# def score_threshold_process(score_threshold, k, docs):
# if score_threshold is not None:
# cmp = (
# operator.le
# )
# docs = [
# (doc, similarity)
# for doc, similarity in docs
# if cmp(similarity, score_threshold)
# ]
# return docs[:k]

View File

@@ -0,0 +1,105 @@
import uuid
from typing import Any, Dict, List, Tuple
import chromadb
from chromadb.api.types import (GetResult, QueryResult)
from langchain.docstore.document import Document
from configs import SCORE_THRESHOLD
from server.knowledge_base.kb_service.base import (EmbeddingsFunAdapter,
KBService, SupportedVSType)
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
if not get_result['documents']:
return []
_metadatas = get_result['metadatas'] if get_result['metadatas'] else [{}] * len(get_result['documents'])
document_list = []
for page_content, metadata in zip(get_result['documents'], _metadatas):
document_list.append(Document(**{'page_content': page_content, 'metadata': metadata}))
return document_list
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
"""
from langchain_community.vectorstores.chroma import Chroma
"""
return [
# TODO: Chroma can do batch querying,
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
class ChromaKBService(KBService):
vs_path: str
kb_path: str
client = None
collection = None
def vs_type(self) -> str:
return SupportedVSType.CHROMADB
def get_vs_path(self) -> str:
return get_vs_path(self.kb_name, self.embed_model)
def get_kb_path(self) -> str:
return get_kb_path(self.kb_name)
def do_init(self) -> None:
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
self.client = chromadb.PersistentClient(path=self.vs_path)
self.collection = self.client.get_or_create_collection(self.kb_name)
def do_create_kb(self) -> None:
# In ChromaDB, creating a KB is equivalent to creating a collection
self.collection = self.client.get_or_create_collection(self.kb_name)
def do_drop_kb(self):
# Dropping a KB is equivalent to deleting a collection in ChromaDB
try:
self.client.delete_collection(self.kb_name)
except ValueError as e:
if not str(e) == f"Collection {self.kb_name} does not exist.":
raise e
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD, expr: str) -> List[
Tuple[Document, float]]:
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
return _results_to_docs_and_scores(query_result)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
doc_infos = []
data = self._docs_to_embeddings(docs)
ids = [str(uuid.uuid1()) for _ in range(len(data["texts"]))]
for _id, text, embedding, metadata in zip(ids, data["texts"], data["embeddings"], data["metadatas"]):
self.collection.add(ids=_id, embeddings=embedding, metadatas=metadata, documents=text)
doc_infos.append({"id": _id, "metadata": metadata})
return doc_infos
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
get_result: GetResult = self.collection.get(ids=ids)
return _get_result_to_documents(get_result)
def del_doc_by_ids(self, ids: List[str]) -> bool:
self.collection.delete(ids=ids)
return True
def do_clear_vs(self):
# Clearing the vector store might be equivalent to dropping and recreating the collection
self.do_drop_kb()
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
return self.collection.delete(where={"source": kb_file.filepath})

View File

@@ -0,0 +1,38 @@
from typing import List
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from server.knowledge_base.kb_service.base import KBService
class DefaultKBService(KBService):
def do_create_kb(self):
pass
def do_drop_kb(self):
pass
def do_add_doc(self, docs: List[Document]):
pass
def do_clear_vs(self):
pass
def vs_type(self) -> str:
return "default"
def do_init(self):
pass
def do_search(self):
pass
def do_insert_multi_knowledge(self):
pass
def do_insert_one_knowledge(self):
pass
def do_delete_doc(self):
pass

View File

@@ -0,0 +1,261 @@
from typing import List
import os
import shutil
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores.elasticsearch import ElasticsearchStore
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.utils import KnowledgeFile
from server.utils import load_local_embeddings
from elasticsearch import Elasticsearch,BadRequestError
from configs import logger
from configs import kbs_config
class ESKBService(KBService):
def do_init(self):
self.kb_path = self.get_kb_path(self.kb_name)
self.index_name = os.path.split(self.kb_path)[-1]
self.IP = kbs_config[self.vs_type()]['host']
self.PORT = kbs_config[self.vs_type()]['port']
self.user = kbs_config[self.vs_type()].get("user",'')
self.password = kbs_config[self.vs_type()].get("password",'')
self.dims_length = kbs_config[self.vs_type()].get("dims_length",None)
self.embeddings_model = load_local_embeddings(self.embed_model, EMBEDDING_DEVICE)
try:
# ES python客户端连接仅连接
if self.user != "" and self.password != "":
self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}",
basic_auth=(self.user,self.password))
else:
logger.warning("ES未配置用户名和密码")
self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}")
except ConnectionError:
logger.error("连接到 Elasticsearch 失败!")
raise ConnectionError
except Exception as e:
logger.error(f"Error 发生 : {e}")
raise e
try:
# 首先尝试通过es_client_python创建
mappings = {
"properties": {
"dense_vector": {
"type": "dense_vector",
"dims": self.dims_length,
"index": True
}
}
}
self.es_client_python.indices.create(index=self.index_name, mappings=mappings)
except BadRequestError as e:
logger.error("创建索引失败,重新")
logger.error(e)
try:
# langchain ES 连接、创建索引
if self.user != "" and self.password != "":
self.db_init = ElasticsearchStore(
es_url=f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
query_field="context",
vector_query_field="dense_vector",
embedding=self.embeddings_model,
es_user=self.user,
es_password=self.password
)
else:
logger.warning("ES未配置用户名和密码")
self.db_init = ElasticsearchStore(
es_url=f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
query_field="context",
vector_query_field="dense_vector",
embedding=self.embeddings_model,
)
except ConnectionError:
print("### 初始化 Elasticsearch 失败!")
logger.error("### 初始化 Elasticsearch 失败!")
raise ConnectionError
except Exception as e:
logger.error(f"Error 发生 : {e}")
raise e
try:
# 尝试通过db_init创建索引
self.db_init._create_index_if_not_exists(
index_name=self.index_name,
dims_length=self.dims_length
)
except Exception as e:
logger.error("创建索引失败...")
logger.error(e)
# raise e
@staticmethod
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
@staticmethod
def get_vs_path(knowledge_base_name: str):
return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store")
def do_create_kb(self):
if os.path.exists(self.doc_path):
if not os.path.exists(os.path.join(self.kb_path, "vector_store")):
os.makedirs(os.path.join(self.kb_path, "vector_store"))
else:
logger.warning("directory `vector_store` already exists.")
def vs_type(self) -> str:
return SupportedVSType.ES
def _load_es(self, docs, embed_model):
# 将docs写入到ES中
try:
# 连接 + 同时写入文档
if self.user != "" and self.password != "":
self.db = ElasticsearchStore.from_documents(
documents=docs,
embedding=embed_model,
es_url= f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
distance_strategy="COSINE",
query_field="context",
vector_query_field="dense_vector",
verify_certs=False,
es_user=self.user,
es_password=self.password
)
else:
self.db = ElasticsearchStore.from_documents(
documents=docs,
embedding=embed_model,
es_url= f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
distance_strategy="COSINE",
query_field="context",
vector_query_field="dense_vector",
verify_certs=False)
except ConnectionError as ce:
print(ce)
print("连接到 Elasticsearch 失败!")
logger.error("连接到 Elasticsearch 失败!")
except Exception as e:
logger.error(f"Error 发生 : {e}")
print(e)
def do_search(self, query:str, top_k: int, score_threshold: float,expr:str):
# 文本相似性检索
docs = self.db_init.similarity_search_with_score(query=query,
k=top_k)
return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
results = []
for doc_id in ids:
try:
response = self.es_client_python.get(index=self.index_name, id=doc_id)
source = response["_source"]
# Assuming your document has "text" and "metadata" fields
text = source.get("context", "")
metadata = source.get("metadata", {})
results.append(Document(page_content=text, metadata=metadata))
except Exception as e:
logger.error(f"Error retrieving document from Elasticsearch! {e}")
return results
def del_doc_by_ids(self, ids: List[str]) -> bool:
for doc_id in ids:
try:
self.es_client_python.delete(index=self.index_name,
id=doc_id,
refresh=True)
except Exception as e:
logger.error(f"ES Docs Delete Error! {e}")
def do_delete_doc(self, kb_file, **kwargs):
if self.es_client_python.indices.exists(index=self.index_name):
# 从向量数据库中删除索引(文档名称是Keyword)
query = {
"query": {
"term": {
"metadata.source.keyword": kb_file.filepath
}
}
}
# 注意设置size默认返回10个。
search_results = self.es_client_python.search(body=query, size=50)
delete_list = [hit["_id"] for hit in search_results['hits']['hits']]
if len(delete_list) == 0:
return None
else:
for doc_id in delete_list:
try:
self.es_client_python.delete(index=self.index_name,
id=doc_id,
refresh=True)
except Exception as e:
logger.error(f"ES Docs Delete Error! {e}")
# self.db_init.delete(ids=delete_list)
#self.es_client_python.indices.refresh(index=self.index_name)
def do_add_doc(self, docs: List[Document], **kwargs):
'''向知识库添加文件'''
print(f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}")
print("*"*100)
self._load_es(docs=docs, embed_model=self.embeddings_model)
# 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...]
print("写入数据成功.")
print("*"*100)
if self.es_client_python.indices.exists(index=self.index_name):
file_path = docs[0].metadata.get("source")
query = {
"query": {
"term": {
"metadata.source.keyword": file_path
},
"term": {
"_index": self.index_name
}
}
}
# 注意设置size默认返回10个。
search_results = self.es_client_python.search(body=query, size=50)
if len(search_results["hits"]["hits"]) == 0:
raise ValueError("召回元素个数为0")
info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]]
return info_docs
def do_clear_vs(self):
"""从知识库删除全部向量"""
if self.es_client_python.indices.exists(index=self.kb_name):
self.es_client_python.indices.delete(index=self.kb_name)
def do_drop_kb(self):
"""删除知识库"""
# self.kb_file: 知识库路径
if os.path.exists(self.kb_path):
shutil.rmtree(self.kb_path)
if __name__ == '__main__':
esKBService = ESKBService("test")
#esKBService.clear_vs()
#esKBService.create_kb()
esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test"))
print(esKBService.search_docs("如何启动api服务"))

View File

@@ -0,0 +1,124 @@
import os
import shutil
from configs import SCORE_THRESHOLD, EXPR
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from server.utils import torch_gc
from langchain.docstore.document import Document
from typing import List, Dict, Optional, Tuple
class FaissKBService(KBService):
vs_path: str
kb_path: str
vector_name: str = None
def vs_type(self) -> str:
return SupportedVSType.FAISS
def get_vs_path(self):
return get_vs_path(self.kb_name, self.vector_name)
def get_kb_path(self):
return get_kb_path(self.kb_name)
def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
vector_name=self.vector_name,
embed_model=self.embed_model)
def save_vector_store(self):
self.load_vector_store().save(self.vs_path)
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
with self.load_vector_store().acquire() as vs:
return [vs.docstore._dict.get(id) for id in ids]
def del_doc_by_ids(self, ids: List[str]) -> bool:
with self.load_vector_store().acquire() as vs:
vs.delete(ids)
def do_init(self):
self.vector_name = self.vector_name or self.embed_model
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
self.load_vector_store()
def do_drop_kb(self):
self.clear_vs()
try:
shutil.rmtree(self.kb_path)
except Exception:
...
def do_search(self,
query: str,
top_k: int,
score_threshold: float = SCORE_THRESHOLD,
expr: str = EXPR,
) -> List[Tuple[Document, float]]:
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
return docs
def do_add_doc(self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间
with self.load_vector_store().acquire() as vs:
ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
metadatas=data["metadatas"],
ids=kwargs.get("ids"))
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
torch_gc()
return doc_infos
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
with self.load_vector_store().acquire() as vs:
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source").lower() == kb_file.filename.lower()]
if len(ids) > 0:
vs.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
return ids
def do_clear_vs(self):
with kb_faiss_pool.atomic:
kb_faiss_pool.pop((self.kb_name, self.vector_name))
try:
shutil.rmtree(self.vs_path)
except Exception:
...
os.makedirs(self.vs_path, exist_ok=True)
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):
return "in_db"
content_path = os.path.join(self.kb_path, "content")
if os.path.isfile(os.path.join(content_path, file_name)):
return "in_folder"
else:
return False
if __name__ == '__main__':
faissService = FaissKBService("test")
faissService.add_doc(KnowledgeFile("README.md", "test"))
faissService.delete_doc(KnowledgeFile("README.md", "test"))
faissService.do_drop_kb()
print(faissService.search_docs("如何启动api服务"))

View File

@@ -0,0 +1,207 @@
from typing import List, Dict, Optional
from langchain.schema import Document
from langchain.vectorstores.milvus import Milvus
import os
import logging
from configs import kbs_config
from server.db.repository import list_file_num_docs_id_by_kb_name_and_file_name
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
import numpy as np
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MilvusKBService(KBService):
milvus: Milvus
@staticmethod
def get_collection(milvus_name):
from pymilvus import Collection
return Collection(milvus_name)
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = []
if self.milvus and self.milvus.col:
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
data_list = self.milvus.col.query(expr=f'pk in {[int(_id) for _id in ids]}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
result.append(Document(page_content=text, metadata=data))
return result
def get_doc_by_sources_name(self, source_name_list: List[str]) -> List[Document]:
result = []
if self.milvus and self.milvus.col:
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
data_list = self.milvus.col.query(expr=f'source in {source_name_list}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
result.append(Document(page_content=text, metadata=data))
return result
def del_doc_by_ids(self, ids: List[str]) -> bool:
if self.milvus and self.milvus.col:
self.milvus.col.delete(expr=f'pk in {ids}')
@staticmethod
def search(milvus_name, content, limit=3):
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
c = MilvusKBService.get_collection(milvus_name)
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
def do_create_kb(self):
pass
def vs_type(self) -> str:
return SupportedVSType.MILVUS
def _load_milvus(self):
try:
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name,
connection_args=kbs_config.get("milvus"),
index_params=kbs_config.get("milvus_kwargs")["index_params"],
search_params=kbs_config.get("milvus_kwargs")["search_params"],
auto_id=True
)
logger.info("成功加载 Milvus 实例 'milvus'")
# -------- 兼容不同 schema 的文本字段 --------
# 新库尚无 Milvus 集合时 langchain_community.Milvus.col 为 None
# 会在首次 add_documents 建表后再有 schema此处勿访问 .col.schema。
try:
col = self.milvus.col
if col is None:
logger.debug(
"集合 %s 尚未在 Milvus 中建表,跳过文本字段探测(首次写入时会自动创建)",
self.kb_name,
)
else:
field_names = [f.name for f in col.schema.fields]
if self.milvus._text_field not in field_names:
if "page_content" in field_names:
self.milvus._text_field = "page_content"
elif "content" in field_names:
self.milvus._text_field = "content"
else:
for f in col.schema.fields:
if hasattr(f, "dtype") and str(f.dtype).startswith("DataType.VARCHAR"):
self.milvus._text_field = f.name
break
logger.info(f"集合 {self.kb_name} 使用文本字段: {self.milvus._text_field}")
except Exception as e:
logger.warning(f"检测并设置文本字段失败: {e}")
except Exception as e:
logger.error(f"加载 Milvus 实例 'milvus' 失败: {e}")
self._create_collection_if_not_exists()
# 重新加载
# self._load_milvus()
def _create_collection_if_not_exists(self):
"""根据传入字段创建 Milvus 集合"""
from pymilvus import Collection, CollectionSchema, FieldSchema, DataType
from langchain_community.vectorstores import Milvus
# 定义你的字段(根据你的需求修改)
fields = [
FieldSchema(name="pk", dtype=DataType.Int64, is_primary=True, auto_id=True),
FieldSchema(name="vector", dtype=DataType.FloatVector, dim=768), # dim 根据 embedding 模型调整
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=1024),
FieldSchema(name="metadata", dtype=DataType.VARCHAR, max_length=65535),
# 添加其他自定义字段...
]
schema = CollectionSchema(fields=fields, description=self.kb_name)
# 创建集合
collection = Collection(name=self.kb_name, schema=schema, using="default")
# 创建索引
index_params = kbs_config.get("milvus_kwargs")["index_params"]
collection.create_index(field_name="vector", index_params=index_params)
logger.info(f"成功创建 Milvus 集合: {self.kb_name}")
def do_init(self):
self._load_milvus()
def do_drop_kb(self):
if self.milvus and self.milvus.col:
self.milvus.col.release()
# self.milvus.col.drop() # 禁用从chatchat删除集合
def do_search(self, query: str, top_k: int, score_threshold: float, expr: str, custom_strategy_config: dict = {}):
self._load_milvus()
embed_func = EmbeddingsFunAdapter(self.embed_model)
try:
embeddings = embed_func.embed_query(query)
if top_k > 50:
# 按顺序返回全文内容
docs = self.milvus.similarity_search_by_vector(embeddings, top_k, expr = expr)
docs = sorted(docs, key=lambda doc: doc.metadata['pk']) # 根据 pk 从小到大排序
# return score_threshold_process(query,score_threshold, top_k, docs)
return docs
else:
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k, expr = expr)
# TODO 动态score_threshold
return score_threshold_process(query,score_threshold, top_k, docs)
except Exception as e:
logger.error(f"搜索 Milvus 集合 '{self.kb_name}' 失败: {e}")
return []
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs:
for k, v in doc.metadata.items():
doc.metadata[k] = str(v)
for field in self.milvus.fields:
doc.metadata.setdefault(field, "")
doc.metadata.pop(self.milvus._text_field, None)
doc.metadata.pop(self.milvus._vector_field, None)
ids = self.milvus.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
id_list = list_file_num_docs_id_by_kb_name_and_file_name(kb_file.kb_name, kb_file.filename)
if self.milvus and self.milvus.col:
self.milvus.col.delete(expr=f'pk in {id_list}')
# Issue 2846, for windows
# if self.milvus.col:
# file_path = kb_file.filepath.replace("\\", "\\\\")
# file_name = os.path.basename(file_path)
# id_list = [item.get("pk") for item in
# self.milvus.col.query(expr=f'source == "{file_name}"', output_fields=["pk"])]
# self.milvus.col.delete(expr=f'pk in {id_list}')
def do_clear_vs(self):
if self.milvus and self.milvus.col:
self.do_drop_kb()
self.do_init()
if __name__ == '__main__':
# 测试建表使用
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
milvusService = MilvusKBService("t_policy_total_bce_v1")
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
# print(milvusService.get_doc_by_ids(["444022434274215486"]))
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
# milvusService.do_drop_kb()
# print(milvusService.search_docs("如何启动api服务"))

View File

@@ -0,0 +1,96 @@
import json
from typing import List, Dict, Optional
from langchain.schema import Document
from langchain.vectorstores.pgvector import PGVector, DistanceStrategy
from sqlalchemy import text
from configs import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
import shutil
import sqlalchemy
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
class PGKBService(KBService):
engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10)
def _load_pg_vector(self):
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection=PGKBService.engine,
connection_string=kbs_config.get("pg").get("connection_uri"))
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
with Session(PGKBService.engine) as session:
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
results = [Document(page_content=row[0], metadata=row[1]) for row in
session.execute(stmt, {'ids': ids}).fetchall()]
return results
def del_doc_by_ids(self, ids: List[str]) -> bool:
return super().del_doc_by_ids(ids)
def do_init(self):
self._load_pg_vector()
def do_create_kb(self):
pass
def vs_type(self) -> str:
return SupportedVSType.PG
def do_drop_kb(self):
with Session(PGKBService.engine) as session:
session.execute(text(f'''
-- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录
DELETE FROM langchain_pg_embedding
WHERE collection_id IN (
SELECT uuid FROM langchain_pg_collection WHERE name = '{self.kb_name}'
);
-- 删除 langchain_pg_collection 表中 记录
DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}';
'''))
session.commit()
shutil.rmtree(self.kb_path)
def do_search(self, query: str, top_k: int, score_threshold: float,expr:str):
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with Session(PGKBService.engine) as session:
filepath = kb_file.filepath.replace('\\', '\\\\')
session.execute(
text(
''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace(
"filepath", filepath)))
session.commit()
def do_clear_vs(self):
self.pg_vector.delete_collection()
self.pg_vector.create_collection()
if __name__ == '__main__':
from server.db.base import Base, engine
# Base.metadata.create_all(bind=engine)
pGKBService = PGKBService("test")
# pGKBService.create_kb()
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
# pGKBService.drop_kb()
print(pGKBService.get_doc_by_ids(["f1e51390-3029-4a19-90dc-7118aaa25772"]))
# print(pGKBService.search_docs("如何启动api服务"))

View File

@@ -0,0 +1,97 @@
from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import Zilliz
from configs import kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
class ZillizKBService(KBService):
zilliz: Zilliz
@staticmethod
def get_collection(zilliz_name):
from pymilvus import Collection
return Collection(zilliz_name)
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = []
if self.zilliz.col:
# ids = [int(id) for id in ids] # for zilliz if needed #pr 2725
data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
result.append(Document(page_content=text, metadata=data))
return result
def del_doc_by_ids(self, ids: List[str]) -> bool:
self.zilliz.col.delete(expr=f'pk in {ids}')
@staticmethod
def search(zilliz_name, content, limit=3):
search_params = {
"metric_type": "IP",
"params": {},
}
c = ZillizKBService.get_collection(zilliz_name)
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
def do_create_kb(self):
pass
def vs_type(self) -> str:
return SupportedVSType.ZILLIZ
def _load_zilliz(self):
zilliz_args = kbs_config.get("zilliz")
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, connection_args=zilliz_args)
def do_init(self):
self._load_zilliz()
def do_drop_kb(self):
if self.zilliz.col:
self.zilliz.col.release()
self.zilliz.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float,expr:str):
self._load_zilliz()
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs:
for k, v in doc.metadata.items():
doc.metadata[k] = str(v)
for field in self.zilliz.fields:
doc.metadata.setdefault(field, "")
doc.metadata.pop(self.zilliz._text_field, None)
doc.metadata.pop(self.zilliz._vector_field, None)
ids = self.zilliz.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
if self.zilliz.col:
filepath = kb_file.filepath.replace('\\', '\\\\')
delete_list = [item.get("pk") for item in
self.zilliz.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
self.zilliz.col.delete(expr=f'pk in {delete_list}')
def do_clear_vs(self):
if self.zilliz.col:
self.do_drop_kb()
self.do_init()
if __name__ == '__main__':
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
zillizService = ZillizKBService("test")

View File

@@ -0,0 +1,78 @@
from typing import List
from configs import (
EMBEDDING_MODEL,
KB_ROOT_PATH)
from abc import ABC, abstractmethod
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
import os
import shutil
from server.db.repository.knowledge_metadata_repository import add_summary_to_db, delete_summary_from_db
from langchain.docstore.document import Document
class KBSummaryService(ABC):
kb_name: str
embed_model: str
vs_path: str
kb_path: str
def __init__(self,
knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL
):
self.kb_name = knowledge_base_name
self.embed_model = embed_model
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
def get_vs_path(self):
return os.path.join(self.get_kb_path(), "summary_vector_store")
def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
vector_name="summary_vector_store",
embed_model=self.embed_model,
create=True)
def add_kb_summary(self, summary_combine_docs: List[Document]):
with self.load_vector_store().acquire() as vs:
ids = vs.add_documents(documents=summary_combine_docs)
vs.save_local(self.vs_path)
summary_infos = [{"summary_context": doc.page_content,
"summary_id": id,
"doc_ids": doc.metadata.get('doc_ids'),
"metadata": doc.metadata} for id, doc in zip(ids, summary_combine_docs)]
status = add_summary_to_db(kb_name=self.kb_name, summary_infos=summary_infos)
return status
def create_kb_summary(self):
"""
创建知识库chunk summary
:return:
"""
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
def drop_kb_summary(self):
"""
删除知识库chunk summary
:param kb_name:
:return:
"""
with kb_faiss_pool.atomic:
kb_faiss_pool.pop(self.kb_name)
shutil.rmtree(self.vs_path)
delete_summary_from_db(kb_name=self.kb_name)

View File

@@ -0,0 +1,241 @@
from typing import List, Optional
from langchain.schema.language_model import BaseLanguageModel
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
from configs import (logger)
from langchain.chains import StuffDocumentsChain, LLMChain
from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document
from langchain.output_parsers.regex import RegexParser
from langchain.chains.combine_documents.map_reduce import ReduceDocumentsChain, MapReduceDocumentsChain
import sys
import asyncio
class SummaryAdapter:
_OVERLAP_SIZE: int
token_max: int
_separator: str = "\n\n"
chain: MapReduceDocumentsChain
def __init__(self, overlap_size: int, token_max: int,
chain: MapReduceDocumentsChain):
self._OVERLAP_SIZE = overlap_size
self.chain = chain
self.token_max = token_max
@classmethod
def form_summary(cls,
llm: BaseLanguageModel,
reduce_llm: BaseLanguageModel,
overlap_size: int,
token_max: int = 1300):
"""
获取实例
:param reduce_llm: 用于合并摘要的llm
:param llm: 用于生成摘要的llm
:param overlap_size: 重叠部分大小
:param token_max: 最大的chunk数量每个chunk长度小于token_max长度第一次生成摘要时大于token_max长度的摘要会报错
:return:
"""
# This controls how each document will be formatted. Specifically,
document_prompt = PromptTemplate(
input_variables=["page_content"],
template="{page_content}"
)
# The prompt here should take as an input variable the
# `document_variable_name`
prompt_template = (
"根据文本执行任务。以下任务信息"
"{task_briefing}"
"文本内容如下: "
"\r\n"
"{context}"
)
prompt = PromptTemplate(
template=prompt_template,
input_variables=["task_briefing", "context"]
)
llm_chain = LLMChain(llm=llm, prompt=prompt)
# We now define how to combine these summaries
reduce_prompt = PromptTemplate.from_template(
"Combine these summaries: {context}"
)
reduce_llm_chain = LLMChain(llm=reduce_llm, prompt=reduce_prompt)
document_variable_name = "context"
combine_documents_chain = StuffDocumentsChain(
llm_chain=reduce_llm_chain,
document_prompt=document_prompt,
document_variable_name=document_variable_name
)
reduce_documents_chain = ReduceDocumentsChain(
token_max=token_max,
combine_documents_chain=combine_documents_chain,
)
chain = MapReduceDocumentsChain(
llm_chain=llm_chain,
document_variable_name=document_variable_name,
reduce_documents_chain=reduce_documents_chain,
# 返回中间步骤
return_intermediate_steps=True
)
return cls(overlap_size=overlap_size,
chain=chain,
token_max=token_max)
def summarize(self,
file_description: str,
docs: List[DocumentWithVSId] = []
) -> List[Document]:
if sys.version_info < (3, 10):
loop = asyncio.get_event_loop()
else:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 同步调用协程代码
return loop.run_until_complete(self.asummarize(file_description=file_description,
docs=docs))
async def asummarize(self,
file_description: str,
docs: List[DocumentWithVSId] = []) -> List[Document]:
logger.info("start summary")
"""
这个过程分成两个部分:
1. 对每个文档进行处理,得到每个文档的摘要
map_results = self.llm_chain.apply(
# FYI - this is parallelized and so it is fast.
[{self.document_variable_name: d.page_content, **kwargs} for d in docs],
callbacks=callbacks,
)
2. 对每个文档的摘要进行合并得到最终的摘要return_intermediate_steps=True返回中间步骤
result, extra_return_dict = self.reduce_documents_chain.combine_docs(
result_docs, token_max=token_max, callbacks=callbacks, **kwargs
)
"""
summary_combine, summary_intermediate_steps = self.chain.combine_docs(docs=docs,
task_briefing="描述不同方法之间的接近度和相似性,"
"以帮助读者理解它们之间的关系。")
print(summary_combine)
print(summary_intermediate_steps)
# if len(summary_combine) == 0:
# # 为空重新生成,数量减半
# result_docs = [
# Document(page_content=question_result_key, metadata=docs[i].metadata)
# # This uses metadata from the docs, and the textual results from `results`
# for i, question_result_key in enumerate(
# summary_intermediate_steps["intermediate_steps"][
# :len(summary_intermediate_steps["intermediate_steps"]) // 2
# ])
# ]
# summary_combine, summary_intermediate_steps = self.chain.reduce_documents_chain.combine_docs(
# result_docs, token_max=self.token_max
# )
logger.info("end summary")
doc_ids = ",".join([doc.id for doc in docs])
_metadata = {
"file_description": file_description,
"summary_intermediate_steps": summary_intermediate_steps,
"doc_ids": doc_ids
}
summary_combine_doc = Document(page_content=summary_combine, metadata=_metadata)
return [summary_combine_doc]
def _drop_overlap(self, docs: List[DocumentWithVSId]) -> List[str]:
"""
# 将文档中page_content句子叠加的部分去掉
:param docs:
:param separator:
:return:
"""
merge_docs = []
pre_doc = None
for doc in docs:
# 第一个文档直接添加
if len(merge_docs) == 0:
pre_doc = doc.page_content
merge_docs.append(doc.page_content)
continue
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
# 迭代递减pre_doc的长度每次迭代删除前面的字符
# 查询重叠部分直到pre_doc的长度小于 self._OVERLAP_SIZE // 2 - 2len(separator)
for i in range(len(pre_doc), self._OVERLAP_SIZE // 2 - 2 * len(self._separator), -1):
# 每次迭代删除前面的字符
pre_doc = pre_doc[1:]
if doc.page_content[:len(pre_doc)] == pre_doc:
# 删除下一个开头重叠的部分
merge_docs.append(doc.page_content[len(pre_doc):])
break
pre_doc = doc.page_content
return merge_docs
def _join_docs(self, docs: List[str]) -> Optional[str]:
text = self._separator.join(docs)
text = text.strip()
if text == "":
return None
else:
return text
if __name__ == '__main__':
docs = [
'梦者有特别的作用,也就是说梦是在预卜未来。因此,梦内容的',
'梦内容的多彩多姿以及对梦者本身所遗留的特殊印象,使他们很难想象',
'使他们很难想象出一套系统划一的观念,而需要以其个别的价值与可靠性作各',
'值与可靠性作各种不同的分化与聚合。因此,古代哲学家们对梦的评价也就完全'
]
_OVERLAP_SIZE = 1
separator: str = "\n\n"
merge_docs = []
# 将文档中page_content句子叠加的部分去掉
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
pre_doc = None
for doc in docs:
# 第一个文档直接添加
if len(merge_docs) == 0:
pre_doc = doc
merge_docs.append(doc)
continue
# 列表中上一个结尾与下一个开头重叠的部分,删除下一个开头重叠的部分
# 迭代递减pre_doc的长度每次迭代删除前面的字符
# 查询重叠部分直到pre_doc的长度小于 _OVERLAP_SIZE-2len(separator)
for i in range(len(pre_doc), _OVERLAP_SIZE - 2 * len(separator), -1):
# 每次迭代删除前面的字符
pre_doc = pre_doc[1:]
if doc[:len(pre_doc)] == pre_doc:
# 删除下一个开头重叠的部分
page_content = doc[len(pre_doc):]
merge_docs.append(page_content)
pre_doc = doc
break
# 将merge_docs中的句子合并成一个文档
text = separator.join(merge_docs)
text = text.strip()
print(text)

View File

@@ -0,0 +1,220 @@
from fastapi import Body
from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
OVERLAP_SIZE,
logger, log_verbose, )
from server.knowledge_base.utils import (list_files_from_folder)
from sse_starlette import EventSourceResponse
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from typing import List, Optional
from server.knowledge_base.kb_summary.base import KBSummaryService
from server.knowledge_base.kb_summary.summary_chunk import SummaryAdapter
from server.utils import wrap_done, get_ChatOpenAI, BaseResponse
from configs import LLM_MODELS, TEMPERATURE
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
def recreate_summary_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
):
"""
重建单个知识库文件摘要
:param max_tokens:
:param model_name:
:param temperature:
:param file_description:
:param knowledge_base_name:
:param allow_empty_kb:
:param vs_type:
:param embed_model:
:return:
"""
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.drop_kb_summary()
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
files = list_files_from_folder(knowledge_base_name)
i = 0
for i, file_name in enumerate(files):
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description,
docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f"({i + 1} / {len(files)}): {file_name} 总结完成")
yield json.dumps({
"code": 200,
"msg": f"({i + 1} / {len(files)}): {file_name}",
"total": len(files),
"finished": i + 1,
"doc": file_name,
}, ensure_ascii=False)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
i += 1
return EventSourceResponse(output())
def summary_file_to_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
file_name: str = Body(..., examples=["test.pdf"]),
allow_empty_kb: bool = Body(True),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
):
"""
单个知识库根据文件名称摘要
:param model_name:
:param max_tokens:
:param temperature:
:param file_description:
:param file_name:
:param knowledge_base_name:
:param allow_empty_kb:
:param vs_type:
:param embed_model:
:return:
"""
def output():
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists() and not allow_empty_kb:
yield {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}"}
else:
# 重新创建知识库
kb_summary = KBSummaryService(knowledge_base_name, embed_model)
kb_summary.create_kb_summary()
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
doc_infos = kb.list_docs(file_name=file_name)
docs = summary.summarize(file_description=file_description,
docs=doc_infos)
status_kb_summary = kb_summary.add_kb_summary(summary_combine_docs=docs)
if status_kb_summary:
logger.info(f" {file_name} 总结完成")
yield json.dumps({
"code": 200,
"msg": f"{file_name} 总结完成",
"doc": file_name,
}, ensure_ascii=False)
else:
msg = f"知识库'{knowledge_base_name}'总结文件‘{file_name}’时出错。已跳过。"
logger.error(msg)
yield json.dumps({
"code": 500,
"msg": msg,
})
return EventSourceResponse(output())
def summary_doc_ids_to_vector_store(
knowledge_base_name: str = Body(..., examples=["samples"]),
doc_ids: List = Body([], examples=[["uuid"]]),
vs_type: str = Body(DEFAULT_VS_TYPE),
embed_model: str = Body(EMBEDDING_MODEL),
file_description: str = Body(''),
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
) -> BaseResponse:
"""
单个知识库根据doc_ids摘要
:param knowledge_base_name:
:param doc_ids:
:param model_name:
:param max_tokens:
:param temperature:
:param file_description:
:param vs_type:
:param embed_model:
:return:
"""
kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model)
if not kb.exists():
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data={})
else:
llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
reduce_llm = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
# 文本摘要适配器
summary = SummaryAdapter.form_summary(llm=llm,
reduce_llm=reduce_llm,
overlap_size=OVERLAP_SIZE)
doc_infos = kb.get_doc_by_ids(ids=doc_ids)
# doc_infos转换成DocumentWithVSId包装的对象
doc_info_with_ids = [DocumentWithVSId(**doc.dict(), id=with_id) for with_id, doc in zip(doc_ids, doc_infos)]
docs = summary.summarize(file_description=file_description,
docs=doc_info_with_ids)
# 将docs转换成dict
resp_summarize = [{**doc.dict()} for doc in docs]
return BaseResponse(code=200, msg="总结完成", data={"summarize": resp_summarize})

View File

@@ -0,0 +1,192 @@
from configs import (
EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
CHUNK_SIZE, OVERLAP_SIZE,
logger, log_verbose
)
from server.knowledge_base.utils import (
get_file_path, list_kbs_from_folder,
list_files_from_folder, files2docs_in_thread,
KnowledgeFile
)
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.models.conversation_model import ConversationModel
from server.db.models.message_model import MessageModel
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
from server.db.repository.knowledge_metadata_repository import add_summary_to_db
from server.db.base import Base, engine
from server.db.session import session_scope
import os
from dateutil.parser import parse
from typing import Literal, List
import time
def create_tables():
Base.metadata.create_all(bind=engine)
def reset_tables():
Base.metadata.drop_all(bind=engine)
create_tables()
def import_from_db(
sqlite_path: str = None,
# csv_path: str = None,
) -> bool:
"""
在知识库与向量库无变化的情况下,从备份数据库中导入数据到 info.db。
适用于版本升级时info.db 结构变化,但无需重新向量化的情况。
请确保两边数据库表名一致,需要导入的字段名一致
当前仅支持 sqlite
"""
import sqlite3 as sql
from pprint import pprint
models = list(Base.registry.mappers)
try:
con = sql.connect(sqlite_path)
con.row_factory = sql.Row
cur = con.cursor()
tables = [x["name"] for x in cur.execute("select name from sqlite_master where type='table'").fetchall()]
for model in models:
table = model.local_table.fullname
if table not in tables:
continue
print(f"processing table: {table}")
with session_scope() as session:
for row in cur.execute(f"select * from {table}").fetchall():
data = {k: row[k] for k in row.keys() if k in model.columns}
if "create_time" in data:
data["create_time"] = parse(data["create_time"])
pprint(data)
session.add(model.class_(**data))
con.close()
return True
except Exception as e:
print(f"无法读取备份数据库:{sqlite_path}。错误信息:{e}")
return False
def file_to_kbfile(kb_name: str, files: List[str]) -> List[KnowledgeFile]:
kb_files = []
for file in files:
try:
kb_file = KnowledgeFile(filename=file, knowledge_base_name=kb_name)
kb_files.append(kb_file)
except Exception as e:
msg = f"{e},已跳过"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return kb_files
def folder2db(
kb_names: List[str],
mode: Literal["recreate_vs", "update_in_db", "increment"],
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
embed_model: str = EMBEDDING_MODEL,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
):
"""
use existed files in local folder to populate database and/or vector store.
set parameter `mode` to:
recreate_vs: recreate all vector store and fill info to database using existed files in local folder
fill_info_only(disabled): do not create vector store, fill info to db using existed files only
update_in_db: update vector store and database info using local files that existed in database only
increment: create vector store and database info for local files that not existed in database only
"""
def files2vs(kb_name: str, kb_files: List[KnowledgeFile]):
# 切片
for success, result in files2docs_in_thread(kb_files,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance):
if success:
_, filename, docs = result
print(f"正在将 {kb_name}/{filename} 添加到向量库,共包含{len(docs)}条文档")
kb_file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kb_file.splited_docs = docs
# 向量化
kb.add_doc(kb_file=kb_file, not_refresh_vs_cache=True)
else:
print(result)
kb_names = kb_names or list_kbs_from_folder()
for kb_name in kb_names:
kb = KBServiceFactory.get_service(kb_name, vs_type, embed_model)
if not kb.exists():
kb.create_kb()
# 清除向量库,从本地文件重建
if mode == "recreate_vs":
kb.clear_vs()
kb.create_kb()
kb_files = file_to_kbfile(kb_name, list_files_from_folder(kb_name))
files2vs(kb_name, kb_files)
start_time_by_save_vs = time.time()
kb.save_vector_store()
end_time_by_save_vs = time.time()
print('向量入库耗时:{}'.format(end_time_by_save_vs - start_time_by_save_vs))
# # 不做文件内容的向量化,仅将文件元信息存到数据库
# # 由于现在数据库存了很多与文本切分相关的信息,单纯存储文件信息意义不大,该功能取消。
# elif mode == "fill_info_only":
# files = list_files_from_folder(kb_name)
# kb_files = file_to_kbfile(kb_name, files)
# for kb_file in kb_files:
# add_file_to_db(kb_file)
# print(f"已将 {kb_name}/{kb_file.filename} 添加到数据库")
# 以数据库中文件列表为基准,利用本地文件更新向量库
elif mode == "update_in_db":
files = kb.list_files()
kb_files = file_to_kbfile(kb_name, files)
files2vs(kb_name, kb_files)
kb.save_vector_store()
# 对比本地目录与数据库中的文件列表,进行增量向量化
elif mode == "increment":
db_files = kb.list_files()
folder_files = list_files_from_folder(kb_name)
files = list(set(folder_files) - set(db_files))
kb_files = file_to_kbfile(kb_name, files)
files2vs(kb_name, kb_files)
kb.save_vector_store()
else:
print(f"unsupported migrate mode: {mode}")
def prune_db_docs(kb_names: List[str]):
"""
delete docs in database that not existed in local folder.
it is used to delete database docs after user deleted some doc files in file browser
"""
for kb_name in kb_names:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb is not None:
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
print(f"success to delete docs for file: {kb_name}/{kb_file.filename}")
kb.save_vector_store()
def prune_folder_files(kb_names: List[str]):
"""
delete doc files in local folder that not existed in database.
it is used to free local disk space by delete unused doc files.
"""
for kb_name in kb_names:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb is not None:
files_in_db = kb.list_files()
files_in_folder = list_files_from_folder(kb_name)
files = list(set(files_in_folder) - set(files_in_db))
for file in files:
os.remove(get_file_path(kb_name, file))
print(f"success to delete file: {kb_name}/{file}")

View File

@@ -0,0 +1,10 @@
from langchain.docstore.document import Document
class DocumentWithVSId(Document):
"""
矢量化后的文档
"""
id: str = None
score: float = 3.0

View File

@@ -0,0 +1,16 @@
"""
PDF 转 Markdown 微服务的 HTTP 地址(仅此一处拼 URL不 import configs避免旧 .pyc / 错误包名导致仍用 0.0.0.0)。
"""
from __future__ import annotations
import os
_DEFAULT = "http://127.0.0.1:6006/convert/"
def resolve_pdf_convert_post_url() -> str:
u = (os.environ.get("PDF_CONVERT_API_URL") or "").strip() or _DEFAULT
u = u.replace("0.0.0.0", "127.0.0.1")
if not u.startswith("http"):
return _DEFAULT
return u

View File

@@ -0,0 +1,580 @@
import asyncio
import os
import re
from configs import (
KB_ROOT_PATH,
CHUNK_SIZE,
OVERLAP_SIZE,
ZH_TITLE_ENHANCE,
logger,
log_verbose,
text_splitter_dict,
LLM_MODELS,
TEXT_SPLITTER_NAME,
TEXT_SPLITTER_MAP
)
import importlib
from server.chat.policy_fun_iast import get_llm_model_response_async
from server.knowledge_base import kb_service as tr
from server.knowledge_base.TexkRank import TextRank
from text_splitter import zh_title_enhance as func_zh_title_enhance
import langchain.document_loaders
from langchain.docstore.document import Document
from langchain.text_splitter import TextSplitter
from pathlib import Path
from server.utils import run_in_thread_pool, get_model_worker_config
import json
from typing import List, Union,Dict, Tuple, Generator
import chardet
import time
def get_split_time(f):
def inner(*arg,**kwarg):
s_time = time.time()
res = f(*arg,**kwarg)
e_time = time.time()
print('切片耗时:{}'.format(e_time - s_time))
return res
return inner
def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id:
return False
return True
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
def get_vs_path(knowledge_base_name: str, vector_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store", vector_name)
def get_file_path(knowledge_base_name: str, doc_name: str):
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
def list_kbs_from_folder():
return [f for f in os.listdir(KB_ROOT_PATH)
if os.path.isdir(os.path.join(KB_ROOT_PATH, f))]
def list_files_from_folder(kb_name: str):
doc_path = get_doc_path(kb_name)
result = []
def is_skiped_path(path: str):
tail = os.path.basename(path).lower()
for x in ["temp", "tmp", ".", "~$"]:
if tail.startswith(x):
return True
return False
def process_entry(entry):
if is_skiped_path(entry.path):
return
if entry.is_symlink():
target_path = os.path.realpath(entry.path)
with os.scandir(target_path) as target_it:
for target_entry in target_it:
process_entry(target_entry)
elif entry.is_file():
file_path = (Path(os.path.relpath(entry.path, doc_path)).as_posix()) # 路径统一为 posix 格式
result.append(file_path)
elif entry.is_dir():
with os.scandir(entry.path) as it:
for sub_entry in it:
process_entry(sub_entry)
with os.scandir(doc_path) as it:
for entry in it:
process_entry(entry)
return result
LOADER_DICT = {"GCYHTMLLoader": ['.html', '.htm'],
"GCYWordLoader2": ['.docx', '.doc'],
# "GCYWordLoader": ['.docx'], # .doc 解析目前有点问题,暂时关掉
"MHTMLLoader": ['.mhtml'],
"TextLoader": ['.md', '.txt'],
"JSONLoader": [".json"],
"JSONLinesLoader": [".jsonl"],
"RapidOCRCSVLoader": [".csv"],
# "CSVLoader": [".csv"],
# "FilteredCSVLoader": [".csv"], 如果使用自定义分割csv
"PyMuPDFLoader": [".pdf"],
#"RapidOCRDocLoader": ['.docx', '.doc'],
"RapidOCRPPTLoader": ['.ppt', '.pptx', ],
"RapidOCRLoader": ['.png', '.jpg', '.jpeg', '.bmp'],
"UnstructuredFileLoader": ['.eml', '.msg', '.rst',
'.rtf', '.xml',
'.epub', '.odt','.tsv'],
"UnstructuredEmailLoader": ['.eml', '.msg'],
"UnstructuredEPubLoader": ['.epub'],
"ExcelLoader": ['.xlsx', '.xls', '.xlsd'],
"NotebookLoader": ['.ipynb'],
"UnstructuredODTLoader": ['.odt'],
"PythonLoader": ['.py'],
"UnstructuredRSTLoader": ['.rst'],
"UnstructuredRTFLoader": ['.rtf'],
"SRTLoader": ['.srt'],
"TomlLoader": ['.toml'],
"UnstructuredTSVLoader": ['.tsv'],
"UnstructuredXMLLoader": ['.xml'],
"UnstructuredPowerPointLoader": ['.ppt', '.pptx'],
"EverNoteLoader": ['.enex'],
}
SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist]
# patch json.dumps to disable ensure_ascii
def _new_json_dumps(obj, **kwargs):
kwargs["ensure_ascii"] = False
return _origin_json_dumps(obj, **kwargs)
if json.dumps is not _new_json_dumps:
_origin_json_dumps = json.dumps
json.dumps = _new_json_dumps
class JSONLinesLoader(langchain.document_loaders.JSONLoader):
'''
行式 Json 加载器,要求文件扩展名为 .jsonl
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._json_lines = True
langchain.document_loaders.JSONLinesLoader = JSONLinesLoader
def get_LoaderClass(file_extension):
for LoaderClass, extensions in LOADER_DICT.items():
if file_extension in extensions:
return LoaderClass
def get_SplitterClass(file_extension):
"""
根据文件类型获取文本分块器类型
"""
# print('get Splitter Class', file_extension)
for SplitterClass, extensions in TEXT_SPLITTER_MAP.items():
if file_extension in extensions:
return SplitterClass
print(f'未找到文件类型"{file_extension}"对应的切分器,使用默认值')
return TEXT_SPLITTER_NAME
def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None):
'''
根据loader_name和文件路径或内容返回文档加载器。
'''
loader_kwargs = loader_kwargs or {}
try:
# print(loader_name)
if loader_name in ["RapidOCRLoader", "FilteredCSVLoader",
"GCYWordLoader","GCYWordLoader2", "GCYHTMLLoader",
"RapidOCRPPTLoader", "RapidOCRCSVLoader","ExcelLoader"]:
document_loaders_module = importlib.import_module('document_loaders')
else:
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, loader_name)
except Exception as e:
msg = f"为文件{file_path}查找加载器{loader_name}时出错:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
document_loaders_module = importlib.import_module('langchain.document_loaders')
DocumentLoader = getattr(document_loaders_module, "UnstructuredFileLoader")
if loader_name in ["UnstructuredFileLoader", "TextLoader"]:
loader_kwargs.setdefault("autodetect_encoding", True)
elif loader_name == "CSVLoader":
if not loader_kwargs.get("encoding"):
# 如果未指定 encoding自动识别文件编码类型避免langchain loader 加载文件报编码错误
with open(file_path, 'rb') as struct_file:
encode_detect = chardet.detect(struct_file.read())
if encode_detect is None:
encode_detect = {"encoding": "utf-8"}
loader_kwargs["encoding"] = encode_detect["encoding"]
elif loader_name == "JSONLoader":
loader_kwargs.setdefault("jq_schema", ".")
loader_kwargs.setdefault("text_content", False)
elif loader_name == "JSONLinesLoader":
loader_kwargs.setdefault("jq_schema", ".")
loader_kwargs.setdefault("text_content", False)
loader = DocumentLoader(file_path, **loader_kwargs)
return loader
def make_text_splitter(
splitter_name: str = TEXT_SPLITTER_NAME,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
llm_model: str = LLM_MODELS[0],
):
"""
根据参数获取特定的分词器
"""
# print('spliter name', splitter_name)
splitter_name = splitter_name or "SpacyTextSplitter"
try:
# if splitter_name == "GCYMarkdownTextSplitter": # MarkdownHeaderTextSplitter特殊判定
if splitter_name == "MarkdownTextSplitter": # MarkdownHeaderTextSplitter特殊判定
text_splitter_module = importlib.import_module('text_splitter')
TextSplitter = getattr(text_splitter_module, splitter_name)
headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on']
text_splitter = TextSplitter(
headers_to_split_on=headers_to_split_on,
strip_headers=False, # 不要将标题从分块文本中去掉
promote_headers=True
)
else:
try: ## 优先使用用户自定义的text_splitter
text_splitter_module = importlib.import_module('text_splitter')
TextSplitter = getattr(text_splitter_module, splitter_name)
except: ## 否则使用langchain的text_splitter
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, splitter_name)
if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载
try:
text_splitter = TextSplitter.from_tiktoken_encoder(
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
except:
text_splitter = TextSplitter.from_tiktoken_encoder(
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
config = get_model_worker_config(llm_model)
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
config.get("model_path")
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
from transformers import GPT2TokenizerFast
from langchain.text_splitter import CharacterTextSplitter
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
else: ## 字符长度加载
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
trust_remote_code=True)
text_splitter = TextSplitter.from_huggingface_tokenizer(
tokenizer=tokenizer,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
elif text_splitter_dict[splitter_name]["source"] == "no_tokenizer": # IAST 0429: 目前不需要使用分词器
text_splitter = TextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
else:
try:
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
except:
text_splitter = TextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
except Exception as e:
print(e)
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
text_splitter = TextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
# If you use SpacyTextSplitter you can use GPU to do split likes Issue #1287
# text_splitter._tokenizer.max_length = 37016792
# text_splitter._tokenizer.prefer_gpu()
return text_splitter
class KnowledgeFile:
def __init__(
self,
filename: str,
knowledge_base_name: str,
loader_kwargs: Dict = {},
):
'''
对应知识库目录中的文件,必须是磁盘上存在的才能进行向量化等操作。
'''
self.kb_name = knowledge_base_name
self.filename = str(Path(filename).as_posix())
self.ext = os.path.splitext(filename)[-1].lower()
if self.ext not in SUPPORTED_EXTS:
raise ValueError(f"暂未支持的文件格式 {self.filename}")
self.loader_kwargs = loader_kwargs
self.filepath = get_file_path(knowledge_base_name, filename)
self.docs = None
self.splited_docs = None
self.document_loader_name = get_LoaderClass(self.ext)
self.text_splitter_name = get_SplitterClass(self.ext)
def get_full_text(self) -> Dict[str, str]:
"""
获取文件的全文内容,并返回文件名和全文内容的结构。
"""
try:
docs = self.file2docs()
full_text = "".join([doc.page_content for doc in docs])
result = json.dumps( {
"filename": self.filename,
"full_text": full_text
}, ensure_ascii=False)
return result
except Exception as e:
logger.error(f"获取文件全文内容时出错:{e}", exc_info=e if log_verbose else None)
return {
"filename": self.filename,
"full_text": "加载文件失败或文件内容为空"
}
async def get_llm_result(self) -> Dict[str, str]:
"""
根据文件的全文内容,异步调用模型生成文章摘要、关键词和章节速览。
"""
try:
# full_text_data = self.get_full_text()
# full_text = full_text_data.get("full_text", "")
loop = asyncio.get_event_loop()
full_text_data = await loop.run_in_executor(None, self.get_full_text)
# full_text = full_text_data.get("full_text", "")
try:
# 将 JSON 字符串解析为字典
full_text_dict = json.loads(full_text_data)
full_text = full_text_dict.get("full_text", "")
except json.JSONDecodeError:
print("解析 JSON 数据时出错")
full_text = ""
if len(full_text) > 40000:
# 判断英文占比
# english_chars = re.findall(r'[a-zA-Z]', full_text)
# english_ratio = len(english_chars) / len(full_text) if len(full_text) > 0 else 0
# if english_ratio > 0.9 and len(full_text) > 50000:
full_text = full_text[:40000]
# logger.info(f'=============文章长度{len(full_text)}')
# full_text_80 = TextRank(full_text, 80)
# logger.info(f'=============按80句压缩后文章长度{len(full_text_80)}')
# if len(full_text_80) > 55000:
# full_text_10 = TextRank(full_text_80, num_sentences=10)
# logger.info(f'=============按10句压缩后文章长度{len(full_text_10)}')
# full_text = full_text_10
# else:
# full_text = full_text_80
else:
pass
# 异步调用模型
from asyncio import gather
llm_time = time.time()
abstract_task = get_llm_model_response_async(
strategy_name="gen_abstract",
llm_model_name=LLM_MODELS[1],
template_prompt_name="gen_abstract",
prompt_param_dict={"context": full_text},
temperature=0.7,
max_tokens=4096
)
keywords_task = get_llm_model_response_async(
strategy_name="gen_keywords",
llm_model_name=LLM_MODELS[1],
template_prompt_name="gen_keywords",
prompt_param_dict={"context": full_text},
temperature=0.7,
max_tokens=512
)
paragraph_task = get_llm_model_response_async(
strategy_name="gen_paragraph",
llm_model_name=LLM_MODELS[0],
template_prompt_name="gen_paragraph",
prompt_param_dict={"context": full_text},
temperature=0.7,
max_tokens=8192
)
# 并行执行任务
article_abstract, article_keywords, article_paragraph = await gather(
abstract_task, keywords_task, paragraph_task
)
logger.info(f'生成导读用时:{time.time() - llm_time}')
return {
"filename": self.filename,
"full_text": full_text,
"article_abstract": article_abstract,
"article_keywords": article_keywords,
"article_paragraph": article_paragraph
}
except Exception as e:
logger.error(f"生成LLM结果时出错{e}", exc_info=e if log_verbose else None)
return {
"filename": self.filename,
"article_abstract": "生成摘要失败",
"article_keywords": "生成关键词失败",
"article_paragraph": "生成章节速览失败"
}
def file2docs(self, refresh: bool = False):
if self.docs is None or refresh:
try:
logger.info(f"{self.document_loader_name} used for {self.filepath}")
loader = get_loader(loader_name=self.document_loader_name,
file_path=self.filepath,
loader_kwargs=self.loader_kwargs)
self.docs = loader.load()
except Exception as e:
if self.document_loader_name == 'GCYWordLoader':
loader = get_loader(loader_name='GCYWordLoader2',
file_path=self.filepath,
loader_kwargs=self.loader_kwargs)
else:
logger.error(f"加载文件 {self.filepath} 时出错:{e}", exc_info=e if log_verbose else None)
self.docs = loader.load()
return self.docs
@get_split_time
def docs2texts(
self,
docs: List[Document] = None,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
docs = docs or self.file2docs(refresh=refresh)
# debug 0429
# print('docs2texts',docs )
if not docs:
return []
if text_splitter is None:
self.text_splitter_name = get_SplitterClass(self.ext)
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size,
chunk_overlap=chunk_overlap)
# if self.text_splitter_name == "GCYMarkdownTextSplitter":
if self.text_splitter_name == "MarkdownTextSplitter":
doc_source = (docs[0].metadata)["source"]
docs = text_splitter.split_markdown_text(docs[0].page_content, doc_source)
else:
docs = text_splitter.split_documents(docs)
if not docs:
return []
# 检查切分好的文档是否有'h1'标题字段如果没有就加上。为之后入库其它有h1的文件做准备
if 'h1' not in docs[0].metadata:
for doc in docs:
doc.metadata['h1'] = ''
print(f"文档切分示例:{docs[0]}")
if zh_title_enhance:
docs = func_zh_title_enhance(docs)
self.splited_docs = docs
return self.splited_docs
def file2text(
self,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
refresh: bool = False,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
text_splitter: TextSplitter = None,
):
if self.splited_docs is None or refresh:
docs = self.file2docs()
self.splited_docs = self.docs2texts(docs=docs,
zh_title_enhance=zh_title_enhance,
refresh=refresh,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
text_splitter=text_splitter)
return self.splited_docs
def file_exist(self):
return os.path.isfile(self.filepath)
def get_mtime(self):
return os.path.getmtime(self.filepath)
def get_size(self):
return os.path.getsize(self.filepath)
def files2docs_in_thread(
files: List[Union[KnowledgeFile, Tuple[str, str], Dict]],
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
) -> Generator:
'''
利用多线程批量将磁盘文件转化成langchain Document.
如果传入参数是Tuple形式为(filename, kb_name)
生成器返回值为 status, (kb_name, file_name, docs | error)
'''
def file2docs(*, file: KnowledgeFile, **kwargs) -> Tuple[bool, Tuple[str, str, List[Document]]]:
try:
return True, (file.kb_name, file.filename, file.file2text(**kwargs))
except Exception as e:
msg = f"从文件 {file.kb_name}/{file.filename} 加载文档时出错:{e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return False, (file.kb_name, file.filename, msg)
kwargs_list = []
for i, file in enumerate(files):
kwargs = {}
try:
if isinstance(file, tuple) and len(file) >= 2:
filename = file[0]
kb_name = file[1]
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
elif isinstance(file, dict):
filename = file.pop("filename")
kb_name = file.pop("kb_name")
kwargs.update(file)
file = KnowledgeFile(filename=filename, knowledge_base_name=kb_name)
kwargs["file"] = file
kwargs["chunk_size"] = chunk_size
kwargs["chunk_overlap"] = chunk_overlap
kwargs["zh_title_enhance"] = zh_title_enhance
kwargs_list.append(kwargs)
except Exception as e:
yield False, (kb_name, filename, str(e))
for result in run_in_thread_pool(func=file2docs, params=kwargs_list):
yield result
if __name__ == "__main__":
from pprint import pprint
kb_file = KnowledgeFile(
filename="/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat/knowledge_base/csv1/content/gm.csv",
knowledge_base_name="samples")
# kb_file.text_splitter_name = "RecursiveCharacterTextSplitter"
docs = kb_file.file2docs()
# pprint(docs[-1])