Files
gangyan/langchain-chat/server/model_workers/openai.py

125 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import sys
from typing import List, Dict
from fastchat.conversation import Conversation
from fastchat import conversation as conv
from server.model_workers.base import *
from server.model_workers.base import ApiEmbeddingsParams
from configs import logger, log_verbose
class OpenAIWorker(ApiModelWorker):
"""
支持 OpenAI 格式 API 的 Worker用于 embedding 和 chat
"""
DEFAULT_EMBED_MODEL = "text-embedding-ada-002"
def __init__(
self,
*,
model_names: List[str] = ["openai-api"],
controller_addr: str = None,
worker_addr: str = None,
**kwargs,
):
kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr)
kwargs.setdefault("context_len", 8192)
super().__init__(**kwargs)
def do_chat(self, params: ApiChatParams) -> Dict:
from openai import OpenAI
params.load_config(self.model_names[0])
client = OpenAI(
api_key=params.api_key,
base_url=params.api_base_url,
)
if log_verbose:
logger.info(f'{self.__class__.__name__}:params: {params}')
try:
response = client.chat.completions.create(
model=params.version or self.model_names[0],
messages=params.messages,
temperature=params.temperature,
max_tokens=params.max_tokens,
top_p=params.top_p,
stream=True,
)
for chunk in response:
if chunk.choices:
delta = chunk.choices[0].delta
if delta.content:
yield {
"error_code": 0,
"text": delta.content,
}
except Exception as e:
logger.error(f"OpenAI API 请求错误: {e}")
yield {
"error_code": 500,
"text": f"请求错误: {str(e)}",
}
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
from openai import OpenAI
# embed_texts 里用 worker_class() 默认 model_names 为 ["openai-api"],会错加载成 openai-api 的 base_url
# 在线嵌入应使用 ONLINE_LLM_MODEL 的键(如 bge-m3-api由调用方写入 params.worker_name。
params.load_config(params.worker_name or self.model_names[0])
client = OpenAI(
api_key=params.api_key,
base_url=params.api_base_url,
)
if log_verbose:
logger.info(f'{self.__class__.__name__}:params: {params}')
try:
# OpenAI embedding API 每次最多处理 2048 个文本,这里分批处理
result = []
batch_size = 100
for i in range(0, len(params.texts), batch_size):
batch_texts = params.texts[i:i+batch_size]
response = client.embeddings.create(
model=params.embed_model or self.DEFAULT_EMBED_MODEL,
input=batch_texts,
encoding_format="float",
)
embeddings = [item.embedding for item in response.data]
result.extend(embeddings)
return {"code": 200, "data": result}
except Exception as e:
logger.error(f"OpenAI Embedding API 请求错误: {e}")
return {
"code": 500,
"msg": f"Embedding 请求错误: {str(e)}",
}
def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation:
return conv.Conversation(
name=self.model_names[0],
system_message="You are a helpful assistant.",
messages=[],
roles=["user", "assistant", "system"],
sep="\n",
stop_str="",
)
if __name__ == "__main__":
import uvicorn
from server.utils import MakeFastAPIOffline
from fastchat.serve.model_worker import app
worker = OpenAIWorker(
controller_addr="http://127.0.0.1:20001",
worker_addr="http://127.0.0.1:20008",
)
sys.modules["fastchat.serve.model_worker"].worker = worker
MakeFastAPIOffline(app)
uvicorn.run(app, port=20008)