feat: add production inference API (FastAPI + Celery + Redis + NGINX)
- api/: FastAPI app with /v1/generate, /v1/jobs/{id}, /v1/videos/{id}
- api/tasks.py: Celery worker, each GPU gets its own worker process
- deploy/: systemd units (opensora-api, opensora-worker@), nginx.conf, setup.sh
- Architecture: NGINX → Gunicorn/FastAPI → Redis → 8× Celery workers (GPU 0-7)
- Each task runs torchrun --nproc_per_node=1 subprocess, fully isolated per GPU
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
3be569a31e
commit
c56ae7bb7a
|
|
@ -0,0 +1,23 @@
|
|||
from pathlib import Path
|
||||
|
||||
PROJECT_ROOT = Path("/home/ceshi/mysora")
|
||||
OUTPUT_DIR = Path("/data/train-output/api-outputs")
|
||||
VENV_DIR = Path("/home/ceshi/venv")
|
||||
TORCHRUN = VENV_DIR / "bin" / "torchrun"
|
||||
|
||||
REDIS_URL = "redis://localhost:6379/0"
|
||||
|
||||
# 每张 GPU 加载一份模型,256px 单卡峰值 ~52.5GB,A100 80GB 完全可行
|
||||
NUM_GPUS = 8
|
||||
|
||||
# 推理配置(复用现有 256px/768px config)
|
||||
INFERENCE_CONFIGS = {
|
||||
"256px": "configs/diffusion/inference/256px.py",
|
||||
"768px": "configs/diffusion/inference/768px.py",
|
||||
}
|
||||
|
||||
# 每个任务最长等待时间(秒)。256px ~60s,768px ~1700s
|
||||
TASK_TIMEOUT = {
|
||||
"256px": 300,
|
||||
"768px": 3600,
|
||||
}
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
"""
|
||||
FastAPI 推理服务入口。
|
||||
|
||||
端点:
|
||||
POST /v1/generate 提交生成任务,返回 job_id
|
||||
GET /v1/jobs/{job_id} 查询任务状态
|
||||
GET /v1/videos/{job_id} 下载生成视频
|
||||
GET /health 健康检查
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from celery.result import AsyncResult
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from .schemas import (
|
||||
GenerateRequest,
|
||||
JobResponse,
|
||||
JobStatus,
|
||||
SubmitResponse,
|
||||
)
|
||||
from .tasks import app as celery_app, generate_video
|
||||
|
||||
api = FastAPI(
|
||||
title="Open-Sora Inference API",
|
||||
version="1.0.0",
|
||||
description="文本/图像生成视频 API,基于 Open-Sora v2.0(11B)",
|
||||
)
|
||||
|
||||
|
||||
@api.get("/health")
|
||||
def health():
|
||||
return {"status": "ok", "time": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
@api.post("/v1/generate", response_model=SubmitResponse, status_code=202)
|
||||
def submit(req: GenerateRequest):
|
||||
"""提交生成任务,立即返回 job_id,客户端轮询 /v1/jobs/{job_id} 获取结果。"""
|
||||
task = generate_video.apply_async(
|
||||
kwargs={"request_dict": req.model_dump()},
|
||||
)
|
||||
return SubmitResponse(job_id=task.id)
|
||||
|
||||
|
||||
@api.get("/v1/jobs/{job_id}", response_model=JobResponse)
|
||||
def get_job(job_id: str):
|
||||
"""查询任务状态。status 可能是 pending / processing / completed / failed。"""
|
||||
result = AsyncResult(job_id, app=celery_app)
|
||||
|
||||
if result.state == "PENDING":
|
||||
return JobResponse(job_id=job_id, status=JobStatus.pending)
|
||||
|
||||
if result.state == "STARTED":
|
||||
return JobResponse(job_id=job_id, status=JobStatus.processing)
|
||||
|
||||
if result.state == "SUCCESS":
|
||||
info = result.result or {}
|
||||
return JobResponse(
|
||||
job_id=job_id,
|
||||
status=JobStatus.completed,
|
||||
video_url=f"/v1/videos/{job_id}",
|
||||
completed_at=info.get("completed_at"),
|
||||
)
|
||||
|
||||
if result.state == "FAILURE":
|
||||
return JobResponse(
|
||||
job_id=job_id,
|
||||
status=JobStatus.failed,
|
||||
error=str(result.result),
|
||||
)
|
||||
|
||||
# RETRY / REVOKED 等其他状态
|
||||
return JobResponse(job_id=job_id, status=JobStatus.processing)
|
||||
|
||||
|
||||
@api.get("/v1/videos/{job_id}")
|
||||
def download_video(job_id: str):
|
||||
"""下载已完成任务的视频文件。"""
|
||||
result = AsyncResult(job_id, app=celery_app)
|
||||
if result.state != "SUCCESS":
|
||||
raise HTTPException(status_code=404, detail="视频尚未生成或任务不存在")
|
||||
|
||||
info = result.result or {}
|
||||
video_path = Path(info.get("video_path", ""))
|
||||
if not video_path.exists():
|
||||
raise HTTPException(status_code=404, detail="视频文件不存在")
|
||||
|
||||
return FileResponse(
|
||||
path=str(video_path),
|
||||
media_type="video/mp4",
|
||||
filename=f"{job_id}.mp4",
|
||||
)
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Resolution(str, Enum):
|
||||
px256 = "256px"
|
||||
px768 = "768px"
|
||||
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
landscape = "16:9"
|
||||
portrait = "9:16"
|
||||
square = "1:1"
|
||||
cinematic = "2.39:1"
|
||||
|
||||
|
||||
class CondType(str, Enum):
|
||||
t2v = "t2v"
|
||||
i2v_head = "i2v_head"
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., min_length=1, max_length=2000, description="文本提示词")
|
||||
resolution: Resolution = Field(Resolution.px256, description="分辨率")
|
||||
aspect_ratio: AspectRatio = Field(AspectRatio.landscape, description="画面比例")
|
||||
num_frames: int = Field(49, ge=1, le=129, description="帧数(4k+1,建议 49=~2s,97=~4s,129=~5s)")
|
||||
motion_score: int = Field(4, ge=1, le=7, description="运动幅度(1=静态,7=剧烈)")
|
||||
num_steps: int = Field(50, ge=10, le=100, description="扩散步数,越多质量越高但越慢")
|
||||
seed: Optional[int] = Field(None, description="随机种子,不填则随机")
|
||||
cond_type: CondType = Field(CondType.t2v, description="生成类型")
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
pending = "pending"
|
||||
processing = "processing"
|
||||
completed = "completed"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class SubmitResponse(BaseModel):
|
||||
job_id: str
|
||||
message: str = "任务已提交"
|
||||
|
||||
|
||||
class JobResponse(BaseModel):
|
||||
job_id: str
|
||||
status: JobStatus
|
||||
video_url: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
created_at: Optional[str] = None
|
||||
completed_at: Optional[str] = None
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
"""
|
||||
Celery 任务定义。
|
||||
|
||||
每个 Worker 进程启动时读取 WORKER_GPU_ID 环境变量(由 systemd 注入),
|
||||
将 CUDA_VISIBLE_DEVICES 固定到该 GPU。
|
||||
每个任务通过 subprocess 调用 torchrun --nproc_per_node=1,进程间完全隔离。
|
||||
"""
|
||||
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
import subprocess
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from celery import Celery
|
||||
|
||||
from .config import (
|
||||
INFERENCE_CONFIGS,
|
||||
OUTPUT_DIR,
|
||||
PROJECT_ROOT,
|
||||
REDIS_URL,
|
||||
TASK_TIMEOUT,
|
||||
TORCHRUN,
|
||||
)
|
||||
|
||||
app = Celery("opensora", broker=REDIS_URL, backend=REDIS_URL)
|
||||
|
||||
app.conf.update(
|
||||
task_serializer="json",
|
||||
result_serializer="json",
|
||||
accept_content=["json"],
|
||||
result_expires=86400, # 结果在 Redis 中保留 24 小时
|
||||
task_acks_late=True, # Worker 崩溃时任务重新入队
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_prefetch_multiplier=1, # 每个 Worker 同时只处理一个任务(GPU 串行)
|
||||
broker_connection_retry_on_startup=True,
|
||||
)
|
||||
|
||||
# Worker 绑定的 GPU ID(由 systemd 环境变量注入)
|
||||
_GPU_ID = os.environ.get("WORKER_GPU_ID", "0")
|
||||
|
||||
|
||||
@app.task(bind=True, name="generate_video", max_retries=0)
|
||||
def generate_video(self, request_dict: dict) -> dict:
|
||||
"""
|
||||
执行视频生成推理,返回视频文件路径。
|
||||
|
||||
Args:
|
||||
request_dict: GenerateRequest 的 dict 序列化
|
||||
|
||||
Returns:
|
||||
dict: {"video_path": str, "completed_at": str}
|
||||
"""
|
||||
job_id = self.request.id
|
||||
resolution = request_dict.get("resolution", "256px")
|
||||
prompt = request_dict["prompt"]
|
||||
num_frames = request_dict.get("num_frames", 49)
|
||||
aspect_ratio = request_dict.get("aspect_ratio", "16:9")
|
||||
motion_score = request_dict.get("motion_score", 4)
|
||||
num_steps = request_dict.get("num_steps", 50)
|
||||
seed = request_dict.get("seed") or random.randint(0, 2 ** 32 - 1)
|
||||
cond_type = request_dict.get("cond_type", "t2v")
|
||||
|
||||
job_output_dir = OUTPUT_DIR / job_id
|
||||
job_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
config_path = INFERENCE_CONFIGS[resolution]
|
||||
timeout = TASK_TIMEOUT[resolution]
|
||||
|
||||
cmd = [
|
||||
str(TORCHRUN),
|
||||
"--nproc_per_node=1",
|
||||
"--standalone",
|
||||
"scripts/diffusion/inference.py",
|
||||
config_path,
|
||||
"--save-dir", str(job_output_dir),
|
||||
"--save-prefix", f"{job_id}_",
|
||||
"--prompt", prompt,
|
||||
"--num_frames", str(num_frames),
|
||||
"--aspect_ratio", aspect_ratio,
|
||||
"--motion-score", str(motion_score),
|
||||
"--num-steps", str(num_steps),
|
||||
"--sampling_option.seed", str(seed),
|
||||
"--cond_type", cond_type,
|
||||
]
|
||||
|
||||
env = os.environ.copy()
|
||||
env["CUDA_VISIBLE_DEVICES"] = _GPU_ID
|
||||
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
cwd=str(PROJECT_ROOT),
|
||||
env=env,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
f"推理进程退出码 {proc.returncode}\n"
|
||||
f"STDOUT: {proc.stdout[-2000:]}\n"
|
||||
f"STDERR: {proc.stderr[-2000:]}"
|
||||
)
|
||||
|
||||
# 找到生成的视频文件(inference.py 输出到 video_{resolution}/ 子目录)
|
||||
pattern = str(job_output_dir / f"video_{resolution}" / f"{job_id}_*.mp4")
|
||||
matches = glob.glob(pattern)
|
||||
if not matches:
|
||||
# 退回到任意 mp4
|
||||
matches = glob.glob(str(job_output_dir / "**" / "*.mp4"), recursive=True)
|
||||
if not matches:
|
||||
raise FileNotFoundError(f"推理完成但未找到输出视频,pattern: {pattern}")
|
||||
|
||||
video_path = matches[0]
|
||||
return {
|
||||
"video_path": video_path,
|
||||
"completed_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
upstream opensora_api {
|
||||
server 127.0.0.1:8000;
|
||||
keepalive 32;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80;
|
||||
server_name _;
|
||||
|
||||
# 视频文件最大 500MB
|
||||
client_max_body_size 500M;
|
||||
|
||||
# 防止单 IP 提交过多任务
|
||||
limit_req_zone $binary_remote_addr zone=api_limit:10m rate=5r/m;
|
||||
|
||||
location /health {
|
||||
proxy_pass http://opensora_api;
|
||||
proxy_set_header Host $host;
|
||||
}
|
||||
|
||||
location /v1/ {
|
||||
limit_req zone=api_limit burst=10 nodelay;
|
||||
|
||||
proxy_pass http://opensora_api;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
|
||||
# 视频下载可能需要较长时间
|
||||
proxy_read_timeout 3600;
|
||||
proxy_send_timeout 3600;
|
||||
proxy_connect_timeout 10;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
[Unit]
|
||||
Description=Open-Sora FastAPI Service
|
||||
After=network.target redis.service
|
||||
Requires=redis.service
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=ceshi
|
||||
Group=ceshi
|
||||
WorkingDirectory=/home/ceshi/mysora
|
||||
Environment="PATH=/home/ceshi/venv/bin:/usr/local/bin:/usr/bin:/bin"
|
||||
|
||||
ExecStart=/home/ceshi/venv/bin/gunicorn api.main:api \
|
||||
--workers 4 \
|
||||
--worker-class uvicorn.workers.UvicornWorker \
|
||||
--bind 127.0.0.1:8000 \
|
||||
--timeout 3600 \
|
||||
--keep-alive 5 \
|
||||
--log-level info \
|
||||
--access-logfile /data/train-output/logs/api-access.log \
|
||||
--error-logfile /data/train-output/logs/api-error.log
|
||||
|
||||
# 崩溃自动重启
|
||||
Restart=always
|
||||
RestartSec=5
|
||||
StartLimitIntervalSec=60
|
||||
StartLimitBurst=5
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
[Unit]
|
||||
Description=Open-Sora Celery Worker GPU%i
|
||||
After=network.target redis.service
|
||||
Requires=redis.service
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=ceshi
|
||||
Group=ceshi
|
||||
WorkingDirectory=/home/ceshi/mysora
|
||||
Environment="PATH=/home/ceshi/venv/bin:/usr/local/bin:/usr/bin:/bin"
|
||||
|
||||
# %i 是 systemd 模板参数,即 GPU ID (0-7)
|
||||
Environment="WORKER_GPU_ID=%i"
|
||||
|
||||
ExecStart=/home/ceshi/venv/bin/celery \
|
||||
--app api.tasks \
|
||||
worker \
|
||||
--loglevel=info \
|
||||
--concurrency=1 \
|
||||
--hostname=gpu%i@%%h \
|
||||
--logfile=/data/train-output/logs/worker-gpu%i.log \
|
||||
--max-tasks-per-child=10
|
||||
|
||||
# 每处理 10 个任务重启 Worker,防止 GPU 内存碎片积累
|
||||
# 崩溃自动重启
|
||||
Restart=always
|
||||
RestartSec=10
|
||||
StartLimitIntervalSec=120
|
||||
StartLimitBurst=5
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
#!/usr/bin/env bash
|
||||
# 在训练服务器(ceshi 用户)上执行,完成 API 服务的一键部署
|
||||
set -euo pipefail
|
||||
|
||||
REPO=/home/ceshi/mysora
|
||||
VENV=/home/ceshi/venv
|
||||
LOG_DIR=/data/train-output/logs
|
||||
|
||||
echo "=== [1/6] 安装 API 依赖 ==="
|
||||
$VENV/bin/pip install \
|
||||
fastapi \
|
||||
"uvicorn[standard]" \
|
||||
gunicorn \
|
||||
celery[redis] \
|
||||
redis \
|
||||
--quiet
|
||||
|
||||
echo "=== [2/6] 安装 NGINX ==="
|
||||
sudo apt-get install -y nginx --quiet
|
||||
|
||||
echo "=== [3/6] 配置 NGINX ==="
|
||||
sudo cp $REPO/deploy/nginx.conf /etc/nginx/sites-available/opensora
|
||||
sudo ln -sf /etc/nginx/sites-available/opensora /etc/nginx/sites-enabled/opensora
|
||||
sudo rm -f /etc/nginx/sites-enabled/default
|
||||
sudo nginx -t
|
||||
sudo systemctl enable nginx
|
||||
sudo systemctl restart nginx
|
||||
|
||||
echo "=== [4/6] 安装并启动 Redis ==="
|
||||
sudo apt-get install -y redis-server --quiet
|
||||
# 开启 AOF 持久化,防止重启丢失任务
|
||||
sudo sed -i 's/^appendonly no/appendonly yes/' /etc/redis/redis.conf
|
||||
sudo systemctl enable redis
|
||||
sudo systemctl restart redis
|
||||
|
||||
echo "=== [5/6] 创建日志目录 ==="
|
||||
mkdir -p $LOG_DIR
|
||||
sudo chown -R ceshi:ceshi $LOG_DIR
|
||||
|
||||
echo "=== [6/6] 注册并启动 systemd 服务 ==="
|
||||
sudo cp $REPO/deploy/opensora-api.service /etc/systemd/system/
|
||||
sudo cp $REPO/deploy/opensora-worker@.service /etc/systemd/system/
|
||||
sudo systemctl daemon-reload
|
||||
|
||||
# 启动 API 服务
|
||||
sudo systemctl enable opensora-api
|
||||
sudo systemctl restart opensora-api
|
||||
|
||||
# 启动 8 个 GPU Worker(GPU 0-7)
|
||||
for i in $(seq 0 7); do
|
||||
sudo systemctl enable opensora-worker@$i
|
||||
sudo systemctl restart opensora-worker@$i
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "=== 部署完成 ==="
|
||||
echo "API 服务: http://$(hostname -I | awk '{print $1}')/v1/generate"
|
||||
echo "健康检查: http://$(hostname -I | awk '{print $1}')/health"
|
||||
echo ""
|
||||
echo "查看 API 日志: sudo journalctl -u opensora-api -f"
|
||||
echo "查看 Worker 日志: tail -f $LOG_DIR/worker-gpu0.log"
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.30.0
|
||||
gunicorn>=22.0.0
|
||||
celery[redis]>=5.4.0
|
||||
redis>=5.0.0
|
||||
Loading…
Reference in New Issue