from datetime import datetime import operator from abc import ABC, abstractmethod import re import os from pathlib import Path import numpy as np from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document from joblib import Parallel, delayed import time from textrank4zh import TextRank4Keyword, TextRank4Sentence import multiprocessing from configs.model_config import LLM_MODELS from server.chat.policy_fun_iast import get_llm_model_response from server.chat.utils import get_personal_knowledge_map, get_similar_documents1 from nltk.tokenize import sent_tokenize import logging # 配置日志 logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") def generate_weights_as_list(length, total_sum=80): if length == 1: return [50] # Create a range of numbers decreasing logarithmically x = np.linspace(0, length - 1, length) weights = np.exp(-x / (length / 5)) # Normalize the weights to match the specified sum normalized_weights = weights / sum(weights) * total_sum integer_weights = np.round(normalized_weights).astype(int) # Adjust the weights to match the exact sum if rounding causes deviation adjustment = total_sum - sum(integer_weights) for i in range(abs(adjustment)): if adjustment > 0: integer_weights[i % length] += 1 elif adjustment < 0: integer_weights[i % length] -= 1 return integer_weights.tolist() def score_threshold_process(query,score_threshold, k, docs): """ 根据分数阈值过滤和使用TextRank摘要文档,并返回前k个文档。 :param score_threshold: 相似度分数阈值;忽略低于此阈值的文档。 :param k: 要返回的顶部文档数量。 :param docs: 文档列表,每个文档是一个元组(文档,相似度分数)。 :return: 根据分数阈值返回的前k个文档的列表。 """ # 如果提供了score_threshold,则只过滤大于阈值的文档。 if score_threshold is not None: cmp = ( operator.le ) docs = [ (doc, similarity) for doc, similarity in docs if cmp(similarity, score_threshold) ] # 当召回结果都大于score_threshold时 if len(docs) == 0: return docs result = [] try: for doc in docs: if query.replace(" ","").replace("\n","").replace("\r","") in doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r",""): result.append(doc) except Exception as e: for doc in docs: if query.replace(" ","").replace("\n","").replace("\r","") in doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r",""): result.append(doc) if len(docs) > 0 and not "h1" in docs[0][0].metadata: # 如果存在用户的问题在标题中的情况则进行去重操作且不需要再匹配相关度,只需要把问题在标题中的文献提交出去 if len(result) > 0: temp={} for doc in result: if doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") not in temp: temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")] = doc else: if temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") : continue elif temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") : continue else: temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content docs = [] for i in temp: docs.append(temp[i]) else: try: sentences = [] sentences_page_content = [] for doc in docs: meta = doc[0].metadata # 若缺少标题或为空,则用正文首句作为标题(最多50字) if meta["title"] == "": meta["title"] = doc[0].page_content # 如有摘要则替换 page_content,保证后续文本更简洁 summary = meta.get("summary") if summary: doc[0].page_content = summary sentences = [doc[0].metadata["title"] for doc in docs] sentences_page_content = [str(i+1)+":【"+doc[0].metadata["title"]+doc[0].page_content+"】" for i,doc in enumerate(docs)] except Exception as e: sentences = [doc[0].metadata["source"] for doc in docs] sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"】" for i,doc in enumerate(docs)] res = get_llm_model_response( strategy_name="default_similar", llm_model_name=LLM_MODELS[0], template_prompt_name="default_similar", prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")}, temperature=0.01, max_tokens=512 ) try: index =[] if res == "无": index = [] else: index = res.split(",") index = [int(i)-1 for i in index] docs = get_similar_documents1(index=index,sentences=sentences,query=query, docs=docs, top_k=k) except Exception as e: print(e) docs = get_similar_documents1(index=[],sentences=sentences,query=query, docs=docs, top_k=k) # 去重操作只针对通用知识库 temp={} for doc in docs: try: if doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") not in temp: temp[doc[0].metadata["title"]] = doc else: if temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") : continue elif temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","") : continue else: temp[doc[0].metadata["title"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content except Exception as e: print(e) if doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") not in temp: temp[doc[0].metadata["source"]] = doc else: if temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") : continue elif temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") : continue else: temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content docs = [] for i in temp: docs.append(temp[i]) #只针对个人知识库 if "h1" in docs[0][0].metadata: all_source = [doc[0].metadata["source"] for doc in docs] unique_source = list(set(all_source)) all_title_map = get_personal_knowledge_map(unique_source) for doc in docs: doc[0].metadata["uuid_name"] = doc[0].metadata["source"] if doc[0].metadata["source"] in all_title_map: doc[0].metadata["source"] = all_title_map[doc[0].metadata["source"]] else: pass try: sentences = [doc[0].metadata["source"] for doc in docs] sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"】" for i,doc in enumerate(docs)] except Exception as e: sentences = [doc[0].metadata["source"] for doc in docs] sentences_page_content = [str(i+1)+":【"+doc[0].metadata["source"]+doc[0].page_content+"】" for i,doc in enumerate(docs)] kwargs = {} kwargs["extra_body"] = {"chat_template_kwargs": {"enable_thinking": True}} res = get_llm_model_response( strategy_name="default_similar", llm_model_name=LLM_MODELS[0], template_prompt_name="default_similar", prompt_param_dict={"input": query, "title": f"{sentences_page_content}", "time": datetime.now().strftime("%Y%m%d")}, temperature=0.01, max_tokens=None, **kwargs ) res = re.sub(r'.*?', '', res,flags=re.DOTALL) try: index =[] if res == "无": index = [] else: index = res.split(",") index = [int(i)-1 for i in index] docs = get_similar_documents1(index=index,sentences=sentences,query=query, docs=docs, top_k=k) except Exception as e: print(e) docs = get_similar_documents1(index=[],sentences=sentences,query=query, docs=docs, top_k=k) # 去重操作只针对通用知识库 temp={} for doc in docs: if doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") not in temp: temp[doc[0].metadata["source"]] = doc else: if temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].page_content.replace(" ","").replace("\n","").replace("\r","") : continue elif temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content.replace(" ","").replace("\n","").replace("\r","") == doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","") : continue else: temp[doc[0].metadata["source"].replace(" ","").replace("\n","").replace("\r","")][0].page_content += doc[0].page_content docs = [] for i in temp: docs.append(temp[i]) if len(docs) == 0: return docs # 为TextRank算法生成权重。 cont = generate_weights_as_list(len(docs)) # 处理每个文档以提取或分配摘要。 for i, (doc, _) in enumerate(docs): summary_sources = ['content', 'abstract', 'text'] # 根据不同知识库遍历字段 for source in summary_sources: try: if docs[i][0].metadata["title"] in docs[i][0].page_content or len(docs[i][0].page_content) < 100: doc.metadata['summary'] = TextRank(doc.metadata[source], cont[i]) if len(doc.metadata['summary']) >15000: doc.metadata['summary'] = TextRank(doc.metadata[source], 1) break else: doc.metadata['summary'] = docs[i][0].page_content except KeyError: doc.metadata['summary'] = docs[i][0].page_content continue # 如果当前源失败,则尝试下一个源。 # 返回前k个文档。 return docs[:k] # 猴子补丁,为了兼容TexRank, import networkx as nx import numpy as np # 进行猴子补丁 nx.from_numpy_matrix = nx.from_numpy_array # 进行猴子补丁,入数据类型兼容性检查 def process_text_segment(text_segment, num_sentences): tr4w = TextRank4Keyword() tr4w.analyze(text=text_segment, lower=True, window=5) keywords = [(item.word, item.weight) for item in tr4w.get_keywords(30, word_min_len=4)] tr4s = TextRank4Sentence() tr4s.analyze(text=text_segment, lower=True, source='all_filters') summaries = [item.sentence for item in tr4s.get_key_sentences(num=num_sentences)] return keywords, summaries def split_text_balanced(text, n_parts): sentences = sent_tokenize(text) min_sentences_per_part = 10 n_parts = max(1, min(n_parts, len(sentences) // min_sentences_per_part)) k, m = divmod(len(sentences), n_parts) return [' '.join(sentences[i * k + min(i, m):(i + 1) * k + min(i + 1, m)]) for i in range(n_parts)] def TextRank(text,num_sentences, n_cores=multiprocessing.cpu_count()): start_time = time.time() text_parts = split_text_balanced(text, n_cores) all_keywords = [] all_summaries = [] # 在这里直接顺序处理每个分块,或用线程池而非进程池 for part in text_parts: keywords, summaries = process_text_segment(part, num_sentences) all_keywords.extend(keywords) all_summaries.extend(summaries) # Print results # print('关键词:') for word, weight in sorted(all_keywords, key=lambda x: x[1], reverse=True): print(word, weight) end_time = time.time() logging.info(f"TextRank耗时: {end_time - start_time:.2f}秒") all_summaries = "".join(all_summaries) return all_summaries from server.db.repository.knowledge_base_repository import ( add_kb_to_db, delete_kb_from_db, list_kbs_from_db, kb_exists, load_kb_from_db, get_kb_detail, ) from server.db.repository.knowledge_file_repository import ( add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db, count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db, list_docs_from_db, ) from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,EXPR, EMBEDDING_MODEL, KB_INFO) from server.knowledge_base.utils import ( get_kb_path, get_doc_path, KnowledgeFile, list_kbs_from_folder, list_files_from_folder, ) from typing import List, Union, Dict, Optional, Tuple from server.embeddings_api import embed_texts, aembed_texts, embed_documents from server.knowledge_base.model.kb_document_model import DocumentWithVSId import time def get_emb_time(f): def inner(*arg,**kwarg): s_time = time.time() res = f(*arg,**kwarg) e_time = time.time() print('向量化耗时:{}秒'.format(e_time - s_time)) return res return inner def normalize(embeddings: List[List[float]]) -> np.ndarray: ''' sklearn.preprocessing.normalize 的替代(使用 L2),避免安装 scipy, scikit-learn ''' # 过滤掉 None 值 embeddings = [e for e in embeddings if e is not None] if not embeddings: raise ValueError("No valid embeddings found (all are None)") embeddings = np.array(embeddings) norm = np.linalg.norm(embeddings, axis=1) norm = np.reshape(norm, (norm.shape[0], 1)) norm = np.tile(norm, (1, len(embeddings[0]))) return np.divide(embeddings, norm) class SupportedVSType: FAISS = 'faiss' MILVUS = 'milvus' DEFAULT = 'default' ZILLIZ = 'zilliz' PG = 'pg' ES = 'es' CHROMADB = 'chromadb' class KBService(ABC): def __init__(self, knowledge_base_name: str, embed_model: str = EMBEDDING_MODEL, ): self.kb_name = knowledge_base_name self.kb_info = KB_INFO.get(knowledge_base_name, f"关于{knowledge_base_name}的知识库") self.embed_model = embed_model self.kb_path = get_kb_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name) self.do_init() def __repr__(self) -> str: return f"{self.kb_name} @ {self.embed_model}" def save_vector_store(self): ''' 保存向量库:FAISS保存到磁盘,milvus保存到数据库。PGVector暂未支持 ''' pass def create_kb(self): """ 创建知识库 """ if not os.path.exists(self.doc_path): os.makedirs(self.doc_path) self.do_create_kb() status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model) return status def clear_vs(self): """ 删除向量库中所有内容 """ self.do_clear_vs() status = delete_files_from_db(self.kb_name) return status def drop_kb(self): """ 删除知识库 """ self.do_drop_kb() status = delete_kb_from_db(self.kb_name) return status def _docs_to_embeddings(self, docs: List[Document]) -> Dict: ''' 将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数 ''' return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False) @get_emb_time def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): """ 向知识库添加文件 如果指定了docs,则不再将文本向量化,并将数据库对应条目标为custom_docs=True """ if docs: custom_docs = True for doc in docs: doc.metadata.setdefault("source", kb_file.filename) else: docs = kb_file.file2text() custom_docs = False if docs: # 将 metadata["source"] 改为相对路径 for doc in docs: try: source = doc.metadata.get("source", "") if os.path.isabs(source): rel_path = Path(source).relative_to(self.doc_path) doc.metadata["source"] = str(rel_path.as_posix().strip("/")) except Exception as e: print(f"cannot convert absolute path ({source}) to relative path. error is : {e}") self.delete_doc(kb_file) doc_infos = self.do_add_doc(docs, **kwargs) status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs), doc_infos=doc_infos) else: status = False return status def delete_doc(self, kb_file: KnowledgeFile, delete_content: bool = False, **kwargs): """ 从知识库删除文件 """ self.do_delete_doc(kb_file, **kwargs) status = delete_file_from_db(kb_file) if delete_content and os.path.exists(kb_file.filepath): os.remove(kb_file.filepath) return status def update_info(self, kb_info: str): """ 更新知识库介绍 """ self.kb_info = kb_info status = add_kb_to_db(self.kb_name, self.kb_info, self.vs_type(), self.embed_model) return status def update_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): """ 使用content中的文件更新向量库 如果指定了docs,则使用自定义docs,并将数据库对应条目标为custom_docs=True """ if os.path.exists(kb_file.filepath): self.delete_doc(kb_file, **kwargs) return self.add_doc(kb_file, docs=docs, **kwargs) def exist_doc(self, file_name: str): return file_exists_in_db(KnowledgeFile(knowledge_base_name=self.kb_name, filename=file_name)) def list_files(self): return list_files_from_db(self.kb_name) def count_files(self): return count_files_from_db(self.kb_name) def search_docs(self, query: str, top_k: int = VECTOR_SEARCH_TOP_K, score_threshold: float = SCORE_THRESHOLD, expr: str = EXPR, custom_strategy_config: dict = {} ) ->List[Document]: docs = self.do_search(query, top_k, score_threshold, expr, custom_strategy_config) return docs def get_doc_by_ids(self, ids: List[str]) -> List[Document]: return [] def get_doc_by_sources_name(self, source_name_list: List[str]) -> List[Document]: return [] def del_doc_by_ids(self, ids: List[str]) -> bool: raise NotImplementedError def update_doc_by_ids(self, docs: Dict[str, Document]) -> bool: ''' 传入参数为: {doc_id: Document, ...} 如果对应 doc_id 的值为 None,或其 page_content 为空,则删除该文档 ''' self.del_doc_by_ids(list(docs.keys())) docs = [] ids = [] for k, v in docs.items(): if not v or not v.page_content.strip(): continue ids.append(k) docs.append(v) self.do_add_doc(docs=docs, ids=ids) return True def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[DocumentWithVSId]: ''' 通过file_name或metadata检索Document ''' doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata) docs = [] for x in doc_infos: doc_info = self.get_doc_by_ids([x["id"]])[0] if doc_info is not None: # 处理非空的情况 doc_with_id = DocumentWithVSId(**doc_info.dict(), id=x["id"]) docs.append(doc_with_id) else: # 处理空的情况 # 可以选择跳过当前循环迭代或执行其他操作 pass return docs @abstractmethod def do_create_kb(self): """ 创建知识库子类实自己逻辑 """ pass @staticmethod def list_kbs_type(): return list(kbs_config.keys()) @classmethod def list_kbs(cls): return list_kbs_from_db() def exists(self, kb_name: str = None): kb_name = kb_name or self.kb_name return kb_exists(kb_name) @abstractmethod def vs_type(self) -> str: pass @abstractmethod def do_init(self): pass @abstractmethod def do_drop_kb(self): """ 删除知识库子类实自己逻辑 """ pass @abstractmethod def do_search(self, query: str, top_k: int, score_threshold: float, expr: str, custom_strategy_config: dict = {}, ) -> List[Tuple[Document, float]]: """ 搜索知识库子类实自己逻辑 """ pass @abstractmethod def do_add_doc(self, docs: List[Document], **kwargs, ) -> List[Dict]: """ 向知识库添加文档子类实自己逻辑 """ pass @abstractmethod def do_delete_doc(self, kb_file: KnowledgeFile): """ 从知识库删除文档子类实自己逻辑 """ pass @abstractmethod def do_clear_vs(self): """ 从知识库删除全部向量子类实自己逻辑 """ pass class KBServiceFactory: @staticmethod def get_service(kb_name: str, vector_store_type: Union[str, SupportedVSType], embed_model: str = EMBEDDING_MODEL, ) -> KBService: if isinstance(vector_store_type, str): vector_store_type = getattr(SupportedVSType, vector_store_type.upper()) if SupportedVSType.FAISS == vector_store_type: from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService return FaissKBService(kb_name, embed_model=embed_model) elif SupportedVSType.PG == vector_store_type: from server.knowledge_base.kb_service.pg_kb_service import PGKBService return PGKBService(kb_name, embed_model=embed_model) elif SupportedVSType.MILVUS == vector_store_type: from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService return MilvusKBService(kb_name,embed_model=embed_model) elif SupportedVSType.ZILLIZ == vector_store_type: from server.knowledge_base.kb_service.zilliz_kb_service import ZillizKBService return ZillizKBService(kb_name, embed_model=embed_model) elif SupportedVSType.DEFAULT == vector_store_type: from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService return MilvusKBService(kb_name, embed_model=embed_model) # other milvus parameters are set in model_config.kbs_config elif SupportedVSType.ES == vector_store_type: from server.knowledge_base.kb_service.es_kb_service import ESKBService return ESKBService(kb_name, embed_model=embed_model) elif SupportedVSType.CHROMADB == vector_store_type: from server.knowledge_base.kb_service.chromadb_kb_service import ChromaKBService return ChromaKBService(kb_name, embed_model=embed_model) elif SupportedVSType.DEFAULT == vector_store_type: # kb_exists of default kbservice is False, to make validation easier. from server.knowledge_base.kb_service.default_kb_service import DefaultKBService return DefaultKBService(kb_name) @staticmethod def get_service_by_name(kb_name: str) -> KBService: _, vs_type, embed_model = load_kb_from_db(kb_name) if _ is None: # kb not in db, just return None return None from server.utils import resolve_embed_model_name embed_model = resolve_embed_model_name(embed_model) return KBServiceFactory.get_service(kb_name, vs_type, embed_model) @staticmethod def get_default(): return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT) def get_kb_details() -> List[Dict]: kbs_in_folder = list_kbs_from_folder() kbs_in_db = KBService.list_kbs() result = {} for kb in kbs_in_folder: result[kb] = { "kb_name": kb, "vs_type": "", "kb_info": "", "embed_model": "", "file_count": 0, "create_time": None, "in_folder": True, "in_db": False, } for kb in kbs_in_db: kb_detail = get_kb_detail(kb) if kb_detail: kb_detail["in_db"] = True if kb in result: result[kb].update(kb_detail) else: kb_detail["in_folder"] = False result[kb] = kb_detail data = [] for i, v in enumerate(result.values()): v['No'] = i + 1 data.append(v) return data def get_kb_file_details(kb_name: str) -> List[Dict]: kb = KBServiceFactory.get_service_by_name(kb_name) if kb is None: return [] files_in_folder = list_files_from_folder(kb_name) files_in_db = kb.list_files() result = {} for doc in files_in_folder: result[doc] = { "kb_name": kb_name, "file_name": doc, "file_ext": os.path.splitext(doc)[-1], "file_version": 0, "document_loader": "", "docs_count": 0, "text_splitter": "", "create_time": None, "in_folder": True, "in_db": False, } lower_names = {x.lower(): x for x in result} for doc in files_in_db: doc_detail = get_file_detail(kb_name, doc) if doc_detail: doc_detail["in_db"] = True if doc.lower() in lower_names: result[lower_names[doc.lower()]].update(doc_detail) else: doc_detail["in_folder"] = False result[doc] = doc_detail data = [] for i, v in enumerate(result.values()): v['No'] = i + 1 data.append(v) return data class EmbeddingsFunAdapter(Embeddings): def __init__(self, embed_model: str = EMBEDDING_MODEL): self.embed_model = embed_model def embed_documents(self, texts: List[str]) -> List[List[float]]: result = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False) embeddings = result.data if result and hasattr(result, 'data') else None if not embeddings: raise ValueError(f"Failed to get embeddings for texts: {texts[:2]}...") return normalize(embeddings).tolist() def embed_query(self, text: str) -> List[float]: embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data query_embed = embeddings[0] query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 normalized_query_embed = normalize(query_embed_2d) return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 async def aembed_documents(self, texts: List[str]) -> List[List[float]]: embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data return normalize(embeddings).tolist() async def aembed_query(self, text: str) -> List[float]: embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data query_embed = embeddings[0] query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 normalized_query_embed = normalize(query_embed_2d) return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 # def score_threshold_process(score_threshold, k, docs): # if score_threshold is not None: # cmp = ( # operator.le # ) # docs = [ # (doc, similarity) # for doc, similarity in docs # if cmp(similarity, score_threshold) # ] # return docs[:k]