[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
775
langchain-chat/server/knowledge_base/kb_service/base.py
Normal file
775
langchain-chat/server/knowledge_base/kb_service/base.py
Normal file
@@ -0,0 +1,775 @@
|
||||
from datetime import datetime
|
||||
import operator
|
||||
from abc import ABC, abstractmethod
|
||||
import re
|
||||
import os
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
from joblib import Parallel, delayed
|
||||
import time
|
||||
from textrank4zh import TextRank4Keyword, TextRank4Sentence
|
||||
import multiprocessing
|
||||
|
||||
from configs.model_config import LLM_MODELS
|
||||
from server.chat.policy_fun_iast import get_llm_model_response
|
||||
from server.chat.utils import get_personal_knowledge_map, get_similar_documents1
|
||||
from nltk.tokenize import sent_tokenize
|
||||
import logging
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||||
|
||||
def generate_weights_as_list(length, total_sum=80):
|
||||
if length == 1:
|
||||
return [50]
|
||||
# Create a range of numbers decreasing logarithmically
|
||||
x = np.linspace(0, length - 1, length)
|
||||
weights = np.exp(-x / (length / 5))
|
||||
|
||||
# Normalize the weights to match the specified sum
|
||||
normalized_weights = weights / sum(weights) * total_sum
|
||||
integer_weights = np.round(normalized_weights).astype(int)
|
||||
|
||||
# Adjust the weights to match the exact sum if rounding causes deviation
|
||||
adjustment = total_sum - sum(integer_weights)
|
||||
for i in range(abs(adjustment)):
|
||||
if adjustment > 0:
|
||||
integer_weights[i % length] += 1
|
||||
elif adjustment < 0:
|
||||
integer_weights[i % length] -= 1
|
||||
|
||||
return integer_weights.tolist()
|
||||
|
||||
def score_threshold_process(query,score_threshold, k, docs):
|
||||
"""
|
||||
根据分数阈值过滤和使用TextRank摘要文档,并返回前k个文档。
|
||||
|
||||
:param score_threshold: 相似度分数阈值;忽略低于此阈值的文档。
|
||||
:param k: 要返回的顶部文档数量。
|
||||
:param docs: 文档列表,每个文档是一个元组(文档,相似度分数)。
|
||||
:return: 根据分数阈值返回的前k个文档的列表。
|
||||
"""
|
||||
# 如果提供了score_threshold,则只过滤大于阈值的文档。
|
||||
if score_threshold is not None:
|
||||
cmp = (
|
||||
operator.le
|
||||
)
|
||||
docs = [
|
||||
(doc, similarity)
|
||||
for doc, similarity in docs
|
||||
if cmp(similarity, score_threshold)
|
||||
]
|
||||
# 当召回结果都大于score_threshold时
|
||||
if len(docs) == 0:
|
||||
return docs
|
||||
result = []
|
||||
try:
|
||||
for doc in docs:
|
||||
if query.replace(" ","").replace("\n","").replace("\r","") in doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r",""):
|
||||
result.append(doc)
|
||||
|
||||
except Exception as e:
|
||||
for doc in docs:
|
||||
if query.replace(" ","").replace("\n","").replace("\r","") in doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r",""):
|
||||
result.append(doc)
|
||||
|
||||
if len(docs) > 0 and not "h1" in docs[0][0].metadata:
|
||||
# 如果存在用户的问题在标题中的情况则进行去重操作且不需要再匹配相关度,只需要把问题在标题中的文献提交出去
|
||||
if len(result) > 0:
|
||||
temp={}
|
||||
for doc in result:
|
||||
if doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") not in temp:
|
||||
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")] = doc
|
||||
else:
|
||||
if temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
elif temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
else:
|
||||
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
|
||||
docs = []
|
||||
for i in temp:
|
||||
docs.append(temp[i])
|
||||
else:
|
||||
try:
|
||||
sentences = []
|
||||
sentences_page_content = []
|
||||
for doc in docs:
|
||||
meta = doc[0].metadata
|
||||
# 若缺少标题或为空,则用正文首句作为标题(最多50字)
|
||||
if meta["title"] == "":
|
||||
meta["title"] = doc[0].page_content
|
||||
# 如有摘要则替换 page_content,保证后续文本更简洁
|
||||
summary = meta.get("summary")
|
||||
if summary:
|
||||
doc[0].page_content = summary
|
||||
sentences = [doc[0].metadata["title"] for doc in docs]
|
||||
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["title"]+doc[0].page_content+"】" for i,doc in enumerate(docs)]
|
||||
except Exception as e:
|
||||
sentences = [doc[0].metadata["source"] for doc in docs]
|
||||
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"】" for i,doc in enumerate(docs)]
|
||||
res = get_llm_model_response(
|
||||
strategy_name="default_similar",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="default_similar",
|
||||
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
|
||||
temperature=0.01,
|
||||
max_tokens=512
|
||||
)
|
||||
try:
|
||||
index =[]
|
||||
if res == "无":
|
||||
index = []
|
||||
else:
|
||||
index = res.split(",")
|
||||
index = [int(i)-1 for i in index]
|
||||
docs = get_similar_documents1(index=index,sentences=sentences,query=query, docs=docs, top_k=k)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
docs = get_similar_documents1(index=[],sentences=sentences,query=query, docs=docs, top_k=k)
|
||||
|
||||
|
||||
# 去重操作只针对通用知识库
|
||||
temp={}
|
||||
for doc in docs:
|
||||
try:
|
||||
if doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") not in temp:
|
||||
temp[doc[0].metadata["title"]] = doc
|
||||
else:
|
||||
if temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
elif temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
else:
|
||||
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") not in temp:
|
||||
temp[doc[0].metadata["source"]] = doc
|
||||
else:
|
||||
if temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
elif temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
else:
|
||||
temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
|
||||
docs = []
|
||||
for i in temp:
|
||||
docs.append(temp[i])
|
||||
|
||||
|
||||
|
||||
|
||||
#只针对个人知识库
|
||||
if "h1" in docs[0][0].metadata:
|
||||
all_source = [doc[0].metadata["source"] for doc in docs]
|
||||
unique_source = list(set(all_source))
|
||||
all_title_map = get_personal_knowledge_map(unique_source)
|
||||
for doc in docs:
|
||||
doc[0].metadata["uuid_name"] = doc[0].metadata["source"]
|
||||
if doc[0].metadata["source"] in all_title_map:
|
||||
doc[0].metadata["source"] = all_title_map[doc[0].metadata["source"]]
|
||||
else:
|
||||
pass
|
||||
try:
|
||||
sentences = [doc[0].metadata["source"] for doc in docs]
|
||||
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"】" for i,doc in enumerate(docs)]
|
||||
except Exception as e:
|
||||
sentences = [doc[0].metadata["source"] for doc in docs]
|
||||
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"】" for i,doc in enumerate(docs)]
|
||||
kwargs = {}
|
||||
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}}
|
||||
res = get_llm_model_response(
|
||||
strategy_name="default_similar",
|
||||
llm_model_name=LLM_MODELS[0],
|
||||
template_prompt_name="default_similar",
|
||||
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
|
||||
temperature=0.01,
|
||||
max_tokens=None,
|
||||
**kwargs
|
||||
)
|
||||
res = re.sub(r'<think>.*?</think>', '', res,flags=re.DOTALL)
|
||||
try:
|
||||
index =[]
|
||||
if res == "无":
|
||||
index = []
|
||||
else:
|
||||
index = res.split(",")
|
||||
index = [int(i)-1 for i in index]
|
||||
docs = get_similar_documents1(index=index,sentences=sentences,query=query, docs=docs, top_k=k)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
docs = get_similar_documents1(index=[],sentences=sentences,query=query, docs=docs, top_k=k)
|
||||
|
||||
|
||||
# 去重操作只针对通用知识库
|
||||
temp={}
|
||||
for doc in docs:
|
||||
if doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") not in temp:
|
||||
temp[doc[0].metadata["source"]] = doc
|
||||
else:
|
||||
if temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
elif temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") :
|
||||
continue
|
||||
else:
|
||||
temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
|
||||
docs = []
|
||||
for i in temp:
|
||||
docs.append(temp[i])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if len(docs) == 0:
|
||||
return docs
|
||||
# 为TextRank算法生成权重。
|
||||
cont = generate_weights_as_list(len(docs))
|
||||
|
||||
# 处理每个文档以提取或分配摘要。
|
||||
for i, (doc, _) in enumerate(docs):
|
||||
summary_sources = ['content', 'abstract', 'text'] # 根据不同知识库遍历字段
|
||||
for source in summary_sources:
|
||||
try:
|
||||
if docs[i][0].metadata["title"] in docs[i][0].page_content or len(docs[i][0].page_content) < 100:
|
||||
doc.metadata['summary'] = TextRank(doc.metadata[source], cont[i])
|
||||
if len(doc.metadata['summary']) >15000:
|
||||
doc.metadata['summary'] = TextRank(doc.metadata[source], 1)
|
||||
break
|
||||
else:
|
||||
doc.metadata['summary'] = docs[i][0].page_content
|
||||
except KeyError:
|
||||
doc.metadata['summary'] = docs[i][0].page_content
|
||||
continue # 如果当前源失败,则尝试下一个源。
|
||||
|
||||
# 返回前k个文档。
|
||||
return docs[:k]
|
||||
|
||||
# 猴子补丁,为了兼容TexRank,
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
# 进行猴子补丁
|
||||
nx.from_numpy_matrix = nx.from_numpy_array
|
||||
# 进行猴子补丁,入数据类型兼容性检查
|
||||
|
||||
def process_text_segment(text_segment, num_sentences):
|
||||
tr4w = TextRank4Keyword()
|
||||
tr4w.analyze(text=text_segment, lower=True, window=5)
|
||||
keywords = [(item.word, item.weight) for item in tr4w.get_keywords(30, word_min_len=4)]
|
||||
|
||||
tr4s = TextRank4Sentence()
|
||||
tr4s.analyze(text=text_segment, lower=True, source='all_filters')
|
||||
summaries = [item.sentence for item in tr4s.get_key_sentences(num=num_sentences)]
|
||||
|
||||
return keywords, summaries
|
||||
|
||||
def split_text_balanced(text, n_parts):
|
||||
sentences = sent_tokenize(text)
|
||||
min_sentences_per_part = 10
|
||||
n_parts = max(1, min(n_parts, len(sentences) // min_sentences_per_part))
|
||||
k, m = divmod(len(sentences), n_parts)
|
||||
return [' '.join(sentences[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n_parts)]
|
||||
|
||||
|
||||
def TextRank(text,num_sentences, n_cores=multiprocessing.cpu_count()):
|
||||
start_time = time.time()
|
||||
text_parts = split_text_balanced(text, n_cores)
|
||||
all_keywords = []
|
||||
all_summaries = []
|
||||
# 在这里直接顺序处理每个分块,或用线程池而非进程池
|
||||
for part in text_parts:
|
||||
keywords, summaries = process_text_segment(part, num_sentences)
|
||||
all_keywords.extend(keywords)
|
||||
all_summaries.extend(summaries)
|
||||
|
||||
# Print results
|
||||
# print('关键词:')
|
||||
for word, weight in sorted(all_keywords, key=lambda x: x[1], reverse=True):
|
||||
print(word, weight)
|
||||
|
||||
end_time = time.time()
|
||||
logging.info(f"TextRank耗时: {end_time - start_time:.2f}秒")
|
||||
all_summaries = "".join(all_summaries)
|
||||
return all_summaries
|
||||
|
||||
|
||||
from server.db.repository.knowledge_base_repository import (
|
||||
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
|
||||
load_kb_from_db, get_kb_detail,
|
||||
)
|
||||
from server.db.repository.knowledge_file_repository import (
|
||||
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
|
||||
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
|
||||
list_docs_from_db,
|
||||
)
|
||||
|
||||
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,EXPR,
|
||||
EMBEDDING_MODEL, KB_INFO)
|
||||
from server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, KnowledgeFile,
|
||||
list_kbs_from_folder, list_files_from_folder,
|
||||
)
|
||||
|
||||
from typing import List, Union, Dict, Optional, Tuple
|
||||
|
||||
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
|
||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||
import time
|
||||
|
||||
|
||||
def get_emb_time(f):
|
||||
def inner(*arg,**kwarg):
|
||||
s_time = time.time()
|
||||
res = f(*arg,**kwarg)
|
||||
e_time = time.time()
|
||||
print('向量化耗时:{}秒'.format(e_time - s_time))
|
||||
return res
|
||||
return inner
|
||||
|
||||
def normalize(embeddings: List[List[float]]) -> np.ndarray:
|
||||
'''
|
||||
sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn
|
||||
'''
|
||||
# 过滤掉 None 值
|
||||
embeddings = [e for e in embeddings if e is not None]
|
||||
if not embeddings:
|
||||
raise ValueError("No valid embeddings found (all are None)")
|
||||
embeddings = np.array(embeddings)
|
||||
norm = np.linalg.norm(embeddings, axis=1)
|
||||
norm = np.reshape(norm, (norm.shape[0], 1))
|
||||
norm = np.tile(norm, (1, len(embeddings[0])))
|
||||
return np.divide(embeddings, norm)
|
||||
|
||||
|
||||
class SupportedVSType:
|
||||
FAISS = 'faiss'
|
||||
MILVUS = 'milvus'
|
||||
DEFAULT = 'default'
|
||||
ZILLIZ = 'zilliz'
|
||||
PG = 'pg'
|
||||
ES = 'es'
|
||||
CHROMADB = 'chromadb'
|
||||
|
||||
|
||||
class KBService(ABC):
|
||||
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
|
||||
self.embed_model = embed_model
|
||||
self.kb_path = get_kb_path(self.kb_name)
|
||||
self.doc_path = get_doc_path(self.kb_name)
|
||||
self.do_init()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.kb_name} @ {self.embed_model}"
|
||||
|
||||
def save_vector_store(self):
|
||||
'''
|
||||
保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持
|
||||
'''
|
||||
pass
|
||||
|
||||
def create_kb(self):
|
||||
"""
|
||||
创建知识库
|
||||
"""
|
||||
if not os.path.exists(self.doc_path):
|
||||
os.makedirs(self.doc_path)
|
||||
self.do_create_kb()
|
||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||
return status
|
||||
|
||||
def clear_vs(self):
|
||||
"""
|
||||
删除向量库中所有内容
|
||||
"""
|
||||
self.do_clear_vs()
|
||||
status = delete_files_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
def drop_kb(self):
|
||||
"""
|
||||
删除知识库
|
||||
"""
|
||||
self.do_drop_kb()
|
||||
status = delete_kb_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
|
||||
'''
|
||||
将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
|
||||
'''
|
||||
return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
|
||||
|
||||
@get_emb_time
|
||||
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||
"""
|
||||
向知识库添加文件
|
||||
如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True
|
||||
"""
|
||||
if docs:
|
||||
custom_docs = True
|
||||
for doc in docs:
|
||||
doc.metadata.setdefault("source", kb_file.filename)
|
||||
else:
|
||||
docs = kb_file.file2text()
|
||||
custom_docs = False
|
||||
|
||||
if docs:
|
||||
# 将 metadata["source"] 改为相对路径
|
||||
for doc in docs:
|
||||
try:
|
||||
source = doc.metadata.get("source", "")
|
||||
if os.path.isabs(source):
|
||||
rel_path = Path(source).relative_to(self.doc_path)
|
||||
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
|
||||
except Exception as e:
|
||||
print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")
|
||||
self.delete_doc(kb_file)
|
||||
doc_infos = self.do_add_doc(docs, **kwargs)
|
||||
status = add_file_to_db(kb_file,
|
||||
custom_docs=custom_docs,
|
||||
docs_count=len(docs),
|
||||
doc_infos=doc_infos)
|
||||
else:
|
||||
status = False
|
||||
return status
|
||||
|
||||
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
|
||||
"""
|
||||
从知识库删除文件
|
||||
"""
|
||||
self.do_delete_doc(kb_file, **kwargs)
|
||||
status = delete_file_from_db(kb_file)
|
||||
if delete_content and os.path.exists(kb_file.filepath):
|
||||
os.remove(kb_file.filepath)
|
||||
return status
|
||||
|
||||
def update_info(self, kb_info: str):
|
||||
"""
|
||||
更新知识库介绍
|
||||
"""
|
||||
self.kb_info = kb_info
|
||||
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
|
||||
return status
|
||||
|
||||
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
|
||||
"""
|
||||
使用content中的文件更新向量库
|
||||
如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True
|
||||
"""
|
||||
if os.path.exists(kb_file.filepath):
|
||||
self.delete_doc(kb_file, **kwargs)
|
||||
return self.add_doc(kb_file, docs=docs, **kwargs)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
filename=file_name))
|
||||
|
||||
def list_files(self):
|
||||
return list_files_from_db(self.kb_name)
|
||||
|
||||
def count_files(self):
|
||||
return count_files_from_db(self.kb_name)
|
||||
|
||||
def search_docs(self,
|
||||
query: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
expr: str = EXPR,
|
||||
custom_strategy_config: dict = {}
|
||||
) ->List[Document]:
|
||||
docs = self.do_search(query, top_k, score_threshold, expr, custom_strategy_config)
|
||||
return docs
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
return []
|
||||
|
||||
def get_doc_by_sources_name(self, source_name_list: List[str]) -> List[Document]:
|
||||
return []
|
||||
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool:
|
||||
'''
|
||||
传入参数为: {doc_id: Document, ...}
|
||||
如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档
|
||||
'''
|
||||
self.del_doc_by_ids(list(docs.keys()))
|
||||
docs = []
|
||||
ids = []
|
||||
for k, v in docs.items():
|
||||
if not v or not v.page_content.strip():
|
||||
continue
|
||||
ids.append(k)
|
||||
docs.append(v)
|
||||
self.do_add_doc(docs=docs, ids=ids)
|
||||
return True
|
||||
|
||||
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
|
||||
'''
|
||||
通过file_name或metadata检索Document
|
||||
'''
|
||||
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
||||
docs = []
|
||||
for x in doc_infos:
|
||||
doc_info = self.get_doc_by_ids([x["id"]])[0]
|
||||
if doc_info is not None:
|
||||
# 处理非空的情况
|
||||
doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"])
|
||||
docs.append(doc_with_id)
|
||||
else:
|
||||
# 处理空的情况
|
||||
# 可以选择跳过当前循环迭代或执行其他操作
|
||||
pass
|
||||
return docs
|
||||
|
||||
@abstractmethod
|
||||
def do_create_kb(self):
|
||||
"""
|
||||
创建知识库子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def list_kbs_type():
|
||||
return list(kbs_config.keys())
|
||||
|
||||
@classmethod
|
||||
def list_kbs(cls):
|
||||
return list_kbs_from_db()
|
||||
|
||||
def exists(self, kb_name: str = None):
|
||||
kb_name = kb_name or self.kb_name
|
||||
return kb_exists(kb_name)
|
||||
|
||||
@abstractmethod
|
||||
def vs_type(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_init(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_drop_kb(self):
|
||||
"""
|
||||
删除知识库子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_search(self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float,
|
||||
expr: str,
|
||||
custom_strategy_config: dict = {},
|
||||
) -> List[Tuple[Document, float]]:
|
||||
"""
|
||||
搜索知识库子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
向知识库添加文档子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile):
|
||||
"""
|
||||
从知识库删除文档子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_clear_vs(self):
|
||||
"""
|
||||
从知识库删除全部向量子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class KBServiceFactory:
|
||||
|
||||
@staticmethod
|
||||
def get_service(kb_name: str,
|
||||
vector_store_type: Union[str, SupportedVSType],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
) -> KBService:
|
||||
if isinstance(vector_store_type, str):
|
||||
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
|
||||
if SupportedVSType.FAISS == vector_store_type:
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
return FaissKBService(kb_name, embed_model=embed_model)
|
||||
elif SupportedVSType.PG == vector_store_type:
|
||||
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
|
||||
return PGKBService(kb_name, embed_model=embed_model)
|
||||
elif SupportedVSType.MILVUS == vector_store_type:
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name,embed_model=embed_model)
|
||||
elif SupportedVSType.ZILLIZ == vector_store_type:
|
||||
from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
|
||||
return ZillizKBService(kb_name, embed_model=embed_model)
|
||||
elif SupportedVSType.DEFAULT == vector_store_type:
|
||||
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||
return MilvusKBService(kb_name,
|
||||
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
|
||||
elif SupportedVSType.ES == vector_store_type:
|
||||
from server.knowledge_base.kb_service.es_kb_service import ESKBService
|
||||
return ESKBService(kb_name, embed_model=embed_model)
|
||||
elif SupportedVSType.CHROMADB == vector_store_type:
|
||||
from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
|
||||
return ChromaKBService(kb_name, embed_model=embed_model)
|
||||
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
|
||||
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||
return DefaultKBService(kb_name)
|
||||
|
||||
@staticmethod
|
||||
def get_service_by_name(kb_name: str) -> KBService:
|
||||
_, vs_type, embed_model = load_kb_from_db(kb_name)
|
||||
if _ is None: # kb not in db, just return None
|
||||
return None
|
||||
from server.utils import resolve_embed_model_name
|
||||
|
||||
embed_model = resolve_embed_model_name(embed_model)
|
||||
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
|
||||
|
||||
@staticmethod
|
||||
def get_default():
|
||||
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
|
||||
|
||||
|
||||
def get_kb_details() -> List[Dict]:
|
||||
kbs_in_folder = list_kbs_from_folder()
|
||||
kbs_in_db = KBService.list_kbs()
|
||||
result = {}
|
||||
|
||||
for kb in kbs_in_folder:
|
||||
result[kb] = {
|
||||
"kb_name": kb,
|
||||
"vs_type": "",
|
||||
"kb_info": "",
|
||||
"embed_model": "",
|
||||
"file_count": 0,
|
||||
"create_time": None,
|
||||
"in_folder": True,
|
||||
"in_db": False,
|
||||
}
|
||||
|
||||
for kb in kbs_in_db:
|
||||
kb_detail = get_kb_detail(kb)
|
||||
if kb_detail:
|
||||
kb_detail["in_db"] = True
|
||||
if kb in result:
|
||||
result[kb].update(kb_detail)
|
||||
else:
|
||||
kb_detail["in_folder"] = False
|
||||
result[kb] = kb_detail
|
||||
|
||||
data = []
|
||||
for i, v in enumerate(result.values()):
|
||||
v['No'] = i + 1
|
||||
data.append(v)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def get_kb_file_details(kb_name: str) -> List[Dict]:
|
||||
kb = KBServiceFactory.get_service_by_name(kb_name)
|
||||
if kb is None:
|
||||
return []
|
||||
|
||||
files_in_folder = list_files_from_folder(kb_name)
|
||||
files_in_db = kb.list_files()
|
||||
result = {}
|
||||
|
||||
for doc in files_in_folder:
|
||||
result[doc] = {
|
||||
"kb_name": kb_name,
|
||||
"file_name": doc,
|
||||
"file_ext": os.path.splitext(doc)[-1],
|
||||
"file_version": 0,
|
||||
"document_loader": "",
|
||||
"docs_count": 0,
|
||||
"text_splitter": "",
|
||||
"create_time": None,
|
||||
"in_folder": True,
|
||||
"in_db": False,
|
||||
}
|
||||
lower_names = {x.lower(): x for x in result}
|
||||
for doc in files_in_db:
|
||||
doc_detail = get_file_detail(kb_name, doc)
|
||||
if doc_detail:
|
||||
doc_detail["in_db"] = True
|
||||
if doc.lower() in lower_names:
|
||||
result[lower_names[doc.lower()]].update(doc_detail)
|
||||
else:
|
||||
doc_detail["in_folder"] = False
|
||||
result[doc] = doc_detail
|
||||
|
||||
data = []
|
||||
for i, v in enumerate(result.values()):
|
||||
v['No'] = i + 1
|
||||
data.append(v)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingsFunAdapter(Embeddings):
|
||||
def __init__(self, embed_model: str = EMBEDDING_MODEL):
|
||||
self.embed_model = embed_model
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
result = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False)
|
||||
embeddings = result.data if result and hasattr(result, 'data') else None
|
||||
if not embeddings:
|
||||
raise ValueError(f"Failed to get embeddings for texts: {texts[:2]}...")
|
||||
return normalize(embeddings).tolist()
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
|
||||
query_embed = embeddings[0]
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = normalize(query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
|
||||
return normalize(embeddings).tolist()
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
|
||||
query_embed = embeddings[0]
|
||||
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||
normalized_query_embed = normalize(query_embed_2d)
|
||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||
|
||||
|
||||
# def score_threshold_process(score_threshold, k, docs):
|
||||
# if score_threshold is not None:
|
||||
# cmp = (
|
||||
# operator.le
|
||||
# )
|
||||
# docs = [
|
||||
# (doc, similarity)
|
||||
# for doc, similarity in docs
|
||||
# if cmp(similarity, score_threshold)
|
||||
# ]
|
||||
# return docs[:k]
|
||||
@@ -0,0 +1,105 @@
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.types import (GetResult, QueryResult)
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from configs import SCORE_THRESHOLD
|
||||
from server.knowledge_base.kb_service.base import (EmbeddingsFunAdapter,
|
||||
KBService, SupportedVSType)
|
||||
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||
|
||||
|
||||
def _get_result_to_documents(get_result: GetResult) -> List[Document]:
|
||||
if not get_result['documents']:
|
||||
return []
|
||||
|
||||
_metadatas = get_result['metadatas'] if get_result['metadatas'] else [{}] * len(get_result['documents'])
|
||||
|
||||
document_list = []
|
||||
for page_content, metadata in zip(get_result['documents'], _metadatas):
|
||||
document_list.append(Document(**{'page_content': page_content, 'metadata': metadata}))
|
||||
|
||||
return document_list
|
||||
|
||||
|
||||
def _results_to_docs_and_scores(results: Any) -> List[Tuple[Document, float]]:
|
||||
"""
|
||||
from langchain_community.vectorstores.chroma import Chroma
|
||||
"""
|
||||
return [
|
||||
# TODO: Chroma can do batch querying,
|
||||
(Document(page_content=result[0], metadata=result[1] or {}), result[2])
|
||||
for result in zip(
|
||||
results["documents"][0],
|
||||
results["metadatas"][0],
|
||||
results["distances"][0],
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class ChromaKBService(KBService):
|
||||
vs_path: str
|
||||
kb_path: str
|
||||
|
||||
client = None
|
||||
collection = None
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.CHROMADB
|
||||
|
||||
def get_vs_path(self) -> str:
|
||||
return get_vs_path(self.kb_name, self.embed_model)
|
||||
|
||||
def get_kb_path(self) -> str:
|
||||
return get_kb_path(self.kb_name)
|
||||
|
||||
def do_init(self) -> None:
|
||||
self.kb_path = self.get_kb_path()
|
||||
self.vs_path = self.get_vs_path()
|
||||
self.client = chromadb.PersistentClient(path=self.vs_path)
|
||||
self.collection = self.client.get_or_create_collection(self.kb_name)
|
||||
|
||||
def do_create_kb(self) -> None:
|
||||
# In ChromaDB, creating a KB is equivalent to creating a collection
|
||||
self.collection = self.client.get_or_create_collection(self.kb_name)
|
||||
|
||||
def do_drop_kb(self):
|
||||
# Dropping a KB is equivalent to deleting a collection in ChromaDB
|
||||
try:
|
||||
self.client.delete_collection(self.kb_name)
|
||||
except ValueError as e:
|
||||
if not str(e) == f"Collection {self.kb_name} does not exist.":
|
||||
raise e
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD, expr: str) -> List[
|
||||
Tuple[Document, float]]:
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
query_result: QueryResult = self.collection.query(query_embeddings=embeddings, n_results=top_k)
|
||||
return _results_to_docs_and_scores(query_result)
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
doc_infos = []
|
||||
data = self._docs_to_embeddings(docs)
|
||||
ids = [str(uuid.uuid1()) for _ in range(len(data["texts"]))]
|
||||
for _id, text, embedding, metadata in zip(ids, data["texts"], data["embeddings"], data["metadatas"]):
|
||||
self.collection.add(ids=_id, embeddings=embedding, metadatas=metadata, documents=text)
|
||||
doc_infos.append({"id": _id, "metadata": metadata})
|
||||
return doc_infos
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
get_result: GetResult = self.collection.get(ids=ids)
|
||||
return _get_result_to_documents(get_result)
|
||||
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
self.collection.delete(ids=ids)
|
||||
return True
|
||||
|
||||
def do_clear_vs(self):
|
||||
# Clearing the vector store might be equivalent to dropping and recreating the collection
|
||||
self.do_drop_kb()
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
return self.collection.delete(where={"source": kb_file.filepath})
|
||||
@@ -0,0 +1,38 @@
|
||||
from typing import List
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBService
|
||||
|
||||
|
||||
class DefaultKBService(KBService):
|
||||
def do_create_kb(self):
|
||||
pass
|
||||
|
||||
def do_drop_kb(self):
|
||||
pass
|
||||
|
||||
def do_add_doc(self, docs: List[Document]):
|
||||
pass
|
||||
|
||||
def do_clear_vs(self):
|
||||
pass
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return "default"
|
||||
|
||||
def do_init(self):
|
||||
pass
|
||||
|
||||
def do_search(self):
|
||||
pass
|
||||
|
||||
def do_insert_multi_knowledge(self):
|
||||
pass
|
||||
|
||||
def do_insert_one_knowledge(self):
|
||||
pass
|
||||
|
||||
def do_delete_doc(self):
|
||||
pass
|
||||
261
langchain-chat/server/knowledge_base/kb_service/es_kb_service.py
Normal file
261
langchain-chat/server/knowledge_base/kb_service/es_kb_service.py
Normal file
@@ -0,0 +1,261 @@
|
||||
from typing import List
|
||||
import os
|
||||
import shutil
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.elasticsearch import ElasticsearchStore
|
||||
from configs import KB_ROOT_PATH, EMBEDDING_MODEL, EMBEDDING_DEVICE, CACHED_VS_NUM
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
from server.utils import load_local_embeddings
|
||||
from elasticsearch import Elasticsearch,BadRequestError
|
||||
from configs import logger
|
||||
from configs import kbs_config
|
||||
|
||||
class ESKBService(KBService):
|
||||
|
||||
def do_init(self):
|
||||
self.kb_path = self.get_kb_path(self.kb_name)
|
||||
self.index_name = os.path.split(self.kb_path)[-1]
|
||||
self.IP = kbs_config[self.vs_type()]['host']
|
||||
self.PORT = kbs_config[self.vs_type()]['port']
|
||||
self.user = kbs_config[self.vs_type()].get("user",'')
|
||||
self.password = kbs_config[self.vs_type()].get("password",'')
|
||||
self.dims_length = kbs_config[self.vs_type()].get("dims_length",None)
|
||||
self.embeddings_model = load_local_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||
try:
|
||||
# ES python客户端连接(仅连接)
|
||||
if self.user != "" and self.password != "":
|
||||
self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}",
|
||||
basic_auth=(self.user,self.password))
|
||||
else:
|
||||
logger.warning("ES未配置用户名和密码")
|
||||
self.es_client_python = Elasticsearch(f"http://{self.IP}:{self.PORT}")
|
||||
except ConnectionError:
|
||||
logger.error("连接到 Elasticsearch 失败!")
|
||||
raise ConnectionError
|
||||
except Exception as e:
|
||||
logger.error(f"Error 发生 : {e}")
|
||||
raise e
|
||||
try:
|
||||
# 首先尝试通过es_client_python创建
|
||||
mappings = {
|
||||
"properties": {
|
||||
"dense_vector": {
|
||||
"type": "dense_vector",
|
||||
"dims": self.dims_length,
|
||||
"index": True
|
||||
}
|
||||
}
|
||||
}
|
||||
self.es_client_python.indices.create(index=self.index_name, mappings=mappings)
|
||||
except BadRequestError as e:
|
||||
logger.error("创建索引失败,重新")
|
||||
logger.error(e)
|
||||
|
||||
try:
|
||||
# langchain ES 连接、创建索引
|
||||
if self.user != "" and self.password != "":
|
||||
self.db_init = ElasticsearchStore(
|
||||
es_url=f"http://{self.IP}:{self.PORT}",
|
||||
index_name=self.index_name,
|
||||
query_field="context",
|
||||
vector_query_field="dense_vector",
|
||||
embedding=self.embeddings_model,
|
||||
es_user=self.user,
|
||||
es_password=self.password
|
||||
)
|
||||
else:
|
||||
logger.warning("ES未配置用户名和密码")
|
||||
self.db_init = ElasticsearchStore(
|
||||
es_url=f"http://{self.IP}:{self.PORT}",
|
||||
index_name=self.index_name,
|
||||
query_field="context",
|
||||
vector_query_field="dense_vector",
|
||||
embedding=self.embeddings_model,
|
||||
)
|
||||
except ConnectionError:
|
||||
print("### 初始化 Elasticsearch 失败!")
|
||||
logger.error("### 初始化 Elasticsearch 失败!")
|
||||
raise ConnectionError
|
||||
except Exception as e:
|
||||
logger.error(f"Error 发生 : {e}")
|
||||
raise e
|
||||
try:
|
||||
# 尝试通过db_init创建索引
|
||||
self.db_init._create_index_if_not_exists(
|
||||
index_name=self.index_name,
|
||||
dims_length=self.dims_length
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("创建索引失败...")
|
||||
logger.error(e)
|
||||
# raise e
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_kb_path(knowledge_base_name: str):
|
||||
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||
|
||||
@staticmethod
|
||||
def get_vs_path(knowledge_base_name: str):
|
||||
return os.path.join(ESKBService.get_kb_path(knowledge_base_name), "vector_store")
|
||||
|
||||
def do_create_kb(self):
|
||||
if os.path.exists(self.doc_path):
|
||||
if not os.path.exists(os.path.join(self.kb_path, "vector_store")):
|
||||
os.makedirs(os.path.join(self.kb_path, "vector_store"))
|
||||
else:
|
||||
logger.warning("directory `vector_store` already exists.")
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.ES
|
||||
|
||||
def _load_es(self, docs, embed_model):
|
||||
# 将docs写入到ES中
|
||||
try:
|
||||
# 连接 + 同时写入文档
|
||||
if self.user != "" and self.password != "":
|
||||
self.db = ElasticsearchStore.from_documents(
|
||||
documents=docs,
|
||||
embedding=embed_model,
|
||||
es_url= f"http://{self.IP}:{self.PORT}",
|
||||
index_name=self.index_name,
|
||||
distance_strategy="COSINE",
|
||||
query_field="context",
|
||||
vector_query_field="dense_vector",
|
||||
verify_certs=False,
|
||||
es_user=self.user,
|
||||
es_password=self.password
|
||||
)
|
||||
else:
|
||||
self.db = ElasticsearchStore.from_documents(
|
||||
documents=docs,
|
||||
embedding=embed_model,
|
||||
es_url= f"http://{self.IP}:{self.PORT}",
|
||||
index_name=self.index_name,
|
||||
distance_strategy="COSINE",
|
||||
query_field="context",
|
||||
vector_query_field="dense_vector",
|
||||
verify_certs=False)
|
||||
except ConnectionError as ce:
|
||||
print(ce)
|
||||
print("连接到 Elasticsearch 失败!")
|
||||
logger.error("连接到 Elasticsearch 失败!")
|
||||
except Exception as e:
|
||||
logger.error(f"Error 发生 : {e}")
|
||||
print(e)
|
||||
|
||||
|
||||
|
||||
def do_search(self, query:str, top_k: int, score_threshold: float,expr:str):
|
||||
# 文本相似性检索
|
||||
docs = self.db_init.similarity_search_with_score(query=query,
|
||||
k=top_k)
|
||||
return docs
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
results = []
|
||||
for doc_id in ids:
|
||||
try:
|
||||
response = self.es_client_python.get(index=self.index_name, id=doc_id)
|
||||
source = response["_source"]
|
||||
# Assuming your document has "text" and "metadata" fields
|
||||
text = source.get("context", "")
|
||||
metadata = source.get("metadata", {})
|
||||
results.append(Document(page_content=text, metadata=metadata))
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving document from Elasticsearch! {e}")
|
||||
return results
|
||||
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
for doc_id in ids:
|
||||
try:
|
||||
self.es_client_python.delete(index=self.index_name,
|
||||
id=doc_id,
|
||||
refresh=True)
|
||||
except Exception as e:
|
||||
logger.error(f"ES Docs Delete Error! {e}")
|
||||
|
||||
def do_delete_doc(self, kb_file, **kwargs):
|
||||
if self.es_client_python.indices.exists(index=self.index_name):
|
||||
# 从向量数据库中删除索引(文档名称是Keyword)
|
||||
query = {
|
||||
"query": {
|
||||
"term": {
|
||||
"metadata.source.keyword": kb_file.filepath
|
||||
}
|
||||
}
|
||||
}
|
||||
# 注意设置size,默认返回10个。
|
||||
search_results = self.es_client_python.search(body=query, size=50)
|
||||
delete_list = [hit["_id"] for hit in search_results['hits']['hits']]
|
||||
if len(delete_list) == 0:
|
||||
return None
|
||||
else:
|
||||
for doc_id in delete_list:
|
||||
try:
|
||||
self.es_client_python.delete(index=self.index_name,
|
||||
id=doc_id,
|
||||
refresh=True)
|
||||
except Exception as e:
|
||||
logger.error(f"ES Docs Delete Error! {e}")
|
||||
|
||||
# self.db_init.delete(ids=delete_list)
|
||||
#self.es_client_python.indices.refresh(index=self.index_name)
|
||||
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs):
|
||||
'''向知识库添加文件'''
|
||||
print(f"server.knowledge_base.kb_service.es_kb_service.do_add_doc 输入的docs参数长度为:{len(docs)}")
|
||||
print("*"*100)
|
||||
self._load_es(docs=docs, embed_model=self.embeddings_model)
|
||||
# 获取 id 和 source , 格式:[{"id": str, "metadata": dict}, ...]
|
||||
print("写入数据成功.")
|
||||
print("*"*100)
|
||||
|
||||
if self.es_client_python.indices.exists(index=self.index_name):
|
||||
file_path = docs[0].metadata.get("source")
|
||||
query = {
|
||||
"query": {
|
||||
"term": {
|
||||
"metadata.source.keyword": file_path
|
||||
},
|
||||
"term": {
|
||||
"_index": self.index_name
|
||||
}
|
||||
}
|
||||
}
|
||||
# 注意设置size,默认返回10个。
|
||||
search_results = self.es_client_python.search(body=query, size=50)
|
||||
if len(search_results["hits"]["hits"]) == 0:
|
||||
raise ValueError("召回元素个数为0")
|
||||
info_docs = [{"id":hit["_id"], "metadata": hit["_source"]["metadata"]} for hit in search_results["hits"]["hits"]]
|
||||
return info_docs
|
||||
|
||||
|
||||
def do_clear_vs(self):
|
||||
"""从知识库删除全部向量"""
|
||||
if self.es_client_python.indices.exists(index=self.kb_name):
|
||||
self.es_client_python.indices.delete(index=self.kb_name)
|
||||
|
||||
|
||||
def do_drop_kb(self):
|
||||
"""删除知识库"""
|
||||
# self.kb_file: 知识库路径
|
||||
if os.path.exists(self.kb_path):
|
||||
shutil.rmtree(self.kb_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
esKBService = ESKBService("test")
|
||||
#esKBService.clear_vs()
|
||||
#esKBService.create_kb()
|
||||
esKBService.add_doc(KnowledgeFile(filename="README.md", knowledge_base_name="test"))
|
||||
print(esKBService.search_docs("如何启动api服务"))
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,124 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from configs import SCORE_THRESHOLD, EXPR
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter
|
||||
from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss
|
||||
from server.knowledge_base.utils import KnowledgeFile, get_kb_path, get_vs_path
|
||||
from server.utils import torch_gc
|
||||
from langchain.docstore.document import Document
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
vs_path: str
|
||||
kb_path: str
|
||||
vector_name: str = None
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.FAISS
|
||||
|
||||
def get_vs_path(self):
|
||||
return get_vs_path(self.kb_name, self.vector_name)
|
||||
|
||||
def get_kb_path(self):
|
||||
return get_kb_path(self.kb_name)
|
||||
|
||||
def load_vector_store(self) -> ThreadSafeFaiss:
|
||||
return kb_faiss_pool.load_vector_store(kb_name=self.kb_name,
|
||||
vector_name=self.vector_name,
|
||||
embed_model=self.embed_model)
|
||||
|
||||
def save_vector_store(self):
|
||||
self.load_vector_store().save(self.vs_path)
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
return [vs.docstore._dict.get(id) for id in ids]
|
||||
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
vs.delete(ids)
|
||||
|
||||
def do_init(self):
|
||||
self.vector_name = self.vector_name or self.embed_model
|
||||
self.kb_path = self.get_kb_path()
|
||||
self.vs_path = self.get_vs_path()
|
||||
|
||||
def do_create_kb(self):
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
self.load_vector_store()
|
||||
|
||||
def do_drop_kb(self):
|
||||
self.clear_vs()
|
||||
try:
|
||||
shutil.rmtree(self.kb_path)
|
||||
except Exception:
|
||||
...
|
||||
|
||||
def do_search(self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
expr: str = EXPR,
|
||||
) -> List[Tuple[Document, float]]:
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
**kwargs,
|
||||
) -> List[Dict]:
|
||||
data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间
|
||||
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]),
|
||||
metadatas=data["metadatas"],
|
||||
ids=kwargs.get("ids"))
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vs.save_local(self.vs_path)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
torch_gc()
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile,
|
||||
**kwargs):
|
||||
with self.load_vector_store().acquire() as vs:
|
||||
ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source").lower() == kb_file.filename.lower()]
|
||||
if len(ids) > 0:
|
||||
vs.delete(ids)
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vs.save_local(self.vs_path)
|
||||
return ids
|
||||
|
||||
def do_clear_vs(self):
|
||||
with kb_faiss_pool.atomic:
|
||||
kb_faiss_pool.pop((self.kb_name, self.vector_name))
|
||||
try:
|
||||
shutil.rmtree(self.vs_path)
|
||||
except Exception:
|
||||
...
|
||||
os.makedirs(self.vs_path, exist_ok=True)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
if super().exist_doc(file_name):
|
||||
return "in_db"
|
||||
|
||||
content_path = os.path.join(self.kb_path, "content")
|
||||
if os.path.isfile(os.path.join(content_path, file_name)):
|
||||
return "in_folder"
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
faissService = FaissKBService("test")
|
||||
faissService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
faissService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
faissService.do_drop_kb()
|
||||
print(faissService.search_docs("如何启动api服务"))
|
||||
@@ -0,0 +1,207 @@
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.milvus import Milvus
|
||||
import os
|
||||
import logging
|
||||
|
||||
from configs import kbs_config
|
||||
from server.db.repository import list_file_num_docs_id_by_kb_name_and_file_name
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
import numpy as np
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MilvusKBService(KBService):
|
||||
milvus: Milvus
|
||||
|
||||
@staticmethod
|
||||
def get_collection(milvus_name):
|
||||
from pymilvus import Collection
|
||||
return Collection(milvus_name)
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
result = []
|
||||
if self.milvus and self.milvus.col:
|
||||
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
|
||||
data_list = self.milvus.col.query(expr=f'pk in {[int(_id) for _id in ids]}', output_fields=["*"])
|
||||
for data in data_list:
|
||||
text = data.pop("text")
|
||||
result.append(Document(page_content=text, metadata=data))
|
||||
return result
|
||||
|
||||
def get_doc_by_sources_name(self, source_name_list: List[str]) -> List[Document]:
|
||||
result = []
|
||||
if self.milvus and self.milvus.col:
|
||||
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
|
||||
data_list = self.milvus.col.query(expr=f'source in {source_name_list}', output_fields=["*"])
|
||||
for data in data_list:
|
||||
text = data.pop("text")
|
||||
result.append(Document(page_content=text, metadata=data))
|
||||
return result
|
||||
|
||||
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
if self.milvus and self.milvus.col:
|
||||
self.milvus.col.delete(expr=f'pk in {ids}')
|
||||
|
||||
@staticmethod
|
||||
def search(milvus_name, content, limit=3):
|
||||
search_params = {
|
||||
"metric_type": "L2",
|
||||
"params": {"nprobe": 10},
|
||||
}
|
||||
c = MilvusKBService.get_collection(milvus_name)
|
||||
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
|
||||
|
||||
def do_create_kb(self):
|
||||
pass
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.MILVUS
|
||||
|
||||
def _load_milvus(self):
|
||||
try:
|
||||
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||
collection_name=self.kb_name,
|
||||
connection_args=kbs_config.get("milvus"),
|
||||
index_params=kbs_config.get("milvus_kwargs")["index_params"],
|
||||
search_params=kbs_config.get("milvus_kwargs")["search_params"],
|
||||
auto_id=True
|
||||
)
|
||||
logger.info("成功加载 Milvus 实例 'milvus'。")
|
||||
|
||||
# -------- 兼容不同 schema 的文本字段 --------
|
||||
# 新库尚无 Milvus 集合时 langchain_community.Milvus.col 为 None,
|
||||
# 会在首次 add_documents 建表后再有 schema,此处勿访问 .col.schema。
|
||||
try:
|
||||
col = self.milvus.col
|
||||
if col is None:
|
||||
logger.debug(
|
||||
"集合 %s 尚未在 Milvus 中建表,跳过文本字段探测(首次写入时会自动创建)",
|
||||
self.kb_name,
|
||||
)
|
||||
else:
|
||||
field_names = [f.name for f in col.schema.fields]
|
||||
if self.milvus._text_field not in field_names:
|
||||
if "page_content" in field_names:
|
||||
self.milvus._text_field = "page_content"
|
||||
elif "content" in field_names:
|
||||
self.milvus._text_field = "content"
|
||||
else:
|
||||
for f in col.schema.fields:
|
||||
if hasattr(f, "dtype") and str(f.dtype).startswith("DataType.VARCHAR"):
|
||||
self.milvus._text_field = f.name
|
||||
break
|
||||
logger.info(f"集合 {self.kb_name} 使用文本字段: {self.milvus._text_field}")
|
||||
except Exception as e:
|
||||
logger.warning(f"检测并设置文本字段失败: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"加载 Milvus 实例 'milvus' 失败: {e}")
|
||||
self._create_collection_if_not_exists()
|
||||
# 重新加载
|
||||
# self._load_milvus()
|
||||
|
||||
def _create_collection_if_not_exists(self):
|
||||
"""根据传入字段创建 Milvus 集合"""
|
||||
from pymilvus import Collection, CollectionSchema, FieldSchema, DataType
|
||||
from langchain_community.vectorstores import Milvus
|
||||
|
||||
# 定义你的字段(根据你的需求修改)
|
||||
fields = [
|
||||
FieldSchema(name="pk", dtype=DataType.Int64, is_primary=True, auto_id=True),
|
||||
FieldSchema(name="vector", dtype=DataType.FloatVector, dim=768), # dim 根据 embedding 模型调整
|
||||
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
|
||||
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=1024),
|
||||
FieldSchema(name="metadata", dtype=DataType.VARCHAR, max_length=65535),
|
||||
# 添加其他自定义字段...
|
||||
]
|
||||
|
||||
schema = CollectionSchema(fields=fields, description=self.kb_name)
|
||||
|
||||
# 创建集合
|
||||
collection = Collection(name=self.kb_name, schema=schema, using="default")
|
||||
|
||||
# 创建索引
|
||||
index_params = kbs_config.get("milvus_kwargs")["index_params"]
|
||||
collection.create_index(field_name="vector", index_params=index_params)
|
||||
|
||||
logger.info(f"成功创建 Milvus 集合: {self.kb_name}")
|
||||
|
||||
def do_init(self):
|
||||
self._load_milvus()
|
||||
|
||||
def do_drop_kb(self):
|
||||
if self.milvus and self.milvus.col:
|
||||
self.milvus.col.release()
|
||||
# self.milvus.col.drop() # 禁用从chatchat删除集合
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, expr: str, custom_strategy_config: dict = {}):
|
||||
self._load_milvus()
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
try:
|
||||
embeddings = embed_func.embed_query(query)
|
||||
if top_k > 50:
|
||||
# 按顺序返回全文内容
|
||||
docs = self.milvus.similarity_search_by_vector(embeddings, top_k, expr = expr)
|
||||
docs = sorted(docs, key=lambda doc: doc.metadata['pk']) # 根据 pk 从小到大排序
|
||||
# return score_threshold_process(query,score_threshold, top_k, docs)
|
||||
return docs
|
||||
else:
|
||||
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k, expr = expr)
|
||||
# TODO 动态score_threshold
|
||||
return score_threshold_process(query,score_threshold, top_k, docs)
|
||||
except Exception as e:
|
||||
logger.error(f"搜索 Milvus 集合 '{self.kb_name}' 失败: {e}")
|
||||
return []
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
for doc in docs:
|
||||
for k, v in doc.metadata.items():
|
||||
doc.metadata[k] = str(v)
|
||||
for field in self.milvus.fields:
|
||||
doc.metadata.setdefault(field, "")
|
||||
doc.metadata.pop(self.milvus._text_field, None)
|
||||
doc.metadata.pop(self.milvus._vector_field, None)
|
||||
|
||||
ids = self.milvus.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
id_list = list_file_num_docs_id_by_kb_name_and_file_name(kb_file.kb_name, kb_file.filename)
|
||||
if self.milvus and self.milvus.col:
|
||||
self.milvus.col.delete(expr=f'pk in {id_list}')
|
||||
|
||||
# Issue 2846, for windows
|
||||
# if self.milvus.col:
|
||||
# file_path = kb_file.filepath.replace("\\", "\\\\")
|
||||
# file_name = os.path.basename(file_path)
|
||||
# id_list = [item.get("pk") for item in
|
||||
# self.milvus.col.query(expr=f'source == "{file_name}"', output_fields=["pk"])]
|
||||
# self.milvus.col.delete(expr=f'pk in {id_list}')
|
||||
|
||||
def do_clear_vs(self):
|
||||
if self.milvus and self.milvus.col:
|
||||
self.do_drop_kb()
|
||||
self.do_init()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 测试建表使用
|
||||
from server.db.base import Base, engine
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
milvusService = MilvusKBService("t_policy_total_bce_v1")
|
||||
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
|
||||
# print(milvusService.get_doc_by_ids(["444022434274215486"]))
|
||||
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
# milvusService.do_drop_kb()
|
||||
# print(milvusService.search_docs("如何启动api服务"))
|
||||
@@ -0,0 +1,96 @@
|
||||
import json
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores.pgvector import PGVector, DistanceStrategy
|
||||
from sqlalchemy import text
|
||||
|
||||
from configs import kbs_config
|
||||
|
||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
import shutil
|
||||
import sqlalchemy
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class PGKBService(KBService):
|
||||
engine: Engine = sqlalchemy.create_engine(kbs_config.get("pg").get("connection_uri"), pool_size=10)
|
||||
|
||||
def _load_pg_vector(self):
|
||||
self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||
collection_name=self.kb_name,
|
||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||
connection=PGKBService.engine,
|
||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
with Session(PGKBService.engine) as session:
|
||||
stmt = text("SELECT document, cmetadata FROM langchain_pg_embedding WHERE collection_id in :ids")
|
||||
results = [Document(page_content=row[0], metadata=row[1]) for row in
|
||||
session.execute(stmt, {'ids': ids}).fetchall()]
|
||||
return results
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
return super().del_doc_by_ids(ids)
|
||||
|
||||
def do_init(self):
|
||||
self._load_pg_vector()
|
||||
|
||||
def do_create_kb(self):
|
||||
pass
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.PG
|
||||
|
||||
def do_drop_kb(self):
|
||||
with Session(PGKBService.engine) as session:
|
||||
session.execute(text(f'''
|
||||
-- 删除 langchain_pg_embedding 表中关联到 langchain_pg_collection 表中 的记录
|
||||
DELETE FROM langchain_pg_embedding
|
||||
WHERE collection_id IN (
|
||||
SELECT uuid FROM langchain_pg_collection WHERE name = '{self.kb_name}'
|
||||
);
|
||||
-- 删除 langchain_pg_collection 表中 记录
|
||||
DELETE FROM langchain_pg_collection WHERE name = '{self.kb_name}';
|
||||
'''))
|
||||
session.commit()
|
||||
shutil.rmtree(self.kb_path)
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float,expr:str):
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.pg_vector.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
ids = self.pg_vector.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
with Session(PGKBService.engine) as session:
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
session.execute(
|
||||
text(
|
||||
''' DELETE FROM langchain_pg_embedding WHERE cmetadata::jsonb @> '{"source": "filepath"}'::jsonb;'''.replace(
|
||||
"filepath", filepath)))
|
||||
session.commit()
|
||||
|
||||
def do_clear_vs(self):
|
||||
self.pg_vector.delete_collection()
|
||||
self.pg_vector.create_collection()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from server.db.base import Base, engine
|
||||
|
||||
# Base.metadata.create_all(bind=engine)
|
||||
pGKBService = PGKBService("test")
|
||||
# pGKBService.create_kb()
|
||||
# pGKBService.add_doc(KnowledgeFile("README.md", "test"))
|
||||
# pGKBService.delete_doc(KnowledgeFile("README.md", "test"))
|
||||
# pGKBService.drop_kb()
|
||||
print(pGKBService.get_doc_by_ids(["f1e51390-3029-4a19-90dc-7118aaa25772"]))
|
||||
# print(pGKBService.search_docs("如何启动api服务"))
|
||||
@@ -0,0 +1,97 @@
|
||||
from typing import List, Dict, Optional
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
from langchain.vectorstores import Zilliz
|
||||
from configs import kbs_config
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
|
||||
score_threshold_process
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
|
||||
class ZillizKBService(KBService):
|
||||
zilliz: Zilliz
|
||||
|
||||
@staticmethod
|
||||
def get_collection(zilliz_name):
|
||||
from pymilvus import Collection
|
||||
return Collection(zilliz_name)
|
||||
|
||||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
|
||||
result = []
|
||||
if self.zilliz.col:
|
||||
# ids = [int(id) for id in ids] # for zilliz if needed #pr 2725
|
||||
data_list = self.zilliz.col.query(expr=f'pk in {ids}', output_fields=["*"])
|
||||
for data in data_list:
|
||||
text = data.pop("text")
|
||||
result.append(Document(page_content=text, metadata=data))
|
||||
return result
|
||||
|
||||
def del_doc_by_ids(self, ids: List[str]) -> bool:
|
||||
self.zilliz.col.delete(expr=f'pk in {ids}')
|
||||
|
||||
@staticmethod
|
||||
def search(zilliz_name, content, limit=3):
|
||||
search_params = {
|
||||
"metric_type": "IP",
|
||||
"params": {},
|
||||
}
|
||||
c = ZillizKBService.get_collection(zilliz_name)
|
||||
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
|
||||
|
||||
def do_create_kb(self):
|
||||
pass
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.ZILLIZ
|
||||
|
||||
def _load_zilliz(self):
|
||||
zilliz_args = kbs_config.get("zilliz")
|
||||
self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model),
|
||||
collection_name=self.kb_name, connection_args=zilliz_args)
|
||||
|
||||
def do_init(self):
|
||||
self._load_zilliz()
|
||||
|
||||
def do_drop_kb(self):
|
||||
if self.zilliz.col:
|
||||
self.zilliz.col.release()
|
||||
self.zilliz.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float,expr:str):
|
||||
self._load_zilliz()
|
||||
embed_func = EmbeddingsFunAdapter(self.embed_model)
|
||||
embeddings = embed_func.embed_query(query)
|
||||
docs = self.zilliz.similarity_search_with_score_by_vector(embeddings, top_k)
|
||||
return score_threshold_process(score_threshold, top_k, docs)
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
for doc in docs:
|
||||
for k, v in doc.metadata.items():
|
||||
doc.metadata[k] = str(v)
|
||||
for field in self.zilliz.fields:
|
||||
doc.metadata.setdefault(field, "")
|
||||
doc.metadata.pop(self.zilliz._text_field, None)
|
||||
doc.metadata.pop(self.zilliz._vector_field, None)
|
||||
|
||||
ids = self.zilliz.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
if self.zilliz.col:
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
delete_list = [item.get("pk") for item in
|
||||
self.zilliz.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
||||
self.zilliz.col.delete(expr=f'pk in {delete_list}')
|
||||
|
||||
def do_clear_vs(self):
|
||||
if self.zilliz.col:
|
||||
self.do_drop_kb()
|
||||
self.do_init()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from server.db.base import Base, engine
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
zillizService = ZillizKBService("test")
|
||||
Reference in New Issue
Block a user