Files

208 lines
8.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List, Dict, Optional
from langchain.schema import Document
from langchain.vectorstores.milvus import Milvus
import os
import logging
from configs import kbs_config
from server.db.repository import list_file_num_docs_id_by_kb_name_and_file_name
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
import numpy as np
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MilvusKBService(KBService):
milvus: Milvus
@staticmethod
def get_collection(milvus_name):
from pymilvus import Collection
return Collection(milvus_name)
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = []
if self.milvus and self.milvus.col:
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
data_list = self.milvus.col.query(expr=f'pk in {[int(_id) for _id in ids]}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
result.append(Document(page_content=text, metadata=data))
return result
def get_doc_by_sources_name(self, source_name_list: List[str]) -> List[Document]:
result = []
if self.milvus and self.milvus.col:
# ids = [int(id) for id in ids] # for milvus if needed #pr 2725
data_list = self.milvus.col.query(expr=f'source in {source_name_list}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
result.append(Document(page_content=text, metadata=data))
return result
def del_doc_by_ids(self, ids: List[str]) -> bool:
if self.milvus and self.milvus.col:
self.milvus.col.delete(expr=f'pk in {ids}')
@staticmethod
def search(milvus_name, content, limit=3):
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
c = MilvusKBService.get_collection(milvus_name)
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
def do_create_kb(self):
pass
def vs_type(self) -> str:
return SupportedVSType.MILVUS
def _load_milvus(self):
try:
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name,
connection_args=kbs_config.get("milvus"),
index_params=kbs_config.get("milvus_kwargs")["index_params"],
search_params=kbs_config.get("milvus_kwargs")["search_params"],
auto_id=True
)
logger.info("成功加载 Milvus 实例 'milvus'")
# -------- 兼容不同 schema 的文本字段 --------
# 新库尚无 Milvus 集合时 langchain_community.Milvus.col 为 None
# 会在首次 add_documents 建表后再有 schema此处勿访问 .col.schema。
try:
col = self.milvus.col
if col is None:
logger.debug(
"集合 %s 尚未在 Milvus 中建表,跳过文本字段探测(首次写入时会自动创建)",
self.kb_name,
)
else:
field_names = [f.name for f in col.schema.fields]
if self.milvus._text_field not in field_names:
if "page_content" in field_names:
self.milvus._text_field = "page_content"
elif "content" in field_names:
self.milvus._text_field = "content"
else:
for f in col.schema.fields:
if hasattr(f, "dtype") and str(f.dtype).startswith("DataType.VARCHAR"):
self.milvus._text_field = f.name
break
logger.info(f"集合 {self.kb_name} 使用文本字段: {self.milvus._text_field}")
except Exception as e:
logger.warning(f"检测并设置文本字段失败: {e}")
except Exception as e:
logger.error(f"加载 Milvus 实例 'milvus' 失败: {e}")
self._create_collection_if_not_exists()
# 重新加载
# self._load_milvus()
def _create_collection_if_not_exists(self):
"""根据传入字段创建 Milvus 集合"""
from pymilvus import Collection, CollectionSchema, FieldSchema, DataType
from langchain_community.vectorstores import Milvus
# 定义你的字段(根据你的需求修改)
fields = [
FieldSchema(name="pk", dtype=DataType.Int64, is_primary=True, auto_id=True),
FieldSchema(name="vector", dtype=DataType.FloatVector, dim=768), # dim 根据 embedding 模型调整
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
FieldSchema(name="source", dtype=DataType.VARCHAR, max_length=1024),
FieldSchema(name="metadata", dtype=DataType.VARCHAR, max_length=65535),
# 添加其他自定义字段...
]
schema = CollectionSchema(fields=fields, description=self.kb_name)
# 创建集合
collection = Collection(name=self.kb_name, schema=schema, using="default")
# 创建索引
index_params = kbs_config.get("milvus_kwargs")["index_params"]
collection.create_index(field_name="vector", index_params=index_params)
logger.info(f"成功创建 Milvus 集合: {self.kb_name}")
def do_init(self):
self._load_milvus()
def do_drop_kb(self):
if self.milvus and self.milvus.col:
self.milvus.col.release()
# self.milvus.col.drop() # 禁用从chatchat删除集合
def do_search(self, query: str, top_k: int, score_threshold: float, expr: str, custom_strategy_config: dict = {}):
self._load_milvus()
embed_func = EmbeddingsFunAdapter(self.embed_model)
try:
embeddings = embed_func.embed_query(query)
if top_k > 50:
# 按顺序返回全文内容
docs = self.milvus.similarity_search_by_vector(embeddings, top_k, expr = expr)
docs = sorted(docs, key=lambda doc: doc.metadata['pk']) # 根据 pk 从小到大排序
# return score_threshold_process(query,score_threshold, top_k, docs)
return docs
else:
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k, expr = expr)
# TODO 动态score_threshold
return score_threshold_process(query,score_threshold, top_k, docs)
except Exception as e:
logger.error(f"搜索 Milvus 集合 '{self.kb_name}' 失败: {e}")
return []
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs:
for k, v in doc.metadata.items():
doc.metadata[k] = str(v)
for field in self.milvus.fields:
doc.metadata.setdefault(field, "")
doc.metadata.pop(self.milvus._text_field, None)
doc.metadata.pop(self.milvus._vector_field, None)
ids = self.milvus.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
id_list = list_file_num_docs_id_by_kb_name_and_file_name(kb_file.kb_name, kb_file.filename)
if self.milvus and self.milvus.col:
self.milvus.col.delete(expr=f'pk in {id_list}')
# Issue 2846, for windows
# if self.milvus.col:
# file_path = kb_file.filepath.replace("\\", "\\\\")
# file_name = os.path.basename(file_path)
# id_list = [item.get("pk") for item in
# self.milvus.col.query(expr=f'source == "{file_name}"', output_fields=["pk"])]
# self.milvus.col.delete(expr=f'pk in {id_list}')
def do_clear_vs(self):
if self.milvus and self.milvus.col:
self.do_drop_kb()
self.do_init()
if __name__ == '__main__':
# 测试建表使用
from server.db.base import Base, engine
Base.metadata.create_all(bind=engine)
milvusService = MilvusKBService("t_policy_total_bce_v1")
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
# print(milvusService.get_doc_by_ids(["444022434274215486"]))
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
# milvusService.do_drop_kb()
# print(milvusService.search_docs("如何启动api服务"))