165 lines
7.4 KiB
Python
165 lines
7.4 KiB
Python
|
|
import asyncio
|
|||
|
|
from datetime import datetime
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
from pathlib import Path
|
|||
|
|
import shutil
|
|||
|
|
import uuid
|
|||
|
|
from fastapi import BackgroundTasks, Body, File, HTTPException, Query, UploadFile
|
|||
|
|
from fastapi.responses import FileResponse
|
|||
|
|
import requests
|
|||
|
|
from langchain.chains import LLMChain
|
|||
|
|
from langchain.prompts.chat import ChatPromptTemplate
|
|||
|
|
from configs.kb_config import KB_CHAT_TEMP_DIR
|
|||
|
|
from configs.model_config import LLM_MODELS
|
|||
|
|
from configs.translate_config import *
|
|||
|
|
from configs.basic_config import *
|
|||
|
|
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
|
|||
|
|
from server.chat.utils import History
|
|||
|
|
from server.custom.AsyncIteratorCallbackHandlerNew import AsyncIteratorCallbackHandler
|
|||
|
|
from server.utils import get_ChatOpenAI, get_prompt_template, wrap_done
|
|||
|
|
from typing import Any, AsyncIterable, Optional
|
|||
|
|
from sse_starlette.sse import EventSourceResponse
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
|
|||
|
|
def tarnslate_text(
|
|||
|
|
query:str = Body("苹果是红色的", description="翻译语句"),
|
|||
|
|
to_lang:str = Body("zh-cn", description="目标语言包括参数:zh-cn: 中文, en: English, ja: 日本語, ko: 한국어, fr: Français, de: Deutsch, es: Español, it: Italiano, pt: Português, ru: Русский, ar: العربية, hi: हिन्दी, bn: বাংলা, pa: ਪੰਜਾਬੀ, jv: Basa Jawa, ms: Bahasa Melayu, vi: Tiếng Việt, th: ไทย, tr: Türkçe, fa: فارسی, pl: Polski, uk: Українська, ro: Română, nl: Nederlands, el: Ελληνικά"),
|
|||
|
|
conversation_id:str = Body("", description="对话框ID"),
|
|||
|
|
uid: Optional[Any] = Body(None, description="用户ID"),
|
|||
|
|
file_name: Optional[Any] = Body(None, description="文件名称"),
|
|||
|
|
stream: bool = True):
|
|||
|
|
"""
|
|||
|
|
fake translate text
|
|||
|
|
"""
|
|||
|
|
async def chat_iterator() -> AsyncIterable[str]:
|
|||
|
|
nonlocal to_lang,query,stream
|
|||
|
|
# if file_name =="" or not file_name:
|
|||
|
|
# stream = False
|
|||
|
|
answer = ""
|
|||
|
|
temp_text = []
|
|||
|
|
if file_name:
|
|||
|
|
# 指定临时文件夹路径
|
|||
|
|
specified_dir = KB_CHAT_TEMP_DIR
|
|||
|
|
os.makedirs(specified_dir, exist_ok=True) # 确保目录存在
|
|||
|
|
|
|||
|
|
temp_file_path = f"{specified_dir}{file_name}.txt"
|
|||
|
|
|
|||
|
|
# 将文件内容写入指定文件夹中的临时文件
|
|||
|
|
|
|||
|
|
with open(temp_file_path, "r", encoding="utf-8") as file:
|
|||
|
|
query = file.read()
|
|||
|
|
temp_text1 = query.split("\n")
|
|||
|
|
temp_text = temp_text1 if temp_text1 else []
|
|||
|
|
# 删除临时文件
|
|||
|
|
# os.remove(temp_file_path)
|
|||
|
|
else:
|
|||
|
|
temp_text1 = query.split("\n")
|
|||
|
|
temp_text = temp_text1.copy()
|
|||
|
|
for text in temp_text:
|
|||
|
|
if text == "" and stream:
|
|||
|
|
yield json.dumps({"text": "\n"}, ensure_ascii=False)
|
|||
|
|
continue
|
|||
|
|
if text=="" and not stream:
|
|||
|
|
answer += "\n"
|
|||
|
|
continue
|
|||
|
|
message_id = str(uuid.uuid1())+"q"
|
|||
|
|
prompt_name = "translate_text"
|
|||
|
|
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
|||
|
|
chat_type="llm_chat",
|
|||
|
|
query=text)
|
|||
|
|
match to_lang:
|
|||
|
|
case "zh-cn": # 中文
|
|||
|
|
to_lang = "中文"
|
|||
|
|
case "en": # 英语
|
|||
|
|
to_lang = "English"
|
|||
|
|
case "ja": # 日语
|
|||
|
|
to_lang = "日本語"
|
|||
|
|
case "ko": # 韩语
|
|||
|
|
to_lang = "한국어"
|
|||
|
|
case "fr": # 法语
|
|||
|
|
to_lang = "Français"
|
|||
|
|
case "de": # 德语
|
|||
|
|
to_lang = "Deutsch"
|
|||
|
|
case "es": # 西班牙语
|
|||
|
|
to_lang = "Español"
|
|||
|
|
case "it": # 意大利语
|
|||
|
|
to_lang = "Italiano"
|
|||
|
|
case "pt": # 葡萄牙语
|
|||
|
|
to_lang = "Português"
|
|||
|
|
case "ru": # 俄语
|
|||
|
|
to_lang = "Русский"
|
|||
|
|
case "ar": # 阿拉伯语
|
|||
|
|
to_lang = "العربية"
|
|||
|
|
case "hi": # 印地语
|
|||
|
|
to_lang = "हिन्दी"
|
|||
|
|
case "bn": # 孟加拉语
|
|||
|
|
to_lang = "বাংলা"
|
|||
|
|
case "pa": # 旁遮普语
|
|||
|
|
to_lang = "ਪੰਜਾਬੀ"
|
|||
|
|
case "jv": # 爪哇语
|
|||
|
|
to_lang = "Basa Jawa"
|
|||
|
|
case "ms": # 马来语
|
|||
|
|
to_lang = "Bahasa Melayu"
|
|||
|
|
case "vi": # 越南语
|
|||
|
|
to_lang = "Tiếng Việt"
|
|||
|
|
case "th": # 泰语
|
|||
|
|
to_lang = "ไทย"
|
|||
|
|
case "tr": # 土耳其语
|
|||
|
|
to_lang = "Türkçe"
|
|||
|
|
case "fa": # 波斯语
|
|||
|
|
to_lang = "فارسی"
|
|||
|
|
case "pl": # 波兰语
|
|||
|
|
to_lang = "Polski"
|
|||
|
|
case "uk": # 乌克兰语
|
|||
|
|
to_lang = "Українська"
|
|||
|
|
case "ro": # 罗马尼亚语
|
|||
|
|
to_lang = "Română"
|
|||
|
|
case "nl": # 荷兰语
|
|||
|
|
to_lang = "Nederlands"
|
|||
|
|
case "el": # 希腊语
|
|||
|
|
to_lang = "Ελληνικά"
|
|||
|
|
callback = AsyncIteratorCallbackHandler()
|
|||
|
|
callbacks = [callback]
|
|||
|
|
callbacks.append(conversation_callback)
|
|||
|
|
model = get_ChatOpenAI(
|
|||
|
|
model_name=LLM_MODELS[0],
|
|||
|
|
temperature=0.2,
|
|||
|
|
max_tokens=512,
|
|||
|
|
callbacks=callbacks,
|
|||
|
|
)
|
|||
|
|
prompt_template = get_prompt_template("llm_chat", prompt_name)
|
|||
|
|
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
|
|||
|
|
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
|
|||
|
|
chain = LLMChain(prompt=chat_prompt, llm=model)
|
|||
|
|
time = datetime.now().strftime("%Y%m%d")
|
|||
|
|
|
|||
|
|
task = asyncio.create_task(wrap_done(
|
|||
|
|
chain.acall({"input": f"<关键指令>必须翻译如下文本为{to_lang}语言<关键指令/>文本如下:\n{text}", "time": time, "lang":to_lang}),
|
|||
|
|
callback.done),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if stream:
|
|||
|
|
async for token in callback.aiter():
|
|||
|
|
# Use server-sent-events to stream the response
|
|||
|
|
yield json.dumps(
|
|||
|
|
{"text": token, "message_id": message_id},
|
|||
|
|
ensure_ascii=False)
|
|||
|
|
if file_name:
|
|||
|
|
yield json.dumps(
|
|||
|
|
{"text": "\n", "message_id": message_id},
|
|||
|
|
ensure_ascii=False)
|
|||
|
|
else:
|
|||
|
|
|
|||
|
|
async for token in callback.aiter():
|
|||
|
|
answer += token
|
|||
|
|
|
|||
|
|
if file_name:
|
|||
|
|
answer+="\n"
|
|||
|
|
await task
|
|||
|
|
if answer:
|
|||
|
|
yield json.dumps(
|
|||
|
|
{"text": answer, "message_id": message_id},
|
|||
|
|
ensure_ascii=False)
|
|||
|
|
return EventSourceResponse(chat_iterator())
|