Files
gangyan/langchain-chat/server/translator_service/task_manager.py

254 lines
9.4 KiB
Python
Raw Normal View History

import asyncio
import os
import shutil
import uuid
import logging
from datetime import datetime, timedelta, timezone
from enum import Enum
from typing import Dict, Optional, Callable
from sqlalchemy import (
create_engine, Column, String, Enum as SAEnum, Float,
Integer, Text, DateTime, Boolean
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, Session
from fastapi import BackgroundTasks
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from configs.translate_config import *
# 初始化数据库引擎和会话
engine = create_engine(
SQLALCHEMY_DATABASE_URI,
connect_args={"check_same_thread": False}
)
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
Base = declarative_base()
# 日志设置
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# 时区设置(北京时间 UTC+8
LOCAL_TZ = timezone(timedelta(hours=8))
# ----------------------
# 任务状态枚举
# ----------------------
class TaskStatusEnum(str, Enum):
"""
任务状态枚举
QUEUED - 已入队等待处理
PROCESSING - 正在处理
COMPLETED - 已完成
FAILED - 失败
CANCELLED - 已取消
"""
QUEUED = "queued"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
# ----------------------
# TranslationTask ORM 模型
# ----------------------
class TranslationTask(Base):
"""数据库表 file_translate_tasks 对应的 ORM 模型"""
__tablename__ = "file_translate_tasks"
id = Column(String, primary_key=True, index=True) # 任务 ID
filename = Column(String, nullable=False) # 原始文件名
src_lang = Column(String, nullable=False) # 源语言
dst_lang = Column(String, nullable=False) # 目标语言
is_dual = Column(Boolean, default=True) # 是否双语模式
file_path = Column(String, nullable=False) # 原文文件路径
output_path = Column(String, nullable=True) # 翻译后文件路径
status = Column(SAEnum(TaskStatusEnum), default=TaskStatusEnum.QUEUED) # 当前状态
progress = Column(Float, default=0.0) # 进度百分比
retry_count = Column(Integer, default=0) # 已重试次数
error_msg = Column(Text, nullable=True) # 错误信息
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(LOCAL_TZ)) # 创建时间
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(LOCAL_TZ), onupdate=lambda: datetime.now(LOCAL_TZ)) # 最后更新时间
# 创建数据库表
Base.metadata.create_all(bind=engine)
# ----------------------
# TaskManager 定义
# ----------------------
class TaskManager:
def __init__(self, translate_fn: Callable):
self.translate_fn = translate_fn # 翻译函数引用
# 设置调度器使用本地时区,避免时间偏移
self.scheduler = AsyncIOScheduler(timezone=LOCAL_TZ)
# 定时扫描并恢复卡住任务
self.scheduler.add_job(
self.recover_stuck_tasks,
'interval', minutes=RECOVERY_INTERVAL,
id='recover_jobs', replace_existing=True
)
# 存储运行中任务的取消事件映射
self._cancel_events: Dict[str, asyncio.Event] = {}
# self.scheduler.start()
logger.info("任务管理器已初始化,恢复调度器已启动,扫描间隔:%s 分钟", RECOVERY_INTERVAL)
def start(self):
self.scheduler.start()
logger.info("恢复调度器已启动")
def shutdown(self):
self.scheduler.shutdown()
logger.info("恢复调度器已关闭")
def db(self) -> Session:
"""获取数据库会话"""
return SessionLocal()
def generate_task_id(self) -> str:
"""生成唯一任务 ID"""
return str(uuid.uuid4())
def add_task(
self,
filename: str,
file_path: str,
src_lang: str,
dst_lang: str,
is_dual: bool,
background_tasks: Optional[BackgroundTasks],
task_id: Optional[str] = None,
) -> str:
"""添加新翻译任务并入队"""
db = self.db()
if not task_id:
task_id = self.generate_task_id() # 如果外部未传入则生成新ID
task = TranslationTask(
id=task_id,
filename=filename,
file_path=file_path,
src_lang=src_lang,
dst_lang=dst_lang,
is_dual=is_dual,
status=TaskStatusEnum.QUEUED,
)
db.add(task)
db.commit()
db.close()
if background_tasks:
background_tasks.add_task(self.start_task, task_id) # 加入后台执行队列
return task_id
async def start_task(self, task_id: str):
"""执行翻译任务的后台入口,支持取消"""
# 创建并注册取消事件
cancel_event = asyncio.Event()
self._cancel_events[task_id] = cancel_event
db = self.db()
task = db.query(TranslationTask).get(task_id)
# 只有 QUEUED 状态可执行
if not task or task.status != TaskStatusEnum.QUEUED:
db.close()
self._cancel_events.pop(task_id, None)
return
# 标记为处理中并更新更新时间
task.status = TaskStatusEnum.PROCESSING
task.updated_at = datetime.now(LOCAL_TZ)
db.commit()
try:
# 调用翻译函数,需支持 cancel_event
await self.translate_fn(task_id, task, cancel_event)
task.status = TaskStatusEnum.COMPLETED
task.progress = 100.0 # 完成进度
except asyncio.CancelledError:
# 如果收到取消信号,标记取消
task.status = TaskStatusEnum.CANCELLED
task.error_msg = "用户已取消"
except Exception as e:
# 其他异常按照重试逻辑处理
logger.error("任务 %s 翻译失败:%s", task_id, e)
task.retry_count += 1
if task.retry_count <= MAX_RETRIES:
task.status = TaskStatusEnum.QUEUED
task.error_msg = "意外中断,正在重试"
run_date = datetime.now(LOCAL_TZ) + timedelta(seconds=RETRY_DELAY)
self.scheduler.add_job(
self.start_task, 'date', run_date=run_date, args=[task_id]
)
else:
task.status = TaskStatusEnum.FAILED
task.error_msg = "意外中断,重试次数已达上限"
finally:
# 无论完成、取消或失败,都清理取消事件并更新数据库
task.updated_at = datetime.now(LOCAL_TZ)
self._cancel_events.pop(task_id, None)
db.commit()
db.close()
def get_task(self, task_id: str) -> Optional[TranslationTask]:
"""获取任务详情"""
db = self.db()
task = db.query(TranslationTask).get(task_id)
db.close()
return task
def cancel_task(self, task_id: str) -> bool:
"""取消正在 QUEUED 或 PROCESSING 状态的任务并触发取消事件"""
db = self.db()
task = db.query(TranslationTask).get(task_id)
if not task:
db.close()
return False
# 标记为取消
if task.status in {TaskStatusEnum.QUEUED, TaskStatusEnum.PROCESSING}:
task.status = TaskStatusEnum.CANCELLED
task.updated_at = datetime.now(LOCAL_TZ)
db.commit()
# 同时清理文件目录
if os.path.exists(os.path.dirname(task.file_path)):
shutil.rmtree(os.path.dirname(task.file_path), ignore_errors=True)
db.close()
# 触发翻译协程中的取消事件
if task_id in self._cancel_events:
self._cancel_events[task_id].set()
return True
def recover_stuck_tasks(self):
"""扫描并恢复处理超时或卡住的任务"""
logger.info(">>> 正在扫描翻译任务。。。")
db = self.db()
cutoff = datetime.now(LOCAL_TZ) - PROCESSING_TIMEOUT
logger.info(f"cutoff 时间为:{cutoff}")
stuck_tasks = db.query(TranslationTask).filter(
TranslationTask.status.in_([
TaskStatusEnum.PROCESSING,
TaskStatusEnum.QUEUED,
TaskStatusEnum.FAILED
]),
TranslationTask.updated_at > cutoff
).all()
logger.info(">>> 找到 %s 条卡住的翻译任务", len(stuck_tasks))
for task in stuck_tasks:
logger.info("正在恢复卡住的任务 %s", task.id)
if task.retry_count < MAX_RETRIES:
task.retry_count += 1
task.status = TaskStatusEnum.QUEUED
task.error_msg = "意外中断,正在重试"
run_date = datetime.now(LOCAL_TZ) + timedelta(seconds=RETRY_DELAY)
self.scheduler.add_job(
self.start_task,
trigger='date',
run_date=run_date,
args=[task.id]
)
else:
task.status = TaskStatusEnum.FAILED
task.error_msg = "意外中断,重试次数已达上限"
db.commit()
db.close()