feat: add production inference API (FastAPI + Celery + Redis + NGINX)
- api/: FastAPI app with /v1/generate, /v1/jobs/{id}, /v1/videos/{id}
- api/tasks.py: Celery worker, each GPU gets its own worker process
- deploy/: systemd units (opensora-api, opensora-worker@), nginx.conf, setup.sh
- Architecture: NGINX → Gunicorn/FastAPI → Redis → 8× Celery workers (GPU 0-7)
- Each task runs torchrun --nproc_per_node=1 subprocess, fully isolated per GPU
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
3be569a31e
commit
c56ae7bb7a
|
|
@ -0,0 +1,23 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
PROJECT_ROOT = Path("/home/ceshi/mysora")
|
||||||
|
OUTPUT_DIR = Path("/data/train-output/api-outputs")
|
||||||
|
VENV_DIR = Path("/home/ceshi/venv")
|
||||||
|
TORCHRUN = VENV_DIR / "bin" / "torchrun"
|
||||||
|
|
||||||
|
REDIS_URL = "redis://localhost:6379/0"
|
||||||
|
|
||||||
|
# 每张 GPU 加载一份模型,256px 单卡峰值 ~52.5GB,A100 80GB 完全可行
|
||||||
|
NUM_GPUS = 8
|
||||||
|
|
||||||
|
# 推理配置(复用现有 256px/768px config)
|
||||||
|
INFERENCE_CONFIGS = {
|
||||||
|
"256px": "configs/diffusion/inference/256px.py",
|
||||||
|
"768px": "configs/diffusion/inference/768px.py",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 每个任务最长等待时间(秒)。256px ~60s,768px ~1700s
|
||||||
|
TASK_TIMEOUT = {
|
||||||
|
"256px": 300,
|
||||||
|
"768px": 3600,
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
"""
|
||||||
|
FastAPI 推理服务入口。
|
||||||
|
|
||||||
|
端点:
|
||||||
|
POST /v1/generate 提交生成任务,返回 job_id
|
||||||
|
GET /v1/jobs/{job_id} 查询任务状态
|
||||||
|
GET /v1/videos/{job_id} 下载生成视频
|
||||||
|
GET /health 健康检查
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from celery.result import AsyncResult
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
|
||||||
|
from .schemas import (
|
||||||
|
GenerateRequest,
|
||||||
|
JobResponse,
|
||||||
|
JobStatus,
|
||||||
|
SubmitResponse,
|
||||||
|
)
|
||||||
|
from .tasks import app as celery_app, generate_video
|
||||||
|
|
||||||
|
api = FastAPI(
|
||||||
|
title="Open-Sora Inference API",
|
||||||
|
version="1.0.0",
|
||||||
|
description="文本/图像生成视频 API,基于 Open-Sora v2.0(11B)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@api.get("/health")
|
||||||
|
def health():
|
||||||
|
return {"status": "ok", "time": datetime.now(timezone.utc).isoformat()}
|
||||||
|
|
||||||
|
|
||||||
|
@api.post("/v1/generate", response_model=SubmitResponse, status_code=202)
|
||||||
|
def submit(req: GenerateRequest):
|
||||||
|
"""提交生成任务,立即返回 job_id,客户端轮询 /v1/jobs/{job_id} 获取结果。"""
|
||||||
|
task = generate_video.apply_async(
|
||||||
|
kwargs={"request_dict": req.model_dump()},
|
||||||
|
)
|
||||||
|
return SubmitResponse(job_id=task.id)
|
||||||
|
|
||||||
|
|
||||||
|
@api.get("/v1/jobs/{job_id}", response_model=JobResponse)
|
||||||
|
def get_job(job_id: str):
|
||||||
|
"""查询任务状态。status 可能是 pending / processing / completed / failed。"""
|
||||||
|
result = AsyncResult(job_id, app=celery_app)
|
||||||
|
|
||||||
|
if result.state == "PENDING":
|
||||||
|
return JobResponse(job_id=job_id, status=JobStatus.pending)
|
||||||
|
|
||||||
|
if result.state == "STARTED":
|
||||||
|
return JobResponse(job_id=job_id, status=JobStatus.processing)
|
||||||
|
|
||||||
|
if result.state == "SUCCESS":
|
||||||
|
info = result.result or {}
|
||||||
|
return JobResponse(
|
||||||
|
job_id=job_id,
|
||||||
|
status=JobStatus.completed,
|
||||||
|
video_url=f"/v1/videos/{job_id}",
|
||||||
|
completed_at=info.get("completed_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.state == "FAILURE":
|
||||||
|
return JobResponse(
|
||||||
|
job_id=job_id,
|
||||||
|
status=JobStatus.failed,
|
||||||
|
error=str(result.result),
|
||||||
|
)
|
||||||
|
|
||||||
|
# RETRY / REVOKED 等其他状态
|
||||||
|
return JobResponse(job_id=job_id, status=JobStatus.processing)
|
||||||
|
|
||||||
|
|
||||||
|
@api.get("/v1/videos/{job_id}")
|
||||||
|
def download_video(job_id: str):
|
||||||
|
"""下载已完成任务的视频文件。"""
|
||||||
|
result = AsyncResult(job_id, app=celery_app)
|
||||||
|
if result.state != "SUCCESS":
|
||||||
|
raise HTTPException(status_code=404, detail="视频尚未生成或任务不存在")
|
||||||
|
|
||||||
|
info = result.result or {}
|
||||||
|
video_path = Path(info.get("video_path", ""))
|
||||||
|
if not video_path.exists():
|
||||||
|
raise HTTPException(status_code=404, detail="视频文件不存在")
|
||||||
|
|
||||||
|
return FileResponse(
|
||||||
|
path=str(video_path),
|
||||||
|
media_type="video/mp4",
|
||||||
|
filename=f"{job_id}.mp4",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,53 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class Resolution(str, Enum):
|
||||||
|
px256 = "256px"
|
||||||
|
px768 = "768px"
|
||||||
|
|
||||||
|
|
||||||
|
class AspectRatio(str, Enum):
|
||||||
|
landscape = "16:9"
|
||||||
|
portrait = "9:16"
|
||||||
|
square = "1:1"
|
||||||
|
cinematic = "2.39:1"
|
||||||
|
|
||||||
|
|
||||||
|
class CondType(str, Enum):
|
||||||
|
t2v = "t2v"
|
||||||
|
i2v_head = "i2v_head"
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateRequest(BaseModel):
|
||||||
|
prompt: str = Field(..., min_length=1, max_length=2000, description="文本提示词")
|
||||||
|
resolution: Resolution = Field(Resolution.px256, description="分辨率")
|
||||||
|
aspect_ratio: AspectRatio = Field(AspectRatio.landscape, description="画面比例")
|
||||||
|
num_frames: int = Field(49, ge=1, le=129, description="帧数(4k+1,建议 49=~2s,97=~4s,129=~5s)")
|
||||||
|
motion_score: int = Field(4, ge=1, le=7, description="运动幅度(1=静态,7=剧烈)")
|
||||||
|
num_steps: int = Field(50, ge=10, le=100, description="扩散步数,越多质量越高但越慢")
|
||||||
|
seed: Optional[int] = Field(None, description="随机种子,不填则随机")
|
||||||
|
cond_type: CondType = Field(CondType.t2v, description="生成类型")
|
||||||
|
|
||||||
|
|
||||||
|
class JobStatus(str, Enum):
|
||||||
|
pending = "pending"
|
||||||
|
processing = "processing"
|
||||||
|
completed = "completed"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class SubmitResponse(BaseModel):
|
||||||
|
job_id: str
|
||||||
|
message: str = "任务已提交"
|
||||||
|
|
||||||
|
|
||||||
|
class JobResponse(BaseModel):
|
||||||
|
job_id: str
|
||||||
|
status: JobStatus
|
||||||
|
video_url: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
created_at: Optional[str] = None
|
||||||
|
completed_at: Optional[str] = None
|
||||||
|
|
@ -0,0 +1,120 @@
|
||||||
|
"""
|
||||||
|
Celery 任务定义。
|
||||||
|
|
||||||
|
每个 Worker 进程启动时读取 WORKER_GPU_ID 环境变量(由 systemd 注入),
|
||||||
|
将 CUDA_VISIBLE_DEVICES 固定到该 GPU。
|
||||||
|
每个任务通过 subprocess 调用 torchrun --nproc_per_node=1,进程间完全隔离。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import random
|
||||||
|
import subprocess
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from celery import Celery
|
||||||
|
|
||||||
|
from .config import (
|
||||||
|
INFERENCE_CONFIGS,
|
||||||
|
OUTPUT_DIR,
|
||||||
|
PROJECT_ROOT,
|
||||||
|
REDIS_URL,
|
||||||
|
TASK_TIMEOUT,
|
||||||
|
TORCHRUN,
|
||||||
|
)
|
||||||
|
|
||||||
|
app = Celery("opensora", broker=REDIS_URL, backend=REDIS_URL)
|
||||||
|
|
||||||
|
app.conf.update(
|
||||||
|
task_serializer="json",
|
||||||
|
result_serializer="json",
|
||||||
|
accept_content=["json"],
|
||||||
|
result_expires=86400, # 结果在 Redis 中保留 24 小时
|
||||||
|
task_acks_late=True, # Worker 崩溃时任务重新入队
|
||||||
|
task_reject_on_worker_lost=True,
|
||||||
|
worker_prefetch_multiplier=1, # 每个 Worker 同时只处理一个任务(GPU 串行)
|
||||||
|
broker_connection_retry_on_startup=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Worker 绑定的 GPU ID(由 systemd 环境变量注入)
|
||||||
|
_GPU_ID = os.environ.get("WORKER_GPU_ID", "0")
|
||||||
|
|
||||||
|
|
||||||
|
@app.task(bind=True, name="generate_video", max_retries=0)
|
||||||
|
def generate_video(self, request_dict: dict) -> dict:
|
||||||
|
"""
|
||||||
|
执行视频生成推理,返回视频文件路径。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request_dict: GenerateRequest 的 dict 序列化
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: {"video_path": str, "completed_at": str}
|
||||||
|
"""
|
||||||
|
job_id = self.request.id
|
||||||
|
resolution = request_dict.get("resolution", "256px")
|
||||||
|
prompt = request_dict["prompt"]
|
||||||
|
num_frames = request_dict.get("num_frames", 49)
|
||||||
|
aspect_ratio = request_dict.get("aspect_ratio", "16:9")
|
||||||
|
motion_score = request_dict.get("motion_score", 4)
|
||||||
|
num_steps = request_dict.get("num_steps", 50)
|
||||||
|
seed = request_dict.get("seed") or random.randint(0, 2 ** 32 - 1)
|
||||||
|
cond_type = request_dict.get("cond_type", "t2v")
|
||||||
|
|
||||||
|
job_output_dir = OUTPUT_DIR / job_id
|
||||||
|
job_output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
config_path = INFERENCE_CONFIGS[resolution]
|
||||||
|
timeout = TASK_TIMEOUT[resolution]
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
str(TORCHRUN),
|
||||||
|
"--nproc_per_node=1",
|
||||||
|
"--standalone",
|
||||||
|
"scripts/diffusion/inference.py",
|
||||||
|
config_path,
|
||||||
|
"--save-dir", str(job_output_dir),
|
||||||
|
"--save-prefix", f"{job_id}_",
|
||||||
|
"--prompt", prompt,
|
||||||
|
"--num_frames", str(num_frames),
|
||||||
|
"--aspect_ratio", aspect_ratio,
|
||||||
|
"--motion-score", str(motion_score),
|
||||||
|
"--num-steps", str(num_steps),
|
||||||
|
"--sampling_option.seed", str(seed),
|
||||||
|
"--cond_type", cond_type,
|
||||||
|
]
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["CUDA_VISIBLE_DEVICES"] = _GPU_ID
|
||||||
|
|
||||||
|
proc = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
cwd=str(PROJECT_ROOT),
|
||||||
|
env=env,
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
|
||||||
|
if proc.returncode != 0:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"推理进程退出码 {proc.returncode}\n"
|
||||||
|
f"STDOUT: {proc.stdout[-2000:]}\n"
|
||||||
|
f"STDERR: {proc.stderr[-2000:]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 找到生成的视频文件(inference.py 输出到 video_{resolution}/ 子目录)
|
||||||
|
pattern = str(job_output_dir / f"video_{resolution}" / f"{job_id}_*.mp4")
|
||||||
|
matches = glob.glob(pattern)
|
||||||
|
if not matches:
|
||||||
|
# 退回到任意 mp4
|
||||||
|
matches = glob.glob(str(job_output_dir / "**" / "*.mp4"), recursive=True)
|
||||||
|
if not matches:
|
||||||
|
raise FileNotFoundError(f"推理完成但未找到输出视频,pattern: {pattern}")
|
||||||
|
|
||||||
|
video_path = matches[0]
|
||||||
|
return {
|
||||||
|
"video_path": video_path,
|
||||||
|
"completed_at": datetime.now(timezone.utc).isoformat(),
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,34 @@
|
||||||
|
upstream opensora_api {
|
||||||
|
server 127.0.0.1:8000;
|
||||||
|
keepalive 32;
|
||||||
|
}
|
||||||
|
|
||||||
|
server {
|
||||||
|
listen 80;
|
||||||
|
server_name _;
|
||||||
|
|
||||||
|
# 视频文件最大 500MB
|
||||||
|
client_max_body_size 500M;
|
||||||
|
|
||||||
|
# 防止单 IP 提交过多任务
|
||||||
|
limit_req_zone $binary_remote_addr zone=api_limit:10m rate=5r/m;
|
||||||
|
|
||||||
|
location /health {
|
||||||
|
proxy_pass http://opensora_api;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
}
|
||||||
|
|
||||||
|
location /v1/ {
|
||||||
|
limit_req zone=api_limit burst=10 nodelay;
|
||||||
|
|
||||||
|
proxy_pass http://opensora_api;
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
|
||||||
|
# 视频下载可能需要较长时间
|
||||||
|
proxy_read_timeout 3600;
|
||||||
|
proxy_send_timeout 3600;
|
||||||
|
proxy_connect_timeout 10;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
[Unit]
|
||||||
|
Description=Open-Sora FastAPI Service
|
||||||
|
After=network.target redis.service
|
||||||
|
Requires=redis.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
User=ceshi
|
||||||
|
Group=ceshi
|
||||||
|
WorkingDirectory=/home/ceshi/mysora
|
||||||
|
Environment="PATH=/home/ceshi/venv/bin:/usr/local/bin:/usr/bin:/bin"
|
||||||
|
|
||||||
|
ExecStart=/home/ceshi/venv/bin/gunicorn api.main:api \
|
||||||
|
--workers 4 \
|
||||||
|
--worker-class uvicorn.workers.UvicornWorker \
|
||||||
|
--bind 127.0.0.1:8000 \
|
||||||
|
--timeout 3600 \
|
||||||
|
--keep-alive 5 \
|
||||||
|
--log-level info \
|
||||||
|
--access-logfile /data/train-output/logs/api-access.log \
|
||||||
|
--error-logfile /data/train-output/logs/api-error.log
|
||||||
|
|
||||||
|
# 崩溃自动重启
|
||||||
|
Restart=always
|
||||||
|
RestartSec=5
|
||||||
|
StartLimitIntervalSec=60
|
||||||
|
StartLimitBurst=5
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
[Unit]
|
||||||
|
Description=Open-Sora Celery Worker GPU%i
|
||||||
|
After=network.target redis.service
|
||||||
|
Requires=redis.service
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
User=ceshi
|
||||||
|
Group=ceshi
|
||||||
|
WorkingDirectory=/home/ceshi/mysora
|
||||||
|
Environment="PATH=/home/ceshi/venv/bin:/usr/local/bin:/usr/bin:/bin"
|
||||||
|
|
||||||
|
# %i 是 systemd 模板参数,即 GPU ID (0-7)
|
||||||
|
Environment="WORKER_GPU_ID=%i"
|
||||||
|
|
||||||
|
ExecStart=/home/ceshi/venv/bin/celery \
|
||||||
|
--app api.tasks \
|
||||||
|
worker \
|
||||||
|
--loglevel=info \
|
||||||
|
--concurrency=1 \
|
||||||
|
--hostname=gpu%i@%%h \
|
||||||
|
--logfile=/data/train-output/logs/worker-gpu%i.log \
|
||||||
|
--max-tasks-per-child=10
|
||||||
|
|
||||||
|
# 每处理 10 个任务重启 Worker,防止 GPU 内存碎片积累
|
||||||
|
# 崩溃自动重启
|
||||||
|
Restart=always
|
||||||
|
RestartSec=10
|
||||||
|
StartLimitIntervalSec=120
|
||||||
|
StartLimitBurst=5
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# 在训练服务器(ceshi 用户)上执行,完成 API 服务的一键部署
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
REPO=/home/ceshi/mysora
|
||||||
|
VENV=/home/ceshi/venv
|
||||||
|
LOG_DIR=/data/train-output/logs
|
||||||
|
|
||||||
|
echo "=== [1/6] 安装 API 依赖 ==="
|
||||||
|
$VENV/bin/pip install \
|
||||||
|
fastapi \
|
||||||
|
"uvicorn[standard]" \
|
||||||
|
gunicorn \
|
||||||
|
celery[redis] \
|
||||||
|
redis \
|
||||||
|
--quiet
|
||||||
|
|
||||||
|
echo "=== [2/6] 安装 NGINX ==="
|
||||||
|
sudo apt-get install -y nginx --quiet
|
||||||
|
|
||||||
|
echo "=== [3/6] 配置 NGINX ==="
|
||||||
|
sudo cp $REPO/deploy/nginx.conf /etc/nginx/sites-available/opensora
|
||||||
|
sudo ln -sf /etc/nginx/sites-available/opensora /etc/nginx/sites-enabled/opensora
|
||||||
|
sudo rm -f /etc/nginx/sites-enabled/default
|
||||||
|
sudo nginx -t
|
||||||
|
sudo systemctl enable nginx
|
||||||
|
sudo systemctl restart nginx
|
||||||
|
|
||||||
|
echo "=== [4/6] 安装并启动 Redis ==="
|
||||||
|
sudo apt-get install -y redis-server --quiet
|
||||||
|
# 开启 AOF 持久化,防止重启丢失任务
|
||||||
|
sudo sed -i 's/^appendonly no/appendonly yes/' /etc/redis/redis.conf
|
||||||
|
sudo systemctl enable redis
|
||||||
|
sudo systemctl restart redis
|
||||||
|
|
||||||
|
echo "=== [5/6] 创建日志目录 ==="
|
||||||
|
mkdir -p $LOG_DIR
|
||||||
|
sudo chown -R ceshi:ceshi $LOG_DIR
|
||||||
|
|
||||||
|
echo "=== [6/6] 注册并启动 systemd 服务 ==="
|
||||||
|
sudo cp $REPO/deploy/opensora-api.service /etc/systemd/system/
|
||||||
|
sudo cp $REPO/deploy/opensora-worker@.service /etc/systemd/system/
|
||||||
|
sudo systemctl daemon-reload
|
||||||
|
|
||||||
|
# 启动 API 服务
|
||||||
|
sudo systemctl enable opensora-api
|
||||||
|
sudo systemctl restart opensora-api
|
||||||
|
|
||||||
|
# 启动 8 个 GPU Worker(GPU 0-7)
|
||||||
|
for i in $(seq 0 7); do
|
||||||
|
sudo systemctl enable opensora-worker@$i
|
||||||
|
sudo systemctl restart opensora-worker@$i
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 部署完成 ==="
|
||||||
|
echo "API 服务: http://$(hostname -I | awk '{print $1}')/v1/generate"
|
||||||
|
echo "健康检查: http://$(hostname -I | awk '{print $1}')/health"
|
||||||
|
echo ""
|
||||||
|
echo "查看 API 日志: sudo journalctl -u opensora-api -f"
|
||||||
|
echo "查看 Worker 日志: tail -f $LOG_DIR/worker-gpu0.log"
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.30.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
celery[redis]>=5.4.0
|
||||||
|
redis>=5.0.0
|
||||||
Loading…
Reference in New Issue