Files
gangyan/langchain-chat/server/chat/translate.py

165 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
from datetime import datetime
import json
import os
from pathlib import Path
import shutil
import uuid
from fastapi import BackgroundTasks, Body, File, HTTPException, Query, UploadFile
from fastapi.responses import FileResponse
import requests
from langchain.chains import LLMChain
from langchain.prompts.chat import ChatPromptTemplate
from configs.kb_config import KB_CHAT_TEMP_DIR
from configs.model_config import LLM_MODELS
from configs.translate_config import *
from configs.basic_config import *
from server.callback_handler.conversation_callback_handler import ConversationCallbackHandler
from server.chat.utils import History
from server.custom.AsyncIteratorCallbackHandlerNew import AsyncIteratorCallbackHandler
from server.utils import get_ChatOpenAI, get_prompt_template, wrap_done
from typing import Any, AsyncIterable, Optional
from sse_starlette.sse import EventSourceResponse
from pydantic import BaseModel, Field
def tarnslate_text(
query:str = Body("苹果是红色的", description="翻译语句"),
to_lang:str = Body("zh-cn", description="目标语言包括参数zh-cn: 中文, en: English, ja: 日本語, ko: 한국어, fr: Français, de: Deutsch, es: Español, it: Italiano, pt: Português, ru: Русский, ar: العربية, hi: हिन्दी, bn: বাংলা, pa: ਪੰਜਾਬੀ, jv: Basa Jawa, ms: Bahasa Melayu, vi: Tiếng Việt, th: ไทย, tr: Türkçe, fa: فارسی, pl: Polski, uk: Українська, ro: Română, nl: Nederlands, el: Ελληνικά"),
conversation_id:str = Body("", description="对话框ID"),
uid: Optional[Any] = Body(None, description="用户ID"),
file_name: Optional[Any] = Body(None, description="文件名称"),
stream: bool = True):
"""
fake translate text
"""
async def chat_iterator() -> AsyncIterable[str]:
nonlocal to_lang,query,stream
# if file_name =="" or not file_name:
# stream = False
answer = ""
temp_text = []
if file_name:
# 指定临时文件夹路径
specified_dir = KB_CHAT_TEMP_DIR
os.makedirs(specified_dir, exist_ok=True) # 确保目录存在
temp_file_path = f"{specified_dir}{file_name}.txt"
# 将文件内容写入指定文件夹中的临时文件
with open(temp_file_path, "r", encoding="utf-8") as file:
query = file.read()
temp_text1 = query.split("\n")
temp_text = temp_text1 if temp_text1 else []
# 删除临时文件
# os.remove(temp_file_path)
else:
temp_text1 = query.split("\n")
temp_text = temp_text1.copy()
for text in temp_text:
if text == "" and stream:
yield json.dumps({"text": "\n"}, ensure_ascii=False)
continue
if text=="" and not stream:
answer += "\n"
continue
message_id = str(uuid.uuid1())+"q"
prompt_name = "translate_text"
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
chat_type="llm_chat",
query=text)
match to_lang:
case "zh-cn": # 中文
to_lang = "中文"
case "en": # 英语
to_lang = "English"
case "ja": # 日语
to_lang = "日本語"
case "ko": # 韩语
to_lang = "한국어"
case "fr": # 法语
to_lang = "Français"
case "de": # 德语
to_lang = "Deutsch"
case "es": # 西班牙语
to_lang = "Español"
case "it": # 意大利语
to_lang = "Italiano"
case "pt": # 葡萄牙语
to_lang = "Português"
case "ru": # 俄语
to_lang = "Русский"
case "ar": # 阿拉伯语
to_lang = "العربية"
case "hi": # 印地语
to_lang = "हिन्दी"
case "bn": # 孟加拉语
to_lang = "বাংলা"
case "pa": # 旁遮普语
to_lang = "ਪੰਜਾਬੀ"
case "jv": # 爪哇语
to_lang = "Basa Jawa"
case "ms": # 马来语
to_lang = "Bahasa Melayu"
case "vi": # 越南语
to_lang = "Tiếng Việt"
case "th": # 泰语
to_lang = "ไทย"
case "tr": # 土耳其语
to_lang = "Türkçe"
case "fa": # 波斯语
to_lang = "فارسی"
case "pl": # 波兰语
to_lang = "Polski"
case "uk": # 乌克兰语
to_lang = "Українська"
case "ro": # 罗马尼亚语
to_lang = "Română"
case "nl": # 荷兰语
to_lang = "Nederlands"
case "el": # 希腊语
to_lang = "Ελληνικά"
callback = AsyncIteratorCallbackHandler()
callbacks = [callback]
callbacks.append(conversation_callback)
model = get_ChatOpenAI(
model_name=LLM_MODELS[0],
temperature=0.2,
max_tokens=512,
callbacks=callbacks,
)
prompt_template = get_prompt_template("llm_chat", prompt_name)
input_prompt = History(role="system", content=prompt_template).to_msg_template(False)
chat_prompt = ChatPromptTemplate.from_messages([input_prompt])
chain = LLMChain(prompt=chat_prompt, llm=model)
time = datetime.now().strftime("%Y%m%d")
task = asyncio.create_task(wrap_done(
chain.acall({"input": f"<关键指令>必须翻译如下文本为{to_lang}语言<关键指令/>文本如下:\n{text}", "time": time, "lang":to_lang}),
callback.done),
)
if stream:
async for token in callback.aiter():
# Use server-sent-events to stream the response
yield json.dumps(
{"text": token, "message_id": message_id},
ensure_ascii=False)
if file_name:
yield json.dumps(
{"text": "\n", "message_id": message_id},
ensure_ascii=False)
else:
async for token in callback.aiter():
answer += token
if file_name:
answer+="\n"
await task
if answer:
yield json.dumps(
{"text": answer, "message_id": message_id},
ensure_ascii=False)
return EventSourceResponse(chat_iterator())