[全量] 初始化项目代码、配置、文档及Agent协同harness

This commit is contained in:
2026-04-02 11:36:05 +08:00
parent 0553309cdf
commit 87e571d9ec
1133 changed files with 221948 additions and 0 deletions

View File

@@ -0,0 +1,775 @@
from datetime import datetime
import operator
from abc import ABC, abstractmethod
import re
import os
from pathlib import Path
import numpy as np
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
from joblib import Parallel, delayed
import time
from textrank4zh import TextRank4Keyword, TextRank4Sentence
import multiprocessing
from configs.model_config import LLM_MODELS
from server.chat.policy_fun_iast import get_llm_model_response
from server.chat.utils import get_personal_knowledge_map, get_similar_documents1
from nltk.tokenize import sent_tokenize
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
def generate_weights_as_list(length, total_sum=80):
if length == 1:
return [50]
# Create a range of numbers decreasing logarithmically
x = np.linspace(0, length - 1, length)
weights = np.exp(-x / (length / 5))
# Normalize the weights to match the specified sum
normalized_weights = weights / sum(weights) * total_sum
integer_weights = np.round(normalized_weights).astype(int)
# Adjust the weights to match the exact sum if rounding causes deviation
adjustment = total_sum - sum(integer_weights)
for i in range(abs(adjustment)):
if adjustment > 0:
integer_weights[i % length] += 1
elif adjustment < 0:
integer_weights[i % length] -= 1
return integer_weights.tolist()
def score_threshold_process(query,score_threshold, k, docs):
"""
根据分数阈值过滤和使用TextRank摘要文档并返回前k个文档。
:param score_threshold: 相似度分数阈值;忽略低于此阈值的文档。
:param k: 要返回的顶部文档数量。
:param docs: 文档列表,每个文档是一个元组(文档,相似度分数)。
:return: 根据分数阈值返回的前k个文档的列表。
"""
# 如果提供了score_threshold则只过滤大于阈值的文档。
if score_threshold is not None:
cmp = (
operator.le
)
docs = [
(doc, similarity)
for doc, similarity in docs
if cmp(similarity, score_threshold)
]
# 当召回结果都大于score_threshold时
if len(docs) == 0:
return docs
result = []
try:
for doc in docs:
if query.replace(" ","").replace("\n","").replace("\r","") in doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r",""):
result.append(doc)
except Exception as e:
for doc in docs:
if query.replace(" ","").replace("\n","").replace("\r","") in doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r",""):
result.append(doc)
if len(docs) > 0 and not "h1" in docs[0][0].metadata:
# 如果存在用户的问题在标题中的情况则进行去重操作且不需要再匹配相关度,只需要把问题在标题中的文献提交出去
if len(result) > 0:
temp={}
for doc in result:
if doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") not in temp:
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")] = doc
else:
if temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
continue
elif temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") :
continue
else:
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
docs = []
for i in temp:
docs.append(temp[i])
else:
try:
sentences = []
sentences_page_content = []
for doc in docs:
meta = doc[0].metadata
# 若缺少标题或为空则用正文首句作为标题最多50字
if meta["title"] == "":
meta["title"] = doc[0].page_content
# 如有摘要则替换 page_content保证后续文本更简洁
summary = meta.get("summary")
if summary:
doc[0].page_content = summary
sentences = [doc[0].metadata["title"] for doc in docs]
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["title"]+doc[0].page_content+"" for i,doc in enumerate(docs)]
except Exception as e:
sentences = [doc[0].metadata["source"] for doc in docs]
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"" for i,doc in enumerate(docs)]
res = get_llm_model_response(
strategy_name="default_similar",
llm_model_name=LLM_MODELS[0],
template_prompt_name="default_similar",
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=512
)
try:
index =[]
if res == "":
index = []
else:
index = res.split(",")
index = [int(i)-1 for i in index]
docs = get_similar_documents1(index=index,sentences=sentences,query=query, docs=docs, top_k=k)
except Exception as e:
print(e)
docs = get_similar_documents1(index=[],sentences=sentences,query=query, docs=docs, top_k=k)
# 去重操作只针对通用知识库
temp={}
for doc in docs:
try:
if doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") not in temp:
temp[doc[0].metadata["title"]] = doc
else:
if temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
continue
elif temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") :
continue
else:
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
except Exception as e:
print(e)
if doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") not in temp:
temp[doc[0].metadata["source"]] = doc
else:
if temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
continue
elif temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") :
continue
else:
temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
docs = []
for i in temp:
docs.append(temp[i])
#只针对个人知识库
if "h1" in docs[0][0].metadata:
all_source = [doc[0].metadata["source"] for doc in docs]
unique_source = list(set(all_source))
all_title_map = get_personal_knowledge_map(unique_source)
for doc in docs:
doc[0].metadata["uuid_name"] = doc[0].metadata["source"]
if doc[0].metadata["source"] in all_title_map:
doc[0].metadata["source"] = all_title_map[doc[0].metadata["source"]]
else:
pass
try:
sentences = [doc[0].metadata["source"] for doc in docs]
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"" for i,doc in enumerate(docs)]
except Exception as e:
sentences = [doc[0].metadata["source"] for doc in docs]
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"" for i,doc in enumerate(docs)]
kwargs = {}
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}}
res = get_llm_model_response(
strategy_name="default_similar",
llm_model_name=LLM_MODELS[0],
template_prompt_name="default_similar",
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=None,
**kwargs
)
res = re.sub(r'<think>.*?</think>', '', res,flags=re.DOTALL)
try:
index =[]
if res == "":
index = []
else:
index = res.split(",")
index = [int(i)-1 for i in index]
docs = get_similar_documents1(index=index,sentences=sentences,query=query, docs=docs, top_k=k)
except Exception as e:
print(e)
docs = get_similar_documents1(index=[],sentences=sentences,query=query, docs=docs, top_k=k)
# 去重操作只针对通用知识库
temp={}
for doc in docs:
if doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") not in temp:
temp[doc[0].metadata["source"]] = doc
else:
if temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
continue
elif temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") :
continue
else:
temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
docs = []
for i in temp:
docs.append(temp[i])
if len(docs) == 0:
return docs
# 为TextRank算法生成权重。
cont = generate_weights_as_list(len(docs))
# 处理每个文档以提取或分配摘要。
for i, (doc, _) in enumerate(docs):
summary_sources = ['content', 'abstract', 'text'] # 根据不同知识库遍历字段
for source in summary_sources:
try:
if docs[i][0].metadata["title"] in docs[i][0].page_content or len(docs[i][0].page_content) < 100:
doc.metadata['summary'] = TextRank(doc.metadata[source], cont[i])
if len(doc.metadata['summary']) >15000:
doc.metadata['summary'] = TextRank(doc.metadata[source], 1)
break
else:
doc.metadata['summary'] = docs[i][0].page_content
except KeyError:
doc.metadata['summary'] = docs[i][0].page_content
continue # 如果当前源失败,则尝试下一个源。
# 返回前k个文档。
return docs[:k]
# 猴子补丁为了兼容TexRank
import networkx as nx
import numpy as np
# 进行猴子补丁
nx.from_numpy_matrix = nx.from_numpy_array
# 进行猴子补丁,入数据类型兼容性检查
def process_text_segment(text_segment, num_sentences):
tr4w = TextRank4Keyword()
tr4w.analyze(text=text_segment, lower=True, window=5)
keywords = [(item.word, item.weight) for item in tr4w.get_keywords(30, word_min_len=4)]
tr4s = TextRank4Sentence()
tr4s.analyze(text=text_segment, lower=True, source='all_filters')
summaries = [item.sentence for item in tr4s.get_key_sentences(num=num_sentences)]
return keywords, summaries
def split_text_balanced(text, n_parts):
sentences = sent_tokenize(text)
min_sentences_per_part = 10
n_parts = max(1, min(n_parts, len(sentences) // min_sentences_per_part))
k, m = divmod(len(sentences), n_parts)
return [' '.join(sentences[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n_parts)]
def TextRank(text,num_sentences, n_cores=multiprocessing.cpu_count()):
start_time = time.time()
text_parts = split_text_balanced(text, n_cores)
all_keywords = []
all_summaries = []
# 在这里直接顺序处理每个分块,或用线程池而非进程池
for part in text_parts:
keywords, summaries = process_text_segment(part, num_sentences)
all_keywords.extend(keywords)
all_summaries.extend(summaries)
# Print results
# print('关键词:')
for word, weight in sorted(all_keywords, key=lambda x: x[1], reverse=True):
print(word, weight)
end_time = time.time()
logging.info(f"TextRank耗时: {end_time - start_time:.2f}")
all_summaries = "".join(all_summaries)
return all_summaries
from server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail,
)
from server.db.repository.knowledge_file_repository import (
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
list_docs_from_db,
)
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,EXPR,
EMBEDDING_MODEL, KB_INFO)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder,
)
from typing import List, Union, Dict, Optional, Tuple
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
import time
def get_emb_time(f):
def inner(*arg,**kwarg):
s_time = time.time()
res = f(*arg,**kwarg)
e_time = time.time()
print('向量化耗时:{}'.format(e_time - s_time))
return res
return inner
def normalize(embeddings: List[List[float]]) -> np.ndarray:
'''
sklearn.preprocessing.normalize 的替代(使用 L2避免安装 scipy, scikit-learn
'''
# 过滤掉 None 值
embeddings = [e for e in embeddings if e is not None]
if not embeddings:
raise ValueError("No valid embeddings found (all are None)")
embeddings = np.array(embeddings)
norm = np.linalg.norm(embeddings, axis=1)
norm = np.reshape(norm, (norm.shape[0], 1))
norm = np.tile(norm, (1, len(embeddings[0])))
return np.divide(embeddings, norm)
class SupportedVSType:
FAISS = 'faiss'
MILVUS = 'milvus'
DEFAULT = 'default'
ZILLIZ = 'zilliz'
PG = 'pg'
ES = 'es'
CHROMADB = 'chromadb'
class KBService(ABC):
def __init__(self,
knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL,
):
self.kb_name = knowledge_base_name
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
self.embed_model = embed_model
self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name)
self.do_init()
def __repr__(self) -> str:
return f"{self.kb_name} @ {self.embed_model}"
def save_vector_store(self):
'''
保存向量库:FAISS保存到磁盘milvus保存到数据库。PGVector暂未支持
'''
pass
def create_kb(self):
"""
创建知识库
"""
if not os.path.exists(self.doc_path):
os.makedirs(self.doc_path)
self.do_create_kb()
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
return status
def clear_vs(self):
"""
删除向量库中所有内容
"""
self.do_clear_vs()
status = delete_files_from_db(self.kb_name)
return status
def drop_kb(self):
"""
删除知识库
"""
self.do_drop_kb()
status = delete_kb_from_db(self.kb_name)
return status
def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
'''
将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
'''
return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
@get_emb_time
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
向知识库添加文件
如果指定了docs则不再将文本向量化并将数据库对应条目标为custom_docs=True
"""
if docs:
custom_docs = True
for doc in docs:
doc.metadata.setdefault("source", kb_file.filename)
else:
docs = kb_file.file2text()
custom_docs = False
if docs:
# 将 metadata["source"] 改为相对路径
for doc in docs:
try:
source = doc.metadata.get("source", "")
if os.path.isabs(source):
rel_path = Path(source).relative_to(self.doc_path)
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
except Exception as e:
print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")
self.delete_doc(kb_file)
doc_infos = self.do_add_doc(docs, **kwargs)
status = add_file_to_db(kb_file,
custom_docs=custom_docs,
docs_count=len(docs),
doc_infos=doc_infos)
else:
status = False
return status
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
"""
从知识库删除文件
"""
self.do_delete_doc(kb_file, **kwargs)
status = delete_file_from_db(kb_file)
if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
return status
def update_info(self, kb_info: str):
"""
更新知识库介绍
"""
self.kb_info = kb_info
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
return status
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
使用content中的文件更新向量库
如果指定了docs则使用自定义docs并将数据库对应条目标为custom_docs=True
"""
if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, docs=docs, **kwargs)
def exist_doc(self, file_name: str):
return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
filename=file_name))
def list_files(self):
return list_files_from_db(self.kb_name)
def count_files(self):
return count_files_from_db(self.kb_name)
def search_docs(self,
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
expr: str = EXPR,
custom_strategy_config: dict = {}
) ->List[Document]:
docs = self.do_search(query, top_k, score_threshold, expr, custom_strategy_config)
return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
return []
def get_doc_by_sources_name(self, source_name_list: List[str]) -> List[Document]:
return []
def del_doc_by_ids(self, ids: List[str]) -> bool:
raise NotImplementedError
def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool:
'''
传入参数为: {doc_id: Document, ...}
如果对应 doc_id 的值为 None或其 page_content 为空,则删除该文档
'''
self.del_doc_by_ids(list(docs.keys()))
docs = []
ids = []
for k, v in docs.items():
if not v or not v.page_content.strip():
continue
ids.append(k)
docs.append(v)
self.do_add_doc(docs=docs, ids=ids)
return True
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
'''
通过file_name或metadata检索Document
'''
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
docs = []
for x in doc_infos:
doc_info = self.get_doc_by_ids([x["id"]])[0]
if doc_info is not None:
# 处理非空的情况
doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"])
docs.append(doc_with_id)
else:
# 处理空的情况
# 可以选择跳过当前循环迭代或执行其他操作
pass
return docs
@abstractmethod
def do_create_kb(self):
"""
创建知识库子类实自己逻辑
"""
pass
@staticmethod
def list_kbs_type():
return list(kbs_config.keys())
@classmethod
def list_kbs(cls):
return list_kbs_from_db()
def exists(self, kb_name: str = None):
kb_name = kb_name or self.kb_name
return kb_exists(kb_name)
@abstractmethod
def vs_type(self) -> str:
pass
@abstractmethod
def do_init(self):
pass
@abstractmethod
def do_drop_kb(self):
"""
删除知识库子类实自己逻辑
"""
pass
@abstractmethod
def do_search(self,
query: str,
top_k: int,
score_threshold: float,
expr: str,
custom_strategy_config: dict = {},
) -> List[Tuple[Document, float]]:
"""
搜索知识库子类实自己逻辑
"""
pass
@abstractmethod
def do_add_doc(self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
"""
向知识库添加文档子类实自己逻辑
"""
pass
@abstractmethod
def do_delete_doc(self,
kb_file: KnowledgeFile):
"""
从知识库删除文档子类实自己逻辑
"""
pass
@abstractmethod
def do_clear_vs(self):
"""
从知识库删除全部向量子类实自己逻辑
"""
pass
class KBServiceFactory:
@staticmethod
def get_service(kb_name: str,
vector_store_type: Union[str, SupportedVSType],
embed_model: str = EMBEDDING_MODEL,
) -> KBService:
if isinstance(vector_store_type, str):
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
if SupportedVSType.FAISS == vector_store_type:
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
return FaissKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.PG == vector_store_type:
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
return PGKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.MILVUS == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name,embed_model=embed_model)
elif SupportedVSType.ZILLIZ == vector_store_type:
from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
return ZillizKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name,
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.ES == vector_store_type:
from server.knowledge_base.kb_service.es_kb_service import ESKBService
return ESKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.CHROMADB == vector_store_type:
from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
return ChromaKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)
@staticmethod
def get_service_by_name(kb_name: str) -> KBService:
_, vs_type, embed_model = load_kb_from_db(kb_name)
if _ is None: # kb not in db, just return None
return None
from server.utils import resolve_embed_model_name
embed_model = resolve_embed_model_name(embed_model)
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@staticmethod
def get_default():
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
def get_kb_details() -> List[Dict]:
kbs_in_folder = list_kbs_from_folder()
kbs_in_db = KBService.list_kbs()
result = {}
for kb in kbs_in_folder:
result[kb] = {
"kb_name": kb,
"vs_type": "",
"kb_info": "",
"embed_model": "",
"file_count": 0,
"create_time": None,
"in_folder": True,
"in_db": False,
}
for kb in kbs_in_db:
kb_detail = get_kb_detail(kb)
if kb_detail:
kb_detail["in_db"] = True
if kb in result:
result[kb].update(kb_detail)
else:
kb_detail["in_folder"] = False
result[kb] = kb_detail
data = []
for i, v in enumerate(result.values()):
v['No'] = i + 1
data.append(v)
return data
def get_kb_file_details(kb_name: str) -> List[Dict]:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb is None:
return []
files_in_folder = list_files_from_folder(kb_name)
files_in_db = kb.list_files()
result = {}
for doc in files_in_folder:
result[doc] = {
"kb_name": kb_name,
"file_name": doc,
"file_ext": os.path.splitext(doc)[-1],
"file_version": 0,
"document_loader": "",
"docs_count": 0,
"text_splitter": "",
"create_time": None,
"in_folder": True,
"in_db": False,
}
lower_names = {x.lower(): x for x in result}
for doc in files_in_db:
doc_detail = get_file_detail(kb_name, doc)
if doc_detail:
doc_detail["in_db"] = True
if doc.lower() in lower_names:
result[lower_names[doc.lower()]].update(doc_detail)
else:
doc_detail["in_folder"] = False
result[doc] = doc_detail
data = []
for i, v in enumerate(result.values()):
v['No'] = i + 1
data.append(v)
return data
class EmbeddingsFunAdapter(Embeddings):
def __init__(self, embed_model: str = EMBEDDING_MODEL):
self.embed_model = embed_model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
result = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False)
embeddings = result.data if result and hasattr(result, 'data') else None
if not embeddings:
raise ValueError(f"Failed to get embeddings for texts: {texts[:2]}...")
return normalize(embeddings).tolist()
def embed_query(self, text: str) -> List[float]:
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
return normalize(embeddings).tolist()
async def aembed_query(self, text: str) -> List[float]:
embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
# def score_threshold_process(score_threshold, k, docs):
# if score_threshold is not None:
# cmp = (
# operator.le
# )
# docs = [
# (doc, similarity)
# for doc, similarity in docs
# if cmp(similarity, score_threshold)
# ]
# return docs[:k]

View File

@@ -0,0 +1,105 @@
import uuid
from typing import Any, Dict, List, Tuple
import chromadb
from chromadb.api.types import (GetResult, QueryResult)
from langchain.docstore.document import Document
from configs import SCORE_THRESHOLD
from server.knowledge_base.kb_service.base import (EmbeddingsFunAdapter,
KBService, SupportedVSType)
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
if not get_result['documents']:
return []
_metadatas = get_result['metadatas'] if get_result['metadatas'] else [{}] * len(get_result['documents'])
document_list = []
for page_content, metadata in zip(get_result['documents'], _metadatas):
document_list.append(Document(**{'page_content': page_content, 'metadata': metadata}))
return document_list
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
"""
from langchain_community.vectorstores.chroma import Chroma
"""
return [
# TODO: Chroma can do batch querying,
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
for result in zip(
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
)
]
class ChromaKBService(KBService):
vs_path: str
kb_path: str
client = None
collection = None
def vs_type(self) -> str:
return SupportedVSType.CHROMADB
def get_vs_path(self) -> str:
return get_vs_path(self.kb_name, self.embed_model)
def get_kb_path(self) -> str:
return get_kb_path(self.kb_name)
def do_init(self) -> None:
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
self.client = chromadb.PersistentClient(path=self.vs_path)
self.collection = self.client.get_or_create_collection(self.kb_name)
def do_create_kb(self) -> None:
# In ChromaDB, creating a KB is equivalent to creating a collection
self.collection = self.client.get_or_create_collection(self.kb_name)
def do_drop_kb(self):
# Dropping a KB is equivalent to deleting a collection in ChromaDB
try:
self.client.delete_collection(self.kb_name)
except ValueError as e:
if not str(e) == f"Collection {self.kb_name} does not exist.":
raise e
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD, expr: str) -> List[
Tuple[Document, float]]:
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
return _results_to_docs_and_scores(query_result)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
doc_infos = []
data = self._docs_to_embeddings(docs)
ids = [str(uuid.uuid1()) for _ in range(len(data["texts"]))]
for _id, text, embedding, metadata in zip(ids, data["texts"], data["embeddings"], data["metadatas"]):
self.collection.add(ids=_id, embeddings=embedding, metadatas=metadata, documents=text)
doc_infos.append({"id": _id, "metadata": metadata})
return doc_infos
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
get_result: GetResult = self.collection.get(ids=ids)
return _get_result_to_documents(get_result)
def del_doc_by_ids(self, ids: List[str]) -> bool:
self.collection.delete(ids=ids)
return True
def do_clear_vs(self):
# Clearing the vector store might be equivalent to dropping and recreating the collection
self.do_drop_kb()
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
return self.collection.delete(where={"source": kb_file.filepath})

View File

@@ -0,0 +1,38 @@
from typing import List
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from server.knowledge_base.kb_service.base import KBService
class DefaultKBService(KBService):
def do_create_kb(self):
pass
def do_drop_kb(self):
pass
def do_add_doc(self, docs: List[Document]):
pass
def do_clear_vs(self):
pass
def vs_type(self) -> str:
return "default"
def do_init(self):
pass
def do_search(self):
pass
def do_insert_multi_knowledge(self):
pass
def do_insert_one_knowledge(self):
pass
def do_delete_doc(self):
pass

View File

@@ -0,0 +1,261 @@
from typing import List
import os
import shutil
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores.elasticsearch import ElasticsearchStore
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.utils import KnowledgeFile
from server.utils import load_local_embeddings
from elasticsearch import Elasticsearch,BadRequestError
from configs import logger
from configs import kbs_config
class ESKBService(KBService):
def do_init(self):
self.kb_path = self.get_kb_path(self.kb_name)
self.index_name = os.path.split(self.kb_path)[-1]
self.IP = kbs_config[self.vs_type()]['host']
self.PORT = kbs_config[self.vs_type()]['port']
self.user = kbs_config[self.vs_type()].get("user",'')
self.password = kbs_config[self.vs_type()].get("password",'')
self.dims_length = kbs_config[self.vs_type()].get("dims_length",None)
self.embeddings_model = load_local_embeddings(self.embed_model, EMBEDDING_DEVICE)
try:
# ES python客户端连接仅连接
if self.user != "" and self.password != "":
self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}",
basic_auth=(self.user,self.password))
else:
logger.warning("ES未配置用户名和密码")
self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}")
except ConnectionError:
logger.error("连接到 Elasticsearch 失败!")
raise ConnectionError
except Exception as e:
logger.error(f"Error 发生 : {e}")
raise e
try:
# 首先尝试通过es_client_python创建
mappings = {
"properties": {
"dense_vector": {
"type": "dense_vector",
"dims": self.dims_length,
"index": True
}
}
}
self.es_client_python.indices.create(index=self.index_name, mappings=mappings)
except BadRequestError as e:
logger.error("创建索引失败,重新")
logger.error(e)
try:
# langchain ES 连接、创建索引
if self.user != "" and self.password != "":
self.db_init = ElasticsearchStore(
es_url=f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
query_field="context",
vector_query_field="dense_vector",
embedding=self.embeddings_model,
es_user=self.user,
es_password=self.password
)
else:
logger.warning("ES未配置用户名和密码")
self.db_init = ElasticsearchStore(
es_url=f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
query_field="context",
vector_query_field="dense_vector",
embedding=self.embeddings_model,
)
except ConnectionError:
print("### 初始化 Elasticsearch 失败!")
logger.error("### 初始化 Elasticsearch 失败!")
raise ConnectionError
except Exception as e:
logger.error(f"Error 发生 : {e}")
raise e
try:
# 尝试通过db_init创建索引
self.db_init._create_index_if_not_exists(
index_name=self.index_name,
dims_length=self.dims_length
)
except Exception as e:
logger.error("创建索引失败...")
logger.error(e)
# raise e
@staticmethod
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
@staticmethod
def get_vs_path(knowledge_base_name: str):
return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store")
def do_create_kb(self):
if os.path.exists(self.doc_path):
if not os.path.exists(os.path.join(self.kb_path, "vector_store")):
os.makedirs(os.path.join(self.kb_path, "vector_store"))
else:
logger.warning("directory `vector_store` already exists.")
def vs_type(self) -> str:
return SupportedVSType.ES
def _load_es(self, docs, embed_model):
# 将docs写入到ES中
try:
# 连接 + 同时写入文档
if self.user != "" and self.password != "":
self.db = ElasticsearchStore.from_documents(
documents=docs,
embedding=embed_model,
es_url= f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
distance_strategy="COSINE",
query_field="context",
vector_query_field="dense_vector",
verify_certs=False,
es_user=self.user,
es_password=self.password
)
else:
self.db = ElasticsearchStore.from_documents(
documents=docs,
embedding=embed_model,
es_url= f"http://{self.IP}:{self.PORT}",
index_name=self.index_name,
distance_strategy="COSINE",
query_field="context",
vector_query_field="dense_vector",
verify_certs=False)
except ConnectionError as ce:
print(ce)
print("连接到 Elasticsearch 失败!")
logger.error("连接到 Elasticsearch 失败!")
except Exception as e:
logger.error(f"Error 发生 : {e}")
print(e)
def do_search(self, query:str, top_k: int, score_threshold: float,expr:str):
# 文本相似性检索
docs = self.db_init.similarity_search_with_score(query=query,
k=top_k)
return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
results = []
for doc_id in ids:
try:
response = self.es_client_python.get(index=self.index_name, id=doc_id)
source = response["_source"]
# Assuming your document has "text" and "metadata" fields
text = source.get("context", "")
metadata = source.get("metadata", {})
results.append(Document(page_content=text, metadata=metadata))
except Exception as e:
logger.error(f"Error retrieving document from Elasticsearch! {e}")
return results
def del_doc_by_ids(self, ids: List[str]) -> bool:
for doc_id in ids:
try:
self.es_client_python.delete(index=self.index_name,
id=doc_id,
refresh=True)
except Exception as e:
logger.error(f"ES Docs Delete Error! {e}")
def do_delete_doc(self, kb_file, **kwargs):
if self.es_client_python.indices.exists(index=self.index_name):
# 从向量数据库中删除索引(文档名称是Keyword)
query = {
"query": {
"term": {
"metadata.source.keyword": kb_file.filepath
}
}
}
# 注意设置size默认返回10个。
search_results = self.es_client_python.search(body=query, size=50)
delete_list = [hit["_id"] for hit in search_results['hits']['hits']]
if len(delete_list) == 0:
return None
else:
for doc_id in delete_list:
try:
self.es_client_python.delete(index=self.index_name,
id=doc_id,
refresh=True)
except Exception as e:
logger.error(f"ES Docs Delete Error! {e}")
# self.db_init.delete(ids=delete_list)
#self.es_client_python.indices.refresh(index=self.index_name)
def do_add_doc(self, docs: List[Document], **kwargs):
'''向知识库添加文件'''
print(f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}")
print("*"*100)
self._load_es(docs=docs, embed_model=self.embeddings_model)
# 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...]
print("写入数据成功.")
print("*"*100)
if self.es_client_python.indices.exists(index=self.index_name):
file_path = docs[0].metadata.get("source")
query = {
"query": {
"term": {
"metadata.source.keyword": file_path
},
"term": {
"_index": self.index_name
}
}
}
# 注意设置size默认返回10个。
search_results = self.es_client_python.search(body=query, size=50)
if len(search_results["hits"]["hits"]) == 0:
raise ValueError("召回元素个数为0")
info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]]
return info_docs
def do_clear_vs(self):
"""从知识库删除全部向量"""
if self.es_client_python.indices.exists(index=self.kb_name):
self.es_client_python.indices.delete(index=self.kb_name)
def do_drop_kb(self):
"""删除知识库"""
# self.kb_file: 知识库路径
if os.path.exists(self.kb_path):
shutil.rmtree(self.kb_path)
if __name__ == '__main__':
esKBService = ESKBService("test")
#esKBService.clear_vs()
#esKBService.create_kb()
esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test"))
print(esKBService.search_docs("如何启动api服务"))

View File

@@ -0,0 +1,124 @@
import os
import shutil
from configs import SCORE_THRESHOLD, EXPR
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
from server.utils import torch_gc
from langchain.docstore.document import Document
from typing import List, Dict, Optional, Tuple
class FaissKBService(KBService):
vs_path: str
kb_path: str
vector_name: str = None
def vs_type(self) -> str:
return SupportedVSType.FAISS
def get_vs_path(self):
return get_vs_path(self.kb_name, self.vector_name)
def get_kb_path(self):
return get_kb_path(self.kb_name)
def load_vector_store(self) -> ThreadSafeFaiss:
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
vector_name=self.vector_name,
embed_model=self.embed_model)
def save_vector_store(self):
self.load_vector_store().save(self.vs_path)
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
with self.load_vector_store().acquire() as vs:
return [vs.docstore._dict.get(id) for id in ids]
def del_doc_by_ids(self, ids: List[str]) -> bool:
with self.load_vector_store().acquire() as vs:
vs.delete(ids)
def do_init(self):
self.vector_name = self.vector_name or self.embed_model
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
def do_create_kb(self):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
self.load_vector_store()
def do_drop_kb(self):
self.clear_vs()
try:
shutil.rmtree(self.kb_path)
except Exception:
...
def do_search(self,
query: str,
top_k: int,
score_threshold: float = SCORE_THRESHOLD,
expr: str = EXPR,
) -> List[Tuple[Document, float]]:
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
with self.load_vector_store().acquire() as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
return docs
def do_add_doc(self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间
with self.load_vector_store().acquire() as vs:
ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
metadatas=data["metadatas"],
ids=kwargs.get("ids"))
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
torch_gc()
return doc_infos
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
with self.load_vector_store().acquire() as vs:
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source").lower() == kb_file.filename.lower()]
if len(ids) > 0:
vs.delete(ids)
if not kwargs.get("not_refresh_vs_cache"):
vs.save_local(self.vs_path)
return ids
def do_clear_vs(self):
with kb_faiss_pool.atomic:
kb_faiss_pool.pop((self.kb_name, self.vector_name))
try:
shutil.rmtree(self.vs_path)
except Exception:
...
os.makedirs(self.vs_path, exist_ok=True)
def exist_doc(self, file_name: str):
if super().exist_doc(file_name):
return "in_db"
content_path = os.path.join(self.kb_path, "content")
if os.path.isfile(os.path.join(content_path, file_name)):
return "in_folder"
else:
return False
if __name__ == '__main__':
faissService = FaissKBService("test")
faissService.add_doc(KnowledgeFile("README.md", "test"))
faissService.delete_doc(KnowledgeFile("README.md", "test"))
faissService.do_drop_kb()
print(faissService.search_docs("如何启动api服务"))

View File

@@ -0,0 +1,207 @@
from typing import List, Dict, Optional
from langchain.schema import Document
from langchain.vectorstores.milvus import Milvus
import os
import logging
from configs import kbs_config
from server.db.repository import list_file_num_docs_id_by_kb_name_and_file_name
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
import numpy as np
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MilvusKBService(KBService):
milvus: Milvus
@staticmethod
def get_collection(milvus_name):
from pymilvus import Collection
return Collection(milvus_name)
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = []
if self.milvus and self.milvus.col:
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
data_list = self.milvus.col.query(expr=f'pk in {[int(_id) for _id in ids]}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
result.append(Document(page_content=text, metadata=data))
return result
def get_doc_by_sources_name(self, source_name_list: List[str]) -> List[Document]:
result = []
if self.milvus and self.milvus.col:
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
data_list = self.milvus.col.query(expr=f'source in {source_name_list}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
result.append(Document(page_content=text, metadata=data))
return result
def del_doc_by_ids(self, ids: List[str]) -> bool:
if self.milvus and self.milvus.col:
self.milvus.col.delete(expr=f'pk in {ids}')
@staticmethod
def search(milvus_name, content, limit=3):
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
c = MilvusKBService.get_collection(milvus_name)
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
def do_create_kb(self):
pass
def vs_type(self) -> str:
return SupportedVSType.MILVUS
def _load_milvus(self):
try:
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name,
connection_args=kbs_config.get("milvus"),
index_params=kbs_config.get("milvus_kwargs")["index_params"],
search_params=kbs_config.get("milvus_kwargs")["search_params"],
auto_id=True
)
logger.info("成功加载 Milvus 实例 'milvus'")
# -------- 兼容不同 schema 的文本字段 --------
# 新库尚无 Milvus 集合时 langchain_community.Milvus.col 为 None
# 会在首次 add_documents 建表后再有 schema此处勿访问 .col.schema。
try:
col = self.milvus.col
if col is None:
logger.debug(
"集合 %s 尚未在 Milvus 中建表,跳过文本字段探测(首次写入时会自动创建)",
self.kb_name,
)
else:
field_names = [f.name for f in col.schema.fields]
if self.milvus._text_field not in field_names:
if "page_content" in field_names:
self.milvus._text_field = "page_content"
elif "content" in field_names:
self.milvus._text_field = "content"
else:
for f in col.schema.fields:
if hasattr(f, "dtype") and str(f.dtype).startswith("DataType.VARCHAR"):
self.milvus._text_field = f.name
break
logger.info(f"集合 {self.kb_name} 使用文本字段: {self.milvus._text_field}")
except Exception as e:
logger.warning(f"检测并设置文本字段失败: {e}")
except Exception as e:
logger.error(f"加载 Milvus 实例 'milvus' 失败: {e}")
self._create_collection_if_not_exists()
# 重新加载
# self._load_milvus()
def _create_collection_if_not_exists(self):
"""根据传入字段创建 Milvus 集合"""
from pymilvus import Collection, CollectionSchema, FieldSchema, DataType
from langchain_community.vectorstores import Milvus
# 定义你的字段(根据你的需求修改)
fields = [
FieldSchema(name="pk", dtype=DataType.Int64, is_primary=True, auto_id=True),
FieldSchema(name="vector", dtype=DataType.FloatVector, dim=768), # dim 根据 embedding 模型调整
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=1024),
FieldSchema(name="metadata", dtype=DataType.VARCHAR, max_length=65535),
# 添加其他自定义字段...
]
schema = CollectionSchema(fields=fields, description=self.kb_name)
# 创建集合
collection = Collection(name=self.kb_name, schema=schema, using="default")
# 创建索引
index_params = kbs_config.get("milvus_kwargs")["index_params"]
collection.create_index(field_name="vector", index_params=index_params)
logger.info(f"成功创建 Milvus 集合: {self.kb_name}")
def do_init(self):
self._load_milvus()
def do_drop_kb(self):
if self.milvus and self.milvus.col:
self.milvus.col.release()
# self.milvus.col.drop() # 禁用从chatchat删除集合
def do_search(self, query: str, top_k: int, score_threshold: float, expr: str, custom_strategy_config: dict = {}):
self._load_milvus()
embed_func = EmbeddingsFunAdapter(self.embed_model)
try:
embeddings = embed_func.embed_query(query)
if top_k > 50:
# 按顺序返回全文内容
docs = self.milvus.similarity_search_by_vector(embeddings, top_k, expr = expr)
docs = sorted(docs, key=lambda doc: doc.metadata['pk']) # 根据 pk 从小到大排序
# return score_threshold_process(query,score_threshold, top_k, docs)
return docs
else:
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k, expr = expr)
# TODO 动态score_threshold
return score_threshold_process(query,score_threshold, top_k, docs)
except Exception as e:
logger.error(f"搜索 Milvus 集合 '{self.kb_name}' 失败: {e}")
return []
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs:
for k, v in doc.metadata.items():
doc.metadata[k] = str(v)
for field in self.milvus.fields:
doc.metadata.setdefault(field, "")
doc.metadata.pop(self.milvus._text_field, None)
doc.metadata.pop(self.milvus._vector_field, None)
ids = self.milvus.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
id_list = list_file_num_docs_id_by_kb_name_and_file_name(kb_file.kb_name, kb_file.filename)
if self.milvus and self.milvus.col:
self.milvus.col.delete(expr=f'pk in {id_list}')
# Issue 2846, for windows
# if self.milvus.col:
# file_path = kb_file.filepath.replace("\\", "\\\\")
# file_name = os.path.basename(file_path)
# id_list = [item.get("pk") for item in
# self.milvus.col.query(expr=f'source == "{file_name}"', output_fields=["pk"])]
# self.milvus.col.delete(expr=f'pk in {id_list}')
def do_clear_vs(self):
if self.milvus and self.milvus.col:
self.do_drop_kb()
self.do_init()
if __name__ == '__main__':
# 测试建表使用
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
milvusService = MilvusKBService("t_policy_total_bce_v1")
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
# print(milvusService.get_doc_by_ids(["444022434274215486"]))
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
# milvusService.do_drop_kb()
# print(milvusService.search_docs("如何启动api服务"))

View File

@@ -0,0 +1,96 @@
import json
from typing import List, Dict, Optional
from langchain.schema import Document
from langchain.vectorstores.pgvector import PGVector, DistanceStrategy
from sqlalchemy import text
from configs import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
import shutil
import sqlalchemy
from sqlalchemy.engine.base import Engine
from sqlalchemy.orm import Session
class PGKBService(KBService):
engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10)
def _load_pg_vector(self):
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name,
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection=PGKBService.engine,
connection_string=kbs_config.get("pg").get("connection_uri"))
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
with Session(PGKBService.engine) as session:
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
results = [Document(page_content=row[0], metadata=row[1]) for row in
session.execute(stmt, {'ids': ids}).fetchall()]
return results
def del_doc_by_ids(self, ids: List[str]) -> bool:
return super().del_doc_by_ids(ids)
def do_init(self):
self._load_pg_vector()
def do_create_kb(self):
pass
def vs_type(self) -> str:
return SupportedVSType.PG
def do_drop_kb(self):
with Session(PGKBService.engine) as session:
session.execute(text(f'''
-- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录
DELETE FROM langchain_pg_embedding
WHERE collection_id IN (
SELECT uuid FROM langchain_pg_collection WHERE name = '{self.kb_name}'
);
-- 删除 langchain_pg_collection 表中 记录
DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}';
'''))
session.commit()
shutil.rmtree(self.kb_path)
def do_search(self, query: str, top_k: int, score_threshold: float,expr:str):
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with Session(PGKBService.engine) as session:
filepath = kb_file.filepath.replace('\\', '\\\\')
session.execute(
text(
''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace(
"filepath", filepath)))
session.commit()
def do_clear_vs(self):
self.pg_vector.delete_collection()
self.pg_vector.create_collection()
if __name__ == '__main__':
from server.db.base import Base, engine
# Base.metadata.create_all(bind=engine)
pGKBService = PGKBService("test")
# pGKBService.create_kb()
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
# pGKBService.drop_kb()
print(pGKBService.get_doc_by_ids(["f1e51390-3029-4a19-90dc-7118aaa25772"]))
# print(pGKBService.search_docs("如何启动api服务"))

View File

@@ -0,0 +1,97 @@
from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import Zilliz
from configs import kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
class ZillizKBService(KBService):
zilliz: Zilliz
@staticmethod
def get_collection(zilliz_name):
from pymilvus import Collection
return Collection(zilliz_name)
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = []
if self.zilliz.col:
# ids = [int(id) for id in ids] # for zilliz if needed #pr 2725
data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
result.append(Document(page_content=text, metadata=data))
return result
def del_doc_by_ids(self, ids: List[str]) -> bool:
self.zilliz.col.delete(expr=f'pk in {ids}')
@staticmethod
def search(zilliz_name, content, limit=3):
search_params = {
"metric_type": "IP",
"params": {},
}
c = ZillizKBService.get_collection(zilliz_name)
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
def do_create_kb(self):
pass
def vs_type(self) -> str:
return SupportedVSType.ZILLIZ
def _load_zilliz(self):
zilliz_args = kbs_config.get("zilliz")
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name, connection_args=zilliz_args)
def do_init(self):
self._load_zilliz()
def do_drop_kb(self):
if self.zilliz.col:
self.zilliz.col.release()
self.zilliz.col.drop()
def do_search(self, query: str, top_k: int, score_threshold: float,expr:str):
self._load_zilliz()
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs:
for k, v in doc.metadata.items():
doc.metadata[k] = str(v)
for field in self.zilliz.fields:
doc.metadata.setdefault(field, "")
doc.metadata.pop(self.zilliz._text_field, None)
doc.metadata.pop(self.zilliz._vector_field, None)
ids = self.zilliz.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
if self.zilliz.col:
filepath = kb_file.filepath.replace('\\', '\\\\')
delete_list = [item.get("pk") for item in
self.zilliz.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
self.zilliz.col.delete(expr=f'pk in {delete_list}')
def do_clear_vs(self):
if self.zilliz.col:
self.do_drop_kb()
self.do_init()
if __name__ == '__main__':
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
zillizService = ZillizKBService("test")