[全量] 初始化项目代码、配置、文档及Agent协同harness
This commit is contained in:
254
langchain-chat/server/translator_service/task_manager.py
Normal file
254
langchain-chat/server/translator_service/task_manager.py
Normal file
@@ -0,0 +1,254 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, Callable
|
||||
|
||||
from sqlalchemy import (
|
||||
create_engine, Column, String, Enum as SAEnum, Float,
|
||||
Integer, Text, DateTime, Boolean
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
|
||||
from fastapi import BackgroundTasks
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
|
||||
from configs.translate_config import *
|
||||
|
||||
# 初始化数据库引擎和会话
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URI,
|
||||
connect_args={"check_same_thread": False}
|
||||
)
|
||||
SessionLocal = sessionmaker(bind=engine, autoflush=False, autocommit=False)
|
||||
Base = declarative_base()
|
||||
|
||||
# 日志设置
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
# 时区设置(北京时间 UTC+8)
|
||||
LOCAL_TZ = timezone(timedelta(hours=8))
|
||||
|
||||
# ----------------------
|
||||
# 任务状态枚举
|
||||
# ----------------------
|
||||
class TaskStatusEnum(str, Enum):
|
||||
"""
|
||||
任务状态枚举:
|
||||
QUEUED - 已入队,等待处理
|
||||
PROCESSING - 正在处理
|
||||
COMPLETED - 已完成
|
||||
FAILED - 失败
|
||||
CANCELLED - 已取消
|
||||
"""
|
||||
QUEUED = "queued"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
# ----------------------
|
||||
# TranslationTask ORM 模型
|
||||
# ----------------------
|
||||
class TranslationTask(Base):
|
||||
"""数据库表 file_translate_tasks 对应的 ORM 模型"""
|
||||
__tablename__ = "file_translate_tasks"
|
||||
|
||||
id = Column(String, primary_key=True, index=True) # 任务 ID
|
||||
filename = Column(String, nullable=False) # 原始文件名
|
||||
src_lang = Column(String, nullable=False) # 源语言
|
||||
dst_lang = Column(String, nullable=False) # 目标语言
|
||||
is_dual = Column(Boolean, default=True) # 是否双语模式
|
||||
file_path = Column(String, nullable=False) # 原文文件路径
|
||||
output_path = Column(String, nullable=True) # 翻译后文件路径
|
||||
status = Column(SAEnum(TaskStatusEnum), default=TaskStatusEnum.QUEUED) # 当前状态
|
||||
progress = Column(Float, default=0.0) # 进度百分比
|
||||
retry_count = Column(Integer, default=0) # 已重试次数
|
||||
error_msg = Column(Text, nullable=True) # 错误信息
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(LOCAL_TZ)) # 创建时间
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(LOCAL_TZ), onupdate=lambda: datetime.now(LOCAL_TZ)) # 最后更新时间
|
||||
|
||||
# 创建数据库表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# ----------------------
|
||||
# TaskManager 定义
|
||||
# ----------------------
|
||||
class TaskManager:
|
||||
def __init__(self, translate_fn: Callable):
|
||||
self.translate_fn = translate_fn # 翻译函数引用
|
||||
# 设置调度器使用本地时区,避免时间偏移
|
||||
self.scheduler = AsyncIOScheduler(timezone=LOCAL_TZ)
|
||||
# 定时扫描并恢复卡住任务
|
||||
self.scheduler.add_job(
|
||||
self.recover_stuck_tasks,
|
||||
'interval', minutes=RECOVERY_INTERVAL,
|
||||
id='recover_jobs', replace_existing=True
|
||||
)
|
||||
# 存储运行中任务的取消事件映射
|
||||
self._cancel_events: Dict[str, asyncio.Event] = {}
|
||||
# self.scheduler.start()
|
||||
logger.info("任务管理器已初始化,恢复调度器已启动,扫描间隔:%s 分钟", RECOVERY_INTERVAL)
|
||||
|
||||
def start(self):
|
||||
self.scheduler.start()
|
||||
logger.info("恢复调度器已启动")
|
||||
|
||||
def shutdown(self):
|
||||
self.scheduler.shutdown()
|
||||
logger.info("恢复调度器已关闭")
|
||||
|
||||
def db(self) -> Session:
|
||||
"""获取数据库会话"""
|
||||
return SessionLocal()
|
||||
|
||||
def generate_task_id(self) -> str:
|
||||
"""生成唯一任务 ID"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def add_task(
|
||||
self,
|
||||
filename: str,
|
||||
file_path: str,
|
||||
src_lang: str,
|
||||
dst_lang: str,
|
||||
is_dual: bool,
|
||||
background_tasks: Optional[BackgroundTasks],
|
||||
task_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""添加新翻译任务并入队"""
|
||||
db = self.db()
|
||||
if not task_id:
|
||||
task_id = self.generate_task_id() # 如果外部未传入,则生成新ID
|
||||
task = TranslationTask(
|
||||
id=task_id,
|
||||
filename=filename,
|
||||
file_path=file_path,
|
||||
src_lang=src_lang,
|
||||
dst_lang=dst_lang,
|
||||
is_dual=is_dual,
|
||||
status=TaskStatusEnum.QUEUED,
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
if background_tasks:
|
||||
background_tasks.add_task(self.start_task, task_id) # 加入后台执行队列
|
||||
return task_id
|
||||
|
||||
async def start_task(self, task_id: str):
|
||||
"""执行翻译任务的后台入口,支持取消"""
|
||||
# 创建并注册取消事件
|
||||
cancel_event = asyncio.Event()
|
||||
self._cancel_events[task_id] = cancel_event
|
||||
|
||||
db = self.db()
|
||||
task = db.query(TranslationTask).get(task_id)
|
||||
# 只有 QUEUED 状态可执行
|
||||
if not task or task.status != TaskStatusEnum.QUEUED:
|
||||
db.close()
|
||||
self._cancel_events.pop(task_id, None)
|
||||
return
|
||||
|
||||
# 标记为处理中并更新更新时间
|
||||
task.status = TaskStatusEnum.PROCESSING
|
||||
task.updated_at = datetime.now(LOCAL_TZ)
|
||||
db.commit()
|
||||
|
||||
try:
|
||||
# 调用翻译函数,需支持 cancel_event
|
||||
await self.translate_fn(task_id, task, cancel_event)
|
||||
task.status = TaskStatusEnum.COMPLETED
|
||||
task.progress = 100.0 # 完成进度
|
||||
except asyncio.CancelledError:
|
||||
# 如果收到取消信号,标记取消
|
||||
task.status = TaskStatusEnum.CANCELLED
|
||||
task.error_msg = "用户已取消"
|
||||
except Exception as e:
|
||||
# 其他异常按照重试逻辑处理
|
||||
logger.error("任务 %s 翻译失败:%s", task_id, e)
|
||||
task.retry_count += 1
|
||||
if task.retry_count <= MAX_RETRIES:
|
||||
task.status = TaskStatusEnum.QUEUED
|
||||
task.error_msg = "意外中断,正在重试"
|
||||
run_date = datetime.now(LOCAL_TZ) + timedelta(seconds=RETRY_DELAY)
|
||||
self.scheduler.add_job(
|
||||
self.start_task, 'date', run_date=run_date, args=[task_id]
|
||||
)
|
||||
else:
|
||||
task.status = TaskStatusEnum.FAILED
|
||||
task.error_msg = "意外中断,重试次数已达上限"
|
||||
finally:
|
||||
# 无论完成、取消或失败,都清理取消事件并更新数据库
|
||||
task.updated_at = datetime.now(LOCAL_TZ)
|
||||
self._cancel_events.pop(task_id, None)
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[TranslationTask]:
|
||||
"""获取任务详情"""
|
||||
db = self.db()
|
||||
task = db.query(TranslationTask).get(task_id)
|
||||
db.close()
|
||||
return task
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""取消正在 QUEUED 或 PROCESSING 状态的任务并触发取消事件"""
|
||||
db = self.db()
|
||||
task = db.query(TranslationTask).get(task_id)
|
||||
if not task:
|
||||
db.close()
|
||||
return False
|
||||
# 标记为取消
|
||||
if task.status in {TaskStatusEnum.QUEUED, TaskStatusEnum.PROCESSING}:
|
||||
task.status = TaskStatusEnum.CANCELLED
|
||||
task.updated_at = datetime.now(LOCAL_TZ)
|
||||
db.commit()
|
||||
# 同时清理文件目录
|
||||
if os.path.exists(os.path.dirname(task.file_path)):
|
||||
shutil.rmtree(os.path.dirname(task.file_path), ignore_errors=True)
|
||||
db.close()
|
||||
# 触发翻译协程中的取消事件
|
||||
if task_id in self._cancel_events:
|
||||
self._cancel_events[task_id].set()
|
||||
return True
|
||||
|
||||
def recover_stuck_tasks(self):
|
||||
"""扫描并恢复处理超时或卡住的任务"""
|
||||
logger.info(">>> 正在扫描翻译任务。。。")
|
||||
db = self.db()
|
||||
cutoff = datetime.now(LOCAL_TZ) - PROCESSING_TIMEOUT
|
||||
logger.info(f"cutoff 时间为:{cutoff}")
|
||||
stuck_tasks = db.query(TranslationTask).filter(
|
||||
TranslationTask.status.in_([
|
||||
TaskStatusEnum.PROCESSING,
|
||||
TaskStatusEnum.QUEUED,
|
||||
TaskStatusEnum.FAILED
|
||||
]),
|
||||
TranslationTask.updated_at > cutoff
|
||||
).all()
|
||||
logger.info(">>> 找到 %s 条卡住的翻译任务", len(stuck_tasks))
|
||||
for task in stuck_tasks:
|
||||
logger.info("正在恢复卡住的任务 %s", task.id)
|
||||
if task.retry_count < MAX_RETRIES:
|
||||
task.retry_count += 1
|
||||
task.status = TaskStatusEnum.QUEUED
|
||||
task.error_msg = "意外中断,正在重试"
|
||||
run_date = datetime.now(LOCAL_TZ) + timedelta(seconds=RETRY_DELAY)
|
||||
self.scheduler.add_job(
|
||||
self.start_task,
|
||||
trigger='date',
|
||||
run_date=run_date,
|
||||
args=[task.id]
|
||||
)
|
||||
else:
|
||||
task.status = TaskStatusEnum.FAILED
|
||||
task.error_msg = "意外中断,重试次数已达上限"
|
||||
db.commit()
|
||||
db.close()
|
||||
Reference in New Issue
Block a user