Files
gangyan/langchain-chat/server/knowledge_base/kb_service/base.py

776 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from datetime import datetime
import operator
from abc import ABC, abstractmethod
import re
import os
from pathlib import Path
import numpy as np
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
from joblib import Parallel, delayed
import time
from textrank4zh import TextRank4Keyword, TextRank4Sentence
import multiprocessing
from configs.model_config import LLM_MODELS
from server.chat.policy_fun_iast import get_llm_model_response
from server.chat.utils import get_personal_knowledge_map, get_similar_documents1
from nltk.tokenize import sent_tokenize
import logging
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
def generate_weights_as_list(length, total_sum=80):
if length == 1:
return [50]
# Create a range of numbers decreasing logarithmically
x = np.linspace(0, length - 1, length)
weights = np.exp(-x / (length / 5))
# Normalize the weights to match the specified sum
normalized_weights = weights / sum(weights) * total_sum
integer_weights = np.round(normalized_weights).astype(int)
# Adjust the weights to match the exact sum if rounding causes deviation
adjustment = total_sum - sum(integer_weights)
for i in range(abs(adjustment)):
if adjustment > 0:
integer_weights[i % length] += 1
elif adjustment < 0:
integer_weights[i % length] -= 1
return integer_weights.tolist()
def score_threshold_process(query,score_threshold, k, docs):
"""
根据分数阈值过滤和使用TextRank摘要文档并返回前k个文档。
:param score_threshold: 相似度分数阈值;忽略低于此阈值的文档。
:param k: 要返回的顶部文档数量。
:param docs: 文档列表,每个文档是一个元组(文档,相似度分数)。
:return: 根据分数阈值返回的前k个文档的列表。
"""
# 如果提供了score_threshold则只过滤大于阈值的文档。
if score_threshold is not None:
cmp = (
operator.le
)
docs = [
(doc, similarity)
for doc, similarity in docs
if cmp(similarity, score_threshold)
]
# 当召回结果都大于score_threshold时
if len(docs) == 0:
return docs
result = []
try:
for doc in docs:
if query.replace(" ","").replace("\n","").replace("\r","") in doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r",""):
result.append(doc)
except Exception as e:
for doc in docs:
if query.replace(" ","").replace("\n","").replace("\r","") in doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r",""):
result.append(doc)
if len(docs) > 0 and not "h1" in docs[0][0].metadata:
# 如果存在用户的问题在标题中的情况则进行去重操作且不需要再匹配相关度,只需要把问题在标题中的文献提交出去
if len(result) > 0:
temp={}
for doc in result:
if doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") not in temp:
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")] = doc
else:
if temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
continue
elif temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") :
continue
else:
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
docs = []
for i in temp:
docs.append(temp[i])
else:
try:
sentences = []
sentences_page_content = []
for doc in docs:
meta = doc[0].metadata
# 若缺少标题或为空则用正文首句作为标题最多50字
if meta["title"] == "":
meta["title"] = doc[0].page_content
# 如有摘要则替换 page_content保证后续文本更简洁
summary = meta.get("summary")
if summary:
doc[0].page_content = summary
sentences = [doc[0].metadata["title"] for doc in docs]
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["title"]+doc[0].page_content+"" for i,doc in enumerate(docs)]
except Exception as e:
sentences = [doc[0].metadata["source"] for doc in docs]
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"" for i,doc in enumerate(docs)]
res = get_llm_model_response(
strategy_name="default_similar",
llm_model_name=LLM_MODELS[0],
template_prompt_name="default_similar",
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=512
)
try:
index =[]
if res == "":
index = []
else:
index = res.split(",")
index = [int(i)-1 for i in index]
docs = get_similar_documents1(index=index,sentences=sentences,query=query, docs=docs, top_k=k)
except Exception as e:
print(e)
docs = get_similar_documents1(index=[],sentences=sentences,query=query, docs=docs, top_k=k)
# 去重操作只针对通用知识库
temp={}
for doc in docs:
try:
if doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") not in temp:
temp[doc[0].metadata["title"]] = doc
else:
if temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
continue
elif temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") :
continue
else:
temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
except Exception as e:
print(e)
if doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") not in temp:
temp[doc[0].metadata["source"]] = doc
else:
if temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
continue
elif temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") :
continue
else:
temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
docs = []
for i in temp:
docs.append(temp[i])
#只针对个人知识库
if "h1" in docs[0][0].metadata:
all_source = [doc[0].metadata["source"] for doc in docs]
unique_source = list(set(all_source))
all_title_map = get_personal_knowledge_map(unique_source)
for doc in docs:
doc[0].metadata["uuid_name"] = doc[0].metadata["source"]
if doc[0].metadata["source"] in all_title_map:
doc[0].metadata["source"] = all_title_map[doc[0].metadata["source"]]
else:
pass
try:
sentences = [doc[0].metadata["source"] for doc in docs]
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"" for i,doc in enumerate(docs)]
except Exception as e:
sentences = [doc[0].metadata["source"] for doc in docs]
sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"" for i,doc in enumerate(docs)]
kwargs = {}
kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}}
res = get_llm_model_response(
strategy_name="default_similar",
llm_model_name=LLM_MODELS[0],
template_prompt_name="default_similar",
prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")},
temperature=0.01,
max_tokens=None,
**kwargs
)
res = re.sub(r'<think>.*?</think>', '', res,flags=re.DOTALL)
try:
index =[]
if res == "":
index = []
else:
index = res.split(",")
index = [int(i)-1 for i in index]
docs = get_similar_documents1(index=index,sentences=sentences,query=query, docs=docs, top_k=k)
except Exception as e:
print(e)
docs = get_similar_documents1(index=[],sentences=sentences,query=query, docs=docs, top_k=k)
# 去重操作只针对通用知识库
temp={}
for doc in docs:
if doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") not in temp:
temp[doc[0].metadata["source"]] = doc
else:
if temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") :
continue
elif temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") :
continue
else:
temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content
docs = []
for i in temp:
docs.append(temp[i])
if len(docs) == 0:
return docs
# 为TextRank算法生成权重。
cont = generate_weights_as_list(len(docs))
# 处理每个文档以提取或分配摘要。
for i, (doc, _) in enumerate(docs):
summary_sources = ['content', 'abstract', 'text'] # 根据不同知识库遍历字段
for source in summary_sources:
try:
if docs[i][0].metadata["title"] in docs[i][0].page_content or len(docs[i][0].page_content) < 100:
doc.metadata['summary'] = TextRank(doc.metadata[source], cont[i])
if len(doc.metadata['summary']) >15000:
doc.metadata['summary'] = TextRank(doc.metadata[source], 1)
break
else:
doc.metadata['summary'] = docs[i][0].page_content
except KeyError:
doc.metadata['summary'] = docs[i][0].page_content
continue # 如果当前源失败,则尝试下一个源。
# 返回前k个文档。
return docs[:k]
# 猴子补丁为了兼容TexRank
import networkx as nx
import numpy as np
# 进行猴子补丁
nx.from_numpy_matrix = nx.from_numpy_array
# 进行猴子补丁,入数据类型兼容性检查
def process_text_segment(text_segment, num_sentences):
tr4w = TextRank4Keyword()
tr4w.analyze(text=text_segment, lower=True, window=5)
keywords = [(item.word, item.weight) for item in tr4w.get_keywords(30, word_min_len=4)]
tr4s = TextRank4Sentence()
tr4s.analyze(text=text_segment, lower=True, source='all_filters')
summaries = [item.sentence for item in tr4s.get_key_sentences(num=num_sentences)]
return keywords, summaries
def split_text_balanced(text, n_parts):
sentences = sent_tokenize(text)
min_sentences_per_part = 10
n_parts = max(1, min(n_parts, len(sentences) // min_sentences_per_part))
k, m = divmod(len(sentences), n_parts)
return [' '.join(sentences[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n_parts)]
def TextRank(text,num_sentences, n_cores=multiprocessing.cpu_count()):
start_time = time.time()
text_parts = split_text_balanced(text, n_cores)
all_keywords = []
all_summaries = []
# 在这里直接顺序处理每个分块,或用线程池而非进程池
for part in text_parts:
keywords, summaries = process_text_segment(part, num_sentences)
all_keywords.extend(keywords)
all_summaries.extend(summaries)
# Print results
# print('关键词:')
for word, weight in sorted(all_keywords, key=lambda x: x[1], reverse=True):
print(word, weight)
end_time = time.time()
logging.info(f"TextRank耗时: {end_time - start_time:.2f}")
all_summaries = "".join(all_summaries)
return all_summaries
from server.db.repository.knowledge_base_repository import (
add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists,
load_kb_from_db, get_kb_detail,
)
from server.db.repository.knowledge_file_repository import (
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
list_docs_from_db,
)
from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,EXPR,
EMBEDDING_MODEL, KB_INFO)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder,
)
from typing import List, Union, Dict, Optional, Tuple
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
import time
def get_emb_time(f):
def inner(*arg,**kwarg):
s_time = time.time()
res = f(*arg,**kwarg)
e_time = time.time()
print('向量化耗时:{}'.format(e_time - s_time))
return res
return inner
def normalize(embeddings: List[List[float]]) -> np.ndarray:
'''
sklearn.preprocessing.normalize 的替代(使用 L2避免安装 scipy, scikit-learn
'''
# 过滤掉 None 值
embeddings = [e for e in embeddings if e is not None]
if not embeddings:
raise ValueError("No valid embeddings found (all are None)")
embeddings = np.array(embeddings)
norm = np.linalg.norm(embeddings, axis=1)
norm = np.reshape(norm, (norm.shape[0], 1))
norm = np.tile(norm, (1, len(embeddings[0])))
return np.divide(embeddings, norm)
class SupportedVSType:
FAISS = 'faiss'
MILVUS = 'milvus'
DEFAULT = 'default'
ZILLIZ = 'zilliz'
PG = 'pg'
ES = 'es'
CHROMADB = 'chromadb'
class KBService(ABC):
def __init__(self,
knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL,
):
self.kb_name = knowledge_base_name
self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库")
self.embed_model = embed_model
self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name)
self.do_init()
def __repr__(self) -> str:
return f"{self.kb_name} @ {self.embed_model}"
def save_vector_store(self):
'''
保存向量库:FAISS保存到磁盘milvus保存到数据库。PGVector暂未支持
'''
pass
def create_kb(self):
"""
创建知识库
"""
if not os.path.exists(self.doc_path):
os.makedirs(self.doc_path)
self.do_create_kb()
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
return status
def clear_vs(self):
"""
删除向量库中所有内容
"""
self.do_clear_vs()
status = delete_files_from_db(self.kb_name)
return status
def drop_kb(self):
"""
删除知识库
"""
self.do_drop_kb()
status = delete_kb_from_db(self.kb_name)
return status
def _docs_to_embeddings(self, docs: List[Document]) -> Dict:
'''
将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数
'''
return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False)
@get_emb_time
def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
向知识库添加文件
如果指定了docs则不再将文本向量化并将数据库对应条目标为custom_docs=True
"""
if docs:
custom_docs = True
for doc in docs:
doc.metadata.setdefault("source", kb_file.filename)
else:
docs = kb_file.file2text()
custom_docs = False
if docs:
# 将 metadata["source"] 改为相对路径
for doc in docs:
try:
source = doc.metadata.get("source", "")
if os.path.isabs(source):
rel_path = Path(source).relative_to(self.doc_path)
doc.metadata["source"] = str(rel_path.as_posix().strip("/"))
except Exception as e:
print(f"cannot convert absolute path ({source}) to relative path. error is : {e}")
self.delete_doc(kb_file)
doc_infos = self.do_add_doc(docs, **kwargs)
status = add_file_to_db(kb_file,
custom_docs=custom_docs,
docs_count=len(docs),
doc_infos=doc_infos)
else:
status = False
return status
def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs):
"""
从知识库删除文件
"""
self.do_delete_doc(kb_file, **kwargs)
status = delete_file_from_db(kb_file)
if delete_content and os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
return status
def update_info(self, kb_info: str):
"""
更新知识库介绍
"""
self.kb_info = kb_info
status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model)
return status
def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs):
"""
使用content中的文件更新向量库
如果指定了docs则使用自定义docs并将数据库对应条目标为custom_docs=True
"""
if os.path.exists(kb_file.filepath):
self.delete_doc(kb_file, **kwargs)
return self.add_doc(kb_file, docs=docs, **kwargs)
def exist_doc(self, file_name: str):
return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name,
filename=file_name))
def list_files(self):
return list_files_from_db(self.kb_name)
def count_files(self):
return count_files_from_db(self.kb_name)
def search_docs(self,
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
expr: str = EXPR,
custom_strategy_config: dict = {}
) ->List[Document]:
docs = self.do_search(query, top_k, score_threshold, expr, custom_strategy_config)
return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
return []
def get_doc_by_sources_name(self, source_name_list: List[str]) -> List[Document]:
return []
def del_doc_by_ids(self, ids: List[str]) -> bool:
raise NotImplementedError
def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool:
'''
传入参数为: {doc_id: Document, ...}
如果对应 doc_id 的值为 None或其 page_content 为空,则删除该文档
'''
self.del_doc_by_ids(list(docs.keys()))
docs = []
ids = []
for k, v in docs.items():
if not v or not v.page_content.strip():
continue
ids.append(k)
docs.append(v)
self.do_add_doc(docs=docs, ids=ids)
return True
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]:
'''
通过file_name或metadata检索Document
'''
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
docs = []
for x in doc_infos:
doc_info = self.get_doc_by_ids([x["id"]])[0]
if doc_info is not None:
# 处理非空的情况
doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"])
docs.append(doc_with_id)
else:
# 处理空的情况
# 可以选择跳过当前循环迭代或执行其他操作
pass
return docs
@abstractmethod
def do_create_kb(self):
"""
创建知识库子类实自己逻辑
"""
pass
@staticmethod
def list_kbs_type():
return list(kbs_config.keys())
@classmethod
def list_kbs(cls):
return list_kbs_from_db()
def exists(self, kb_name: str = None):
kb_name = kb_name or self.kb_name
return kb_exists(kb_name)
@abstractmethod
def vs_type(self) -> str:
pass
@abstractmethod
def do_init(self):
pass
@abstractmethod
def do_drop_kb(self):
"""
删除知识库子类实自己逻辑
"""
pass
@abstractmethod
def do_search(self,
query: str,
top_k: int,
score_threshold: float,
expr: str,
custom_strategy_config: dict = {},
) -> List[Tuple[Document, float]]:
"""
搜索知识库子类实自己逻辑
"""
pass
@abstractmethod
def do_add_doc(self,
docs: List[Document],
**kwargs,
) -> List[Dict]:
"""
向知识库添加文档子类实自己逻辑
"""
pass
@abstractmethod
def do_delete_doc(self,
kb_file: KnowledgeFile):
"""
从知识库删除文档子类实自己逻辑
"""
pass
@abstractmethod
def do_clear_vs(self):
"""
从知识库删除全部向量子类实自己逻辑
"""
pass
class KBServiceFactory:
@staticmethod
def get_service(kb_name: str,
vector_store_type: Union[str, SupportedVSType],
embed_model: str = EMBEDDING_MODEL,
) -> KBService:
if isinstance(vector_store_type, str):
vector_store_type = getattr(SupportedVSType, vector_store_type.upper())
if SupportedVSType.FAISS == vector_store_type:
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
return FaissKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.PG == vector_store_type:
from server.knowledge_base.kb_service.pg_kb_service import PGKBService
return PGKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.MILVUS == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name,embed_model=embed_model)
elif SupportedVSType.ZILLIZ == vector_store_type:
from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService
return ZillizKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name,
embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config
elif SupportedVSType.ES == vector_store_type:
from server.knowledge_base.kb_service.es_kb_service import ESKBService
return ESKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.CHROMADB == vector_store_type:
from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService
return ChromaKBService(kb_name, embed_model=embed_model)
elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier.
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
return DefaultKBService(kb_name)
@staticmethod
def get_service_by_name(kb_name: str) -> KBService:
_, vs_type, embed_model = load_kb_from_db(kb_name)
if _ is None: # kb not in db, just return None
return None
from server.utils import resolve_embed_model_name
embed_model = resolve_embed_model_name(embed_model)
return KBServiceFactory.get_service(kb_name, vs_type, embed_model)
@staticmethod
def get_default():
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
def get_kb_details() -> List[Dict]:
kbs_in_folder = list_kbs_from_folder()
kbs_in_db = KBService.list_kbs()
result = {}
for kb in kbs_in_folder:
result[kb] = {
"kb_name": kb,
"vs_type": "",
"kb_info": "",
"embed_model": "",
"file_count": 0,
"create_time": None,
"in_folder": True,
"in_db": False,
}
for kb in kbs_in_db:
kb_detail = get_kb_detail(kb)
if kb_detail:
kb_detail["in_db"] = True
if kb in result:
result[kb].update(kb_detail)
else:
kb_detail["in_folder"] = False
result[kb] = kb_detail
data = []
for i, v in enumerate(result.values()):
v['No'] = i + 1
data.append(v)
return data
def get_kb_file_details(kb_name: str) -> List[Dict]:
kb = KBServiceFactory.get_service_by_name(kb_name)
if kb is None:
return []
files_in_folder = list_files_from_folder(kb_name)
files_in_db = kb.list_files()
result = {}
for doc in files_in_folder:
result[doc] = {
"kb_name": kb_name,
"file_name": doc,
"file_ext": os.path.splitext(doc)[-1],
"file_version": 0,
"document_loader": "",
"docs_count": 0,
"text_splitter": "",
"create_time": None,
"in_folder": True,
"in_db": False,
}
lower_names = {x.lower(): x for x in result}
for doc in files_in_db:
doc_detail = get_file_detail(kb_name, doc)
if doc_detail:
doc_detail["in_db"] = True
if doc.lower() in lower_names:
result[lower_names[doc.lower()]].update(doc_detail)
else:
doc_detail["in_folder"] = False
result[doc] = doc_detail
data = []
for i, v in enumerate(result.values()):
v['No'] = i + 1
data.append(v)
return data
class EmbeddingsFunAdapter(Embeddings):
def __init__(self, embed_model: str = EMBEDDING_MODEL):
self.embed_model = embed_model
def embed_documents(self, texts: List[str]) -> List[List[float]]:
result = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False)
embeddings = result.data if result and hasattr(result, 'data') else None
if not embeddings:
raise ValueError(f"Failed to get embeddings for texts: {texts[:2]}...")
return normalize(embeddings).tolist()
def embed_query(self, text: str) -> List[float]:
embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
return normalize(embeddings).tolist()
async def aembed_query(self, text: str) -> List[float]:
embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
# def score_threshold_process(score_threshold, k, docs):
# if score_threshold is not None:
# cmp = (
# operator.le
# )
# docs = [
# (doc, similarity)
# for doc, similarity in docs
# if cmp(similarity, score_threshold)
# ]
# return docs[:k]