44 lines
1.6 KiB
Python
44 lines
1.6 KiB
Python
import asyncio
|
|
import json
|
|
from typing import List, Union
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from server.chat.chat import chat
|
|
from server.chat.utils import History
|
|
|
|
|
|
async def chat_with_Yi34B_iter(query: str,
|
|
stream=False,
|
|
model_name="qianfan-api",
|
|
history: Union[int, List[History]] = None,
|
|
conversation_id='',
|
|
temperature=0.7,
|
|
max_tokens=None,
|
|
history_len=3,
|
|
prompt_name="default"
|
|
):
|
|
response = await chat(query=query, history=history,
|
|
history_len=history_len,
|
|
conversation_id=conversation_id,
|
|
stream=stream, model_name=model_name, temperature=temperature,
|
|
max_tokens=max_tokens, prompt_name=prompt_name)
|
|
|
|
contents = ""
|
|
async for data in response.body_iterator: # 这里的data是一个json字符串
|
|
data = json.loads(data)
|
|
contents += data["text"]
|
|
|
|
return contents
|
|
|
|
|
|
def chat_with_Yi34B(query: str, model_name: str = "qianfan-api", conversation_id: str = '',
|
|
history: Union[int, List[History]] = None):
|
|
# 格式化查询字符串
|
|
return asyncio.run(chat_with_Yi34B_iter(query, model_name=model_name, conversation_id=conversation_id,
|
|
history=history))
|
|
|
|
|
|
class ChatWithYi34BInput(BaseModel):
|
|
location: str = Field(description="Query for any kind of chats and questions")
|