from langchain.docstore.document import Document from configs import EMBEDDING_MODEL, logger from server.model_workers.base import ApiEmbeddingsParams from server.utils import ( BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models, resolve_embed_model_name, ) from fastapi import Body from fastapi.concurrency import run_in_threadpool from typing import Dict, List online_embed_models = list_online_embed_models() def embed_texts( texts: List[str], embed_model: str = EMBEDDING_MODEL, to_query: bool = False, ) -> BaseResponse: ''' 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) ''' try: orig = embed_model embed_model = resolve_embed_model_name(embed_model) if embed_model != orig: logger.info(f"embed_texts 嵌入名解析: {orig} -> {embed_model}") logger.info(f"embed_texts called with model={embed_model}, texts_count={len(texts)}") # bge-m3-api 等同时写在 MODEL_PATH 与 ONLINE_LLM_MODEL 时,须优先走内网 OpenAI 兼容 embedding API if embed_model in list_online_embed_models(): # 使用在线API logger.info(f"Using online embeddings model: {embed_model}") config = get_model_worker_config(embed_model) logger.info(f"Config: {config}") worker_class = config.get("worker_class") embed_model_name = config.get("embed_model") or config.get("model_name") logger.info(f"worker_class: {worker_class}, embed_model_name: {embed_model_name}") if worker_class is None: return BaseResponse(code=500, msg=f"未找到 {embed_model} 的 worker_class 配置") worker = worker_class() if worker_class.can_embedding(): params = ApiEmbeddingsParams( texts=texts, to_query=to_query, embed_model=embed_model_name, worker_name=embed_model, ) logger.info(f"Calling do_embeddings with params: {params}") resp = worker.do_embeddings(params) logger.info(f"do_embeddings response: {resp}") return BaseResponse(**resp) else: return BaseResponse(code=500, msg=f"模型 {embed_model} 不支持嵌入功能") if embed_model in list_embed_models(): # 使用本地Embeddings模型 logger.info(f"Using local embeddings model: {embed_model}") from server.utils import load_local_embeddings embeddings = load_local_embeddings(model=embed_model) return BaseResponse(data=embeddings.embed_documents(texts)) return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。") except Exception as e: logger.error(f"embed_texts error: {e}", exc_info=True) return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") async def aembed_texts( texts: List[str], embed_model: str = EMBEDDING_MODEL, to_query: bool = False, ) -> BaseResponse: ''' 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) ''' try: embed_model = resolve_embed_model_name(embed_model) if embed_model in list_online_embed_models(): # 与 embed_texts 一致:内网 API 优先 return await run_in_threadpool(embed_texts, texts=texts, embed_model=embed_model, to_query=to_query) if embed_model in list_embed_models(): # 使用本地Embeddings模型 from server.utils import load_local_embeddings embeddings = load_local_embeddings(model=embed_model) return BaseResponse(data=await embeddings.aembed_documents(texts)) except Exception as e: logger.error(e) return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") def embed_texts_endpoint( texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]), embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"), to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"), ) -> BaseResponse: ''' 对文本进行向量化,返回 BaseResponse(data=List[List[float]]) ''' return embed_texts(texts=texts, embed_model=embed_model, to_query=to_query) def embed_documents( docs: List[Document], embed_model: str = EMBEDDING_MODEL, to_query: bool = False, ) -> Dict: """ 将 List[Document] 向量化,转化为 VectorStore.add_embeddings 可以接受的参数 """ texts = [x.page_content for x in docs] metadatas = [x.metadata for x in docs] embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data if embeddings is not None: return { "texts": texts, "embeddings": embeddings, "metadatas": metadatas, }