import json import sys from typing import List, Dict from fastchat.conversation import Conversation from fastchat import conversation as conv from server.model_workers.base import * from server.model_workers.base import ApiEmbeddingsParams from configs import logger, log_verbose class OpenAIWorker(ApiModelWorker): """ 支持 OpenAI 格式 API 的 Worker,用于 embedding 和 chat """ DEFAULT_EMBED_MODEL = "text-embedding-ada-002" def __init__( self, *, model_names: List[str] = ["openai-api"], controller_addr: str = None, worker_addr: str = None, **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 8192) super().__init__(**kwargs) def do_chat(self, params: ApiChatParams) -> Dict: from openai import OpenAI params.load_config(self.model_names[0]) client = OpenAI( api_key=params.api_key, base_url=params.api_base_url, ) if log_verbose: logger.info(f'{self.__class__.__name__}:params: {params}') try: response = client.chat.completions.create( model=params.version or self.model_names[0], messages=params.messages, temperature=params.temperature, max_tokens=params.max_tokens, top_p=params.top_p, stream=True, ) for chunk in response: if chunk.choices: delta = chunk.choices[0].delta if delta.content: yield { "error_code": 0, "text": delta.content, } except Exception as e: logger.error(f"OpenAI API 请求错误: {e}") yield { "error_code": 500, "text": f"请求错误: {str(e)}", } def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: from openai import OpenAI # embed_texts 里用 worker_class() 默认 model_names 为 ["openai-api"],会错加载成 openai-api 的 base_url; # 在线嵌入应使用 ONLINE_LLM_MODEL 的键(如 bge-m3-api),由调用方写入 params.worker_name。 params.load_config(params.worker_name or self.model_names[0]) client = OpenAI( api_key=params.api_key, base_url=params.api_base_url, ) if log_verbose: logger.info(f'{self.__class__.__name__}:params: {params}') try: # OpenAI embedding API 每次最多处理 2048 个文本,这里分批处理 result = [] batch_size = 100 for i in range(0, len(params.texts), batch_size): batch_texts = params.texts[i:i+batch_size] response = client.embeddings.create( model=params.embed_model or self.DEFAULT_EMBED_MODEL, input=batch_texts, encoding_format="float", ) embeddings = [item.embedding for item in response.data] result.extend(embeddings) return {"code": 200, "data": result} except Exception as e: logger.error(f"OpenAI Embedding API 请求错误: {e}") return { "code": 500, "msg": f"Embedding 请求错误: {str(e)}", } def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: return conv.Conversation( name=self.model_names[0], system_message="You are a helpful assistant.", messages=[], roles=["user", "assistant", "system"], sep="\n", stop_str="", ) if __name__ == "__main__": import uvicorn from server.utils import MakeFastAPIOffline from fastchat.serve.model_worker import app worker = OpenAIWorker( controller_addr="http://127.0.0.1:20001", worker_addr="http://127.0.0.1:20008", ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) uvicorn.run(app, port=20008)