254 lines
9.4 KiB
Python
254 lines
9.4 KiB
Python
import asyncio
|
||
import os
|
||
import shutil
|
||
import uuid
|
||
import logging
|
||
from datetime import datetime, timedelta, timezone
|
||
from enum import Enum
|
||
from typing import Dict, Optional, Callable
|
||
|
||
from sqlalchemy import (
|
||
create_engine, Column, String, Enum as SAEnum, Float,
|
||
Integer, Text, DateTime, Boolean
|
||
)
|
||
from sqlalchemy.ext.declarative import declarative_base
|
||
from sqlalchemy.orm import sessionmaker, Session
|
||
|
||
from fastapi import BackgroundTasks
|
||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||
|
||
from configs.translate_config import *
|
||
|
||
# 初始化数据库引擎和会话
|
||
engine = create_engine(
|
||
SQLALCHEMY_DATABASE_URI,
|
||
connect_args={"check_same_thread": False}
|
||
)
|
||
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||
Base = declarative_base()
|
||
|
||
# 日志设置
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(logging.INFO)
|
||
|
||
# 时区设置(北京时间 UTC+8)
|
||
LOCAL_TZ = timezone(timedelta(hours=8))
|
||
|
||
# ----------------------
|
||
# 任务状态枚举
|
||
# ----------------------
|
||
class TaskStatusEnum(str, Enum):
|
||
"""
|
||
任务状态枚举:
|
||
QUEUED - 已入队,等待处理
|
||
PROCESSING - 正在处理
|
||
COMPLETED - 已完成
|
||
FAILED - 失败
|
||
CANCELLED - 已取消
|
||
"""
|
||
QUEUED = "queued"
|
||
PROCESSING = "processing"
|
||
COMPLETED = "completed"
|
||
FAILED = "failed"
|
||
CANCELLED = "cancelled"
|
||
|
||
# ----------------------
|
||
# TranslationTask ORM 模型
|
||
# ----------------------
|
||
class TranslationTask(Base):
|
||
"""数据库表 file_translate_tasks 对应的 ORM 模型"""
|
||
__tablename__ = "file_translate_tasks"
|
||
|
||
id = Column(String, primary_key=True, index=True) # 任务 ID
|
||
filename = Column(String, nullable=False) # 原始文件名
|
||
src_lang = Column(String, nullable=False) # 源语言
|
||
dst_lang = Column(String, nullable=False) # 目标语言
|
||
is_dual = Column(Boolean, default=True) # 是否双语模式
|
||
file_path = Column(String, nullable=False) # 原文文件路径
|
||
output_path = Column(String, nullable=True) # 翻译后文件路径
|
||
status = Column(SAEnum(TaskStatusEnum), default=TaskStatusEnum.QUEUED) # 当前状态
|
||
progress = Column(Float, default=0.0) # 进度百分比
|
||
retry_count = Column(Integer, default=0) # 已重试次数
|
||
error_msg = Column(Text, nullable=True) # 错误信息
|
||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(LOCAL_TZ)) # 创建时间
|
||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(LOCAL_TZ), onupdate=lambda: datetime.now(LOCAL_TZ)) # 最后更新时间
|
||
|
||
# 创建数据库表
|
||
Base.metadata.create_all(bind=engine)
|
||
|
||
# ----------------------
|
||
# TaskManager 定义
|
||
# ----------------------
|
||
class TaskManager:
|
||
def __init__(self, translate_fn: Callable):
|
||
self.translate_fn = translate_fn # 翻译函数引用
|
||
# 设置调度器使用本地时区,避免时间偏移
|
||
self.scheduler = AsyncIOScheduler(timezone=LOCAL_TZ)
|
||
# 定时扫描并恢复卡住任务
|
||
self.scheduler.add_job(
|
||
self.recover_stuck_tasks,
|
||
'interval', minutes=RECOVERY_INTERVAL,
|
||
id='recover_jobs', replace_existing=True
|
||
)
|
||
# 存储运行中任务的取消事件映射
|
||
self._cancel_events: Dict[str, asyncio.Event] = {}
|
||
# self.scheduler.start()
|
||
logger.info("任务管理器已初始化,恢复调度器已启动,扫描间隔:%s 分钟", RECOVERY_INTERVAL)
|
||
|
||
def start(self):
|
||
self.scheduler.start()
|
||
logger.info("恢复调度器已启动")
|
||
|
||
def shutdown(self):
|
||
self.scheduler.shutdown()
|
||
logger.info("恢复调度器已关闭")
|
||
|
||
def db(self) -> Session:
|
||
"""获取数据库会话"""
|
||
return SessionLocal()
|
||
|
||
def generate_task_id(self) -> str:
|
||
"""生成唯一任务 ID"""
|
||
return str(uuid.uuid4())
|
||
|
||
def add_task(
|
||
self,
|
||
filename: str,
|
||
file_path: str,
|
||
src_lang: str,
|
||
dst_lang: str,
|
||
is_dual: bool,
|
||
background_tasks: Optional[BackgroundTasks],
|
||
task_id: Optional[str] = None,
|
||
) -> str:
|
||
"""添加新翻译任务并入队"""
|
||
db = self.db()
|
||
if not task_id:
|
||
task_id = self.generate_task_id() # 如果外部未传入,则生成新ID
|
||
task = TranslationTask(
|
||
id=task_id,
|
||
filename=filename,
|
||
file_path=file_path,
|
||
src_lang=src_lang,
|
||
dst_lang=dst_lang,
|
||
is_dual=is_dual,
|
||
status=TaskStatusEnum.QUEUED,
|
||
)
|
||
db.add(task)
|
||
db.commit()
|
||
db.close()
|
||
|
||
if background_tasks:
|
||
background_tasks.add_task(self.start_task, task_id) # 加入后台执行队列
|
||
return task_id
|
||
|
||
async def start_task(self, task_id: str):
|
||
"""执行翻译任务的后台入口,支持取消"""
|
||
# 创建并注册取消事件
|
||
cancel_event = asyncio.Event()
|
||
self._cancel_events[task_id] = cancel_event
|
||
|
||
db = self.db()
|
||
task = db.query(TranslationTask).get(task_id)
|
||
# 只有 QUEUED 状态可执行
|
||
if not task or task.status != TaskStatusEnum.QUEUED:
|
||
db.close()
|
||
self._cancel_events.pop(task_id, None)
|
||
return
|
||
|
||
# 标记为处理中并更新更新时间
|
||
task.status = TaskStatusEnum.PROCESSING
|
||
task.updated_at = datetime.now(LOCAL_TZ)
|
||
db.commit()
|
||
|
||
try:
|
||
# 调用翻译函数,需支持 cancel_event
|
||
await self.translate_fn(task_id, task, cancel_event)
|
||
task.status = TaskStatusEnum.COMPLETED
|
||
task.progress = 100.0 # 完成进度
|
||
except asyncio.CancelledError:
|
||
# 如果收到取消信号,标记取消
|
||
task.status = TaskStatusEnum.CANCELLED
|
||
task.error_msg = "用户已取消"
|
||
except Exception as e:
|
||
# 其他异常按照重试逻辑处理
|
||
logger.error("任务 %s 翻译失败:%s", task_id, e)
|
||
task.retry_count += 1
|
||
if task.retry_count <= MAX_RETRIES:
|
||
task.status = TaskStatusEnum.QUEUED
|
||
task.error_msg = "意外中断,正在重试"
|
||
run_date = datetime.now(LOCAL_TZ) + timedelta(seconds=RETRY_DELAY)
|
||
self.scheduler.add_job(
|
||
self.start_task, 'date', run_date=run_date, args=[task_id]
|
||
)
|
||
else:
|
||
task.status = TaskStatusEnum.FAILED
|
||
task.error_msg = "意外中断,重试次数已达上限"
|
||
finally:
|
||
# 无论完成、取消或失败,都清理取消事件并更新数据库
|
||
task.updated_at = datetime.now(LOCAL_TZ)
|
||
self._cancel_events.pop(task_id, None)
|
||
db.commit()
|
||
db.close()
|
||
|
||
def get_task(self, task_id: str) -> Optional[TranslationTask]:
|
||
"""获取任务详情"""
|
||
db = self.db()
|
||
task = db.query(TranslationTask).get(task_id)
|
||
db.close()
|
||
return task
|
||
|
||
def cancel_task(self, task_id: str) -> bool:
|
||
"""取消正在 QUEUED 或 PROCESSING 状态的任务并触发取消事件"""
|
||
db = self.db()
|
||
task = db.query(TranslationTask).get(task_id)
|
||
if not task:
|
||
db.close()
|
||
return False
|
||
# 标记为取消
|
||
if task.status in {TaskStatusEnum.QUEUED, TaskStatusEnum.PROCESSING}:
|
||
task.status = TaskStatusEnum.CANCELLED
|
||
task.updated_at = datetime.now(LOCAL_TZ)
|
||
db.commit()
|
||
# 同时清理文件目录
|
||
if os.path.exists(os.path.dirname(task.file_path)):
|
||
shutil.rmtree(os.path.dirname(task.file_path), ignore_errors=True)
|
||
db.close()
|
||
# 触发翻译协程中的取消事件
|
||
if task_id in self._cancel_events:
|
||
self._cancel_events[task_id].set()
|
||
return True
|
||
|
||
def recover_stuck_tasks(self):
|
||
"""扫描并恢复处理超时或卡住的任务"""
|
||
logger.info(">>> 正在扫描翻译任务。。。")
|
||
db = self.db()
|
||
cutoff = datetime.now(LOCAL_TZ) - PROCESSING_TIMEOUT
|
||
logger.info(f"cutoff 时间为:{cutoff}")
|
||
stuck_tasks = db.query(TranslationTask).filter(
|
||
TranslationTask.status.in_([
|
||
TaskStatusEnum.PROCESSING,
|
||
TaskStatusEnum.QUEUED,
|
||
TaskStatusEnum.FAILED
|
||
]),
|
||
TranslationTask.updated_at > cutoff
|
||
).all()
|
||
logger.info(">>> 找到 %s 条卡住的翻译任务", len(stuck_tasks))
|
||
for task in stuck_tasks:
|
||
logger.info("正在恢复卡住的任务 %s", task.id)
|
||
if task.retry_count < MAX_RETRIES:
|
||
task.retry_count += 1
|
||
task.status = TaskStatusEnum.QUEUED
|
||
task.error_msg = "意外中断,正在重试"
|
||
run_date = datetime.now(LOCAL_TZ) + timedelta(seconds=RETRY_DELAY)
|
||
self.scheduler.add_job(
|
||
self.start_task,
|
||
trigger='date',
|
||
run_date=run_date,
|
||
args=[task.id]
|
||
)
|
||
else:
|
||
task.status = TaskStatusEnum.FAILED
|
||
task.error_msg = "意外中断,重试次数已达上限"
|
||
db.commit()
|
||
db.close() |