254 lines
9.4 KiB
Python
254 lines
9.4 KiB
Python
|
|
import asyncio
|
|||
|
|
import os
|
|||
|
|
import shutil
|
|||
|
|
import uuid
|
|||
|
|
import logging
|
|||
|
|
from datetime import datetime, timedelta, timezone
|
|||
|
|
from enum import Enum
|
|||
|
|
from typing import Dict, Optional, Callable
|
|||
|
|
|
|||
|
|
from sqlalchemy import (
|
|||
|
|
create_engine, Column, String, Enum as SAEnum, Float,
|
|||
|
|
Integer, Text, DateTime, Boolean
|
|||
|
|
)
|
|||
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|||
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|||
|
|
|
|||
|
|
from fastapi import BackgroundTasks
|
|||
|
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
|||
|
|
|
|||
|
|
from configs.translate_config import *
|
|||
|
|
|
|||
|
|
# 初始化数据库引擎和会话
|
|||
|
|
engine = create_engine(
|
|||
|
|
SQLALCHEMY_DATABASE_URI,
|
|||
|
|
connect_args={"check_same_thread": False}
|
|||
|
|
)
|
|||
|
|
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
|||
|
|
Base = declarative_base()
|
|||
|
|
|
|||
|
|
# 日志设置
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
logger.setLevel(logging.INFO)
|
|||
|
|
|
|||
|
|
# 时区设置(北京时间 UTC+8)
|
|||
|
|
LOCAL_TZ = timezone(timedelta(hours=8))
|
|||
|
|
|
|||
|
|
# ----------------------
|
|||
|
|
# 任务状态枚举
|
|||
|
|
# ----------------------
|
|||
|
|
class TaskStatusEnum(str, Enum):
|
|||
|
|
"""
|
|||
|
|
任务状态枚举:
|
|||
|
|
QUEUED - 已入队,等待处理
|
|||
|
|
PROCESSING - 正在处理
|
|||
|
|
COMPLETED - 已完成
|
|||
|
|
FAILED - 失败
|
|||
|
|
CANCELLED - 已取消
|
|||
|
|
"""
|
|||
|
|
QUEUED = "queued"
|
|||
|
|
PROCESSING = "processing"
|
|||
|
|
COMPLETED = "completed"
|
|||
|
|
FAILED = "failed"
|
|||
|
|
CANCELLED = "cancelled"
|
|||
|
|
|
|||
|
|
# ----------------------
|
|||
|
|
# TranslationTask ORM 模型
|
|||
|
|
# ----------------------
|
|||
|
|
class TranslationTask(Base):
|
|||
|
|
"""数据库表 file_translate_tasks 对应的 ORM 模型"""
|
|||
|
|
__tablename__ = "file_translate_tasks"
|
|||
|
|
|
|||
|
|
id = Column(String, primary_key=True, index=True) # 任务 ID
|
|||
|
|
filename = Column(String, nullable=False) # 原始文件名
|
|||
|
|
src_lang = Column(String, nullable=False) # 源语言
|
|||
|
|
dst_lang = Column(String, nullable=False) # 目标语言
|
|||
|
|
is_dual = Column(Boolean, default=True) # 是否双语模式
|
|||
|
|
file_path = Column(String, nullable=False) # 原文文件路径
|
|||
|
|
output_path = Column(String, nullable=True) # 翻译后文件路径
|
|||
|
|
status = Column(SAEnum(TaskStatusEnum), default=TaskStatusEnum.QUEUED) # 当前状态
|
|||
|
|
progress = Column(Float, default=0.0) # 进度百分比
|
|||
|
|
retry_count = Column(Integer, default=0) # 已重试次数
|
|||
|
|
error_msg = Column(Text, nullable=True) # 错误信息
|
|||
|
|
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(LOCAL_TZ)) # 创建时间
|
|||
|
|
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(LOCAL_TZ), onupdate=lambda: datetime.now(LOCAL_TZ)) # 最后更新时间
|
|||
|
|
|
|||
|
|
# 创建数据库表
|
|||
|
|
Base.metadata.create_all(bind=engine)
|
|||
|
|
|
|||
|
|
# ----------------------
|
|||
|
|
# TaskManager 定义
|
|||
|
|
# ----------------------
|
|||
|
|
class TaskManager:
|
|||
|
|
def __init__(self, translate_fn: Callable):
|
|||
|
|
self.translate_fn = translate_fn # 翻译函数引用
|
|||
|
|
# 设置调度器使用本地时区,避免时间偏移
|
|||
|
|
self.scheduler = AsyncIOScheduler(timezone=LOCAL_TZ)
|
|||
|
|
# 定时扫描并恢复卡住任务
|
|||
|
|
self.scheduler.add_job(
|
|||
|
|
self.recover_stuck_tasks,
|
|||
|
|
'interval', minutes=RECOVERY_INTERVAL,
|
|||
|
|
id='recover_jobs', replace_existing=True
|
|||
|
|
)
|
|||
|
|
# 存储运行中任务的取消事件映射
|
|||
|
|
self._cancel_events: Dict[str, asyncio.Event] = {}
|
|||
|
|
# self.scheduler.start()
|
|||
|
|
logger.info("任务管理器已初始化,恢复调度器已启动,扫描间隔:%s 分钟", RECOVERY_INTERVAL)
|
|||
|
|
|
|||
|
|
def start(self):
|
|||
|
|
self.scheduler.start()
|
|||
|
|
logger.info("恢复调度器已启动")
|
|||
|
|
|
|||
|
|
def shutdown(self):
|
|||
|
|
self.scheduler.shutdown()
|
|||
|
|
logger.info("恢复调度器已关闭")
|
|||
|
|
|
|||
|
|
def db(self) -> Session:
|
|||
|
|
"""获取数据库会话"""
|
|||
|
|
return SessionLocal()
|
|||
|
|
|
|||
|
|
def generate_task_id(self) -> str:
|
|||
|
|
"""生成唯一任务 ID"""
|
|||
|
|
return str(uuid.uuid4())
|
|||
|
|
|
|||
|
|
def add_task(
|
|||
|
|
self,
|
|||
|
|
filename: str,
|
|||
|
|
file_path: str,
|
|||
|
|
src_lang: str,
|
|||
|
|
dst_lang: str,
|
|||
|
|
is_dual: bool,
|
|||
|
|
background_tasks: Optional[BackgroundTasks],
|
|||
|
|
task_id: Optional[str] = None,
|
|||
|
|
) -> str:
|
|||
|
|
"""添加新翻译任务并入队"""
|
|||
|
|
db = self.db()
|
|||
|
|
if not task_id:
|
|||
|
|
task_id = self.generate_task_id() # 如果外部未传入,则生成新ID
|
|||
|
|
task = TranslationTask(
|
|||
|
|
id=task_id,
|
|||
|
|
filename=filename,
|
|||
|
|
file_path=file_path,
|
|||
|
|
src_lang=src_lang,
|
|||
|
|
dst_lang=dst_lang,
|
|||
|
|
is_dual=is_dual,
|
|||
|
|
status=TaskStatusEnum.QUEUED,
|
|||
|
|
)
|
|||
|
|
db.add(task)
|
|||
|
|
db.commit()
|
|||
|
|
db.close()
|
|||
|
|
|
|||
|
|
if background_tasks:
|
|||
|
|
background_tasks.add_task(self.start_task, task_id) # 加入后台执行队列
|
|||
|
|
return task_id
|
|||
|
|
|
|||
|
|
async def start_task(self, task_id: str):
|
|||
|
|
"""执行翻译任务的后台入口,支持取消"""
|
|||
|
|
# 创建并注册取消事件
|
|||
|
|
cancel_event = asyncio.Event()
|
|||
|
|
self._cancel_events[task_id] = cancel_event
|
|||
|
|
|
|||
|
|
db = self.db()
|
|||
|
|
task = db.query(TranslationTask).get(task_id)
|
|||
|
|
# 只有 QUEUED 状态可执行
|
|||
|
|
if not task or task.status != TaskStatusEnum.QUEUED:
|
|||
|
|
db.close()
|
|||
|
|
self._cancel_events.pop(task_id, None)
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
# 标记为处理中并更新更新时间
|
|||
|
|
task.status = TaskStatusEnum.PROCESSING
|
|||
|
|
task.updated_at = datetime.now(LOCAL_TZ)
|
|||
|
|
db.commit()
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 调用翻译函数,需支持 cancel_event
|
|||
|
|
await self.translate_fn(task_id, task, cancel_event)
|
|||
|
|
task.status = TaskStatusEnum.COMPLETED
|
|||
|
|
task.progress = 100.0 # 完成进度
|
|||
|
|
except asyncio.CancelledError:
|
|||
|
|
# 如果收到取消信号,标记取消
|
|||
|
|
task.status = TaskStatusEnum.CANCELLED
|
|||
|
|
task.error_msg = "用户已取消"
|
|||
|
|
except Exception as e:
|
|||
|
|
# 其他异常按照重试逻辑处理
|
|||
|
|
logger.error("任务 %s 翻译失败:%s", task_id, e)
|
|||
|
|
task.retry_count += 1
|
|||
|
|
if task.retry_count <= MAX_RETRIES:
|
|||
|
|
task.status = TaskStatusEnum.QUEUED
|
|||
|
|
task.error_msg = "意外中断,正在重试"
|
|||
|
|
run_date = datetime.now(LOCAL_TZ) + timedelta(seconds=RETRY_DELAY)
|
|||
|
|
self.scheduler.add_job(
|
|||
|
|
self.start_task, 'date', run_date=run_date, args=[task_id]
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
task.status = TaskStatusEnum.FAILED
|
|||
|
|
task.error_msg = "意外中断,重试次数已达上限"
|
|||
|
|
finally:
|
|||
|
|
# 无论完成、取消或失败,都清理取消事件并更新数据库
|
|||
|
|
task.updated_at = datetime.now(LOCAL_TZ)
|
|||
|
|
self._cancel_events.pop(task_id, None)
|
|||
|
|
db.commit()
|
|||
|
|
db.close()
|
|||
|
|
|
|||
|
|
def get_task(self, task_id: str) -> Optional[TranslationTask]:
|
|||
|
|
"""获取任务详情"""
|
|||
|
|
db = self.db()
|
|||
|
|
task = db.query(TranslationTask).get(task_id)
|
|||
|
|
db.close()
|
|||
|
|
return task
|
|||
|
|
|
|||
|
|
def cancel_task(self, task_id: str) -> bool:
|
|||
|
|
"""取消正在 QUEUED 或 PROCESSING 状态的任务并触发取消事件"""
|
|||
|
|
db = self.db()
|
|||
|
|
task = db.query(TranslationTask).get(task_id)
|
|||
|
|
if not task:
|
|||
|
|
db.close()
|
|||
|
|
return False
|
|||
|
|
# 标记为取消
|
|||
|
|
if task.status in {TaskStatusEnum.QUEUED, TaskStatusEnum.PROCESSING}:
|
|||
|
|
task.status = TaskStatusEnum.CANCELLED
|
|||
|
|
task.updated_at = datetime.now(LOCAL_TZ)
|
|||
|
|
db.commit()
|
|||
|
|
# 同时清理文件目录
|
|||
|
|
if os.path.exists(os.path.dirname(task.file_path)):
|
|||
|
|
shutil.rmtree(os.path.dirname(task.file_path), ignore_errors=True)
|
|||
|
|
db.close()
|
|||
|
|
# 触发翻译协程中的取消事件
|
|||
|
|
if task_id in self._cancel_events:
|
|||
|
|
self._cancel_events[task_id].set()
|
|||
|
|
return True
|
|||
|
|
|
|||
|
|
def recover_stuck_tasks(self):
|
|||
|
|
"""扫描并恢复处理超时或卡住的任务"""
|
|||
|
|
logger.info(">>> 正在扫描翻译任务。。。")
|
|||
|
|
db = self.db()
|
|||
|
|
cutoff = datetime.now(LOCAL_TZ) - PROCESSING_TIMEOUT
|
|||
|
|
logger.info(f"cutoff 时间为:{cutoff}")
|
|||
|
|
stuck_tasks = db.query(TranslationTask).filter(
|
|||
|
|
TranslationTask.status.in_([
|
|||
|
|
TaskStatusEnum.PROCESSING,
|
|||
|
|
TaskStatusEnum.QUEUED,
|
|||
|
|
TaskStatusEnum.FAILED
|
|||
|
|
]),
|
|||
|
|
TranslationTask.updated_at > cutoff
|
|||
|
|
).all()
|
|||
|
|
logger.info(">>> 找到 %s 条卡住的翻译任务", len(stuck_tasks))
|
|||
|
|
for task in stuck_tasks:
|
|||
|
|
logger.info("正在恢复卡住的任务 %s", task.id)
|
|||
|
|
if task.retry_count < MAX_RETRIES:
|
|||
|
|
task.retry_count += 1
|
|||
|
|
task.status = TaskStatusEnum.QUEUED
|
|||
|
|
task.error_msg = "意外中断,正在重试"
|
|||
|
|
run_date = datetime.now(LOCAL_TZ) + timedelta(seconds=RETRY_DELAY)
|
|||
|
|
self.scheduler.add_job(
|
|||
|
|
self.start_task,
|
|||
|
|
trigger='date',
|
|||
|
|
run_date=run_date,
|
|||
|
|
args=[task.id]
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
task.status = TaskStatusEnum.FAILED
|
|||
|
|
task.error_msg = "意外中断,重试次数已达上限"
|
|||
|
|
db.commit()
|
|||
|
|
db.close()
|