[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
124
langchain-chat/server/embeddings_api.py
Normal file
124
langchain-chat/server/embeddings_api.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from langchain.docstore.document import Document
|
||||
from configs import EMBEDDING_MODEL, logger
|
||||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
from server.utils import (
|
||||
BaseResponse,
|
||||
get_model_worker_config,
|
||||
list_embed_models,
|
||||
list_online_embed_models,
|
||||
resolve_embed_model_name,
|
||||
)
|
||||
from fastapi import Body
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from typing import Dict, List
|
||||
|
||||
online_embed_models = list_online_embed_models()
|
||||
|
||||
|
||||
def embed_texts(
|
||||
texts: List[str],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
try:
|
||||
orig = embed_model
|
||||
embed_model = resolve_embed_model_name(embed_model)
|
||||
if embed_model != orig:
|
||||
logger.info(f"embed_texts 嵌入名解析: {orig} -> {embed_model}")
|
||||
logger.info(f"embed_texts called with model={embed_model}, texts_count={len(texts)}")
|
||||
# bge-m3-api 等同时写在 MODEL_PATH 与 ONLINE_LLM_MODEL 时,须优先走内网 OpenAI 兼容 embedding API
|
||||
if embed_model in list_online_embed_models(): # 使用在线API
|
||||
logger.info(f"Using online embeddings model: {embed_model}")
|
||||
config = get_model_worker_config(embed_model)
|
||||
logger.info(f"Config: {config}")
|
||||
worker_class = config.get("worker_class")
|
||||
embed_model_name = config.get("embed_model") or config.get("model_name")
|
||||
logger.info(f"worker_class: {worker_class}, embed_model_name: {embed_model_name}")
|
||||
if worker_class is None:
|
||||
return BaseResponse(code=500, msg=f"未找到 {embed_model} 的 worker_class 配置")
|
||||
worker = worker_class()
|
||||
if worker_class.can_embedding():
|
||||
params = ApiEmbeddingsParams(
|
||||
texts=texts,
|
||||
to_query=to_query,
|
||||
embed_model=embed_model_name,
|
||||
worker_name=embed_model,
|
||||
)
|
||||
logger.info(f"Calling do_embeddings with params: {params}")
|
||||
resp = worker.do_embeddings(params)
|
||||
logger.info(f"do_embeddings response: {resp}")
|
||||
return BaseResponse(**resp)
|
||||
else:
|
||||
return BaseResponse(code=500, msg=f"模型 {embed_model} 不支持嵌入功能")
|
||||
|
||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||
logger.info(f"Using local embeddings model: {embed_model}")
|
||||
from server.utils import load_local_embeddings
|
||||
|
||||
embeddings = load_local_embeddings(model=embed_model)
|
||||
return BaseResponse(data=embeddings.embed_documents(texts))
|
||||
|
||||
return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。")
|
||||
except Exception as e:
|
||||
logger.error(f"embed_texts error: {e}", exc_info=True)
|
||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||
|
||||
|
||||
async def aembed_texts(
|
||||
texts: List[str],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
try:
|
||||
embed_model = resolve_embed_model_name(embed_model)
|
||||
if embed_model in list_online_embed_models(): # 与 embed_texts 一致:内网 API 优先
|
||||
return await run_in_threadpool(embed_texts,
|
||||
texts=texts,
|
||||
embed_model=embed_model,
|
||||
to_query=to_query)
|
||||
|
||||
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||
from server.utils import load_local_embeddings
|
||||
|
||||
embeddings = load_local_embeddings(model=embed_model)
|
||||
return BaseResponse(data=await embeddings.aembed_documents(texts))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||
|
||||
|
||||
def embed_texts_endpoint(
|
||||
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
||||
embed_model: str = Body(EMBEDDING_MODEL,
|
||||
description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"),
|
||||
to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"),
|
||||
) -> BaseResponse:
|
||||
'''
|
||||
对文本进行向量化,返回 BaseResponse(data=List[List[float]])
|
||||
'''
|
||||
return embed_texts(texts=texts, embed_model=embed_model, to_query=to_query)
|
||||
|
||||
|
||||
def embed_documents(
|
||||
docs: List[Document],
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
to_query: bool = False,
|
||||
) -> Dict:
|
||||
"""
|
||||
将 List[Document] 向量化,转化为 VectorStore.add_embeddings 可以接受的参数
|
||||
"""
|
||||
texts = [x.page_content for x in docs]
|
||||
metadatas = [x.metadata for x in docs]
|
||||
embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data
|
||||
if embeddings is not None:
|
||||
return {
|
||||
"texts": texts,
|
||||
"embeddings": embeddings,
|
||||
"metadatas": metadatas,
|
||||
}
|
||||
Reference in New Issue
Block a user