mysora/api/tasks.py

121 lines
3.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Celery 任务定义。
每个 Worker 进程启动时读取 WORKER_GPU_ID 环境变量(由 systemd 注入),
将 CUDA_VISIBLE_DEVICES 固定到该 GPU。
每个任务通过 subprocess 调用 torchrun --nproc_per_node=1进程间完全隔离。
"""
import os
import glob
import random
import subprocess
from datetime import datetime, timezone
from pathlib import Path
from celery import Celery
from .config import (
INFERENCE_CONFIGS,
OUTPUT_DIR,
PROJECT_ROOT,
REDIS_URL,
TASK_TIMEOUT,
TORCHRUN,
)
app = Celery("opensora", broker=REDIS_URL, backend=REDIS_URL)
app.conf.update(
task_serializer="json",
result_serializer="json",
accept_content=["json"],
result_expires=86400, # 结果在 Redis 中保留 24 小时
task_acks_late=True, # Worker 崩溃时任务重新入队
task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1, # 每个 Worker 同时只处理一个任务GPU 串行)
broker_connection_retry_on_startup=True,
)
# Worker 绑定的 GPU ID由 systemd 环境变量注入)
_GPU_ID = os.environ.get("WORKER_GPU_ID", "0")
@app.task(bind=True, name="generate_video", max_retries=0)
def generate_video(self, request_dict: dict) -> dict:
"""
执行视频生成推理,返回视频文件路径。
Args:
request_dict: GenerateRequest 的 dict 序列化
Returns:
dict: {"video_path": str, "completed_at": str}
"""
job_id = self.request.id
resolution = request_dict.get("resolution", "256px")
prompt = request_dict["prompt"]
num_frames = request_dict.get("num_frames", 49)
aspect_ratio = request_dict.get("aspect_ratio", "16:9")
motion_score = request_dict.get("motion_score", 4)
num_steps = request_dict.get("num_steps", 50)
seed = request_dict.get("seed") or random.randint(0, 2 ** 32 - 1)
cond_type = request_dict.get("cond_type", "t2v")
job_output_dir = OUTPUT_DIR / job_id
job_output_dir.mkdir(parents=True, exist_ok=True)
config_path = INFERENCE_CONFIGS[resolution]
timeout = TASK_TIMEOUT[resolution]
cmd = [
str(TORCHRUN),
"--nproc_per_node=1",
"--standalone",
"scripts/diffusion/inference.py",
config_path,
"--save-dir", str(job_output_dir),
"--save-prefix", f"{job_id}_",
"--prompt", prompt,
"--num_frames", str(num_frames),
"--aspect_ratio", aspect_ratio,
"--motion-score", str(motion_score),
"--num-steps", str(num_steps),
"--sampling_option.seed", str(seed),
"--cond_type", cond_type,
]
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = _GPU_ID
proc = subprocess.run(
cmd,
cwd=str(PROJECT_ROOT),
env=env,
capture_output=True,
text=True,
timeout=timeout,
)
if proc.returncode != 0:
raise RuntimeError(
f"推理进程退出码 {proc.returncode}\n"
f"STDOUT: {proc.stdout[-2000:]}\n"
f"STDERR: {proc.stderr[-2000:]}"
)
# 找到生成的视频文件inference.py 输出到 video_{resolution}/ 子目录)
pattern = str(job_output_dir / f"video_{resolution}" / f"{job_id}_*.mp4")
matches = glob.glob(pattern)
if not matches:
# 退回到任意 mp4
matches = glob.glob(str(job_output_dir / "**" / "*.mp4"), recursive=True)
if not matches:
raise FileNotFoundError(f"推理完成但未找到输出视频pattern: {pattern}")
video_path = matches[0]
return {
"video_path": video_path,
"completed_at": datetime.now(timezone.utc).isoformat(),
}