95 lines
2.7 KiB
Python
95 lines
2.7 KiB
Python
"""
|
||
FastAPI 推理服务入口。
|
||
|
||
端点:
|
||
POST /v1/generate 提交生成任务,返回 job_id
|
||
GET /v1/jobs/{job_id} 查询任务状态
|
||
GET /v1/videos/{job_id} 下载生成视频
|
||
GET /health 健康检查
|
||
"""
|
||
|
||
from datetime import datetime, timezone
|
||
from pathlib import Path
|
||
|
||
from celery.result import AsyncResult
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.responses import FileResponse
|
||
|
||
from .schemas import (
|
||
GenerateRequest,
|
||
JobResponse,
|
||
JobStatus,
|
||
SubmitResponse,
|
||
)
|
||
from .tasks import app as celery_app, generate_video
|
||
|
||
api = FastAPI(
|
||
title="Open-Sora Inference API",
|
||
version="1.0.0",
|
||
description="文本/图像生成视频 API,基于 Open-Sora v2.0(11B)",
|
||
)
|
||
|
||
|
||
@api.get("/health")
|
||
def health():
|
||
return {"status": "ok", "time": datetime.now(timezone.utc).isoformat()}
|
||
|
||
|
||
@api.post("/v1/generate", response_model=SubmitResponse, status_code=202)
|
||
def submit(req: GenerateRequest):
|
||
"""提交生成任务,立即返回 job_id,客户端轮询 /v1/jobs/{job_id} 获取结果。"""
|
||
task = generate_video.apply_async(
|
||
kwargs={"request_dict": req.model_dump()},
|
||
)
|
||
return SubmitResponse(job_id=task.id)
|
||
|
||
|
||
@api.get("/v1/jobs/{job_id}", response_model=JobResponse)
|
||
def get_job(job_id: str):
|
||
"""查询任务状态。status 可能是 pending / processing / completed / failed。"""
|
||
result = AsyncResult(job_id, app=celery_app)
|
||
|
||
if result.state == "PENDING":
|
||
return JobResponse(job_id=job_id, status=JobStatus.pending)
|
||
|
||
if result.state == "STARTED":
|
||
return JobResponse(job_id=job_id, status=JobStatus.processing)
|
||
|
||
if result.state == "SUCCESS":
|
||
info = result.result or {}
|
||
return JobResponse(
|
||
job_id=job_id,
|
||
status=JobStatus.completed,
|
||
video_url=f"/v1/videos/{job_id}",
|
||
completed_at=info.get("completed_at"),
|
||
)
|
||
|
||
if result.state == "FAILURE":
|
||
return JobResponse(
|
||
job_id=job_id,
|
||
status=JobStatus.failed,
|
||
error=str(result.result),
|
||
)
|
||
|
||
# RETRY / REVOKED 等其他状态
|
||
return JobResponse(job_id=job_id, status=JobStatus.processing)
|
||
|
||
|
||
@api.get("/v1/videos/{job_id}")
|
||
def download_video(job_id: str):
|
||
"""下载已完成任务的视频文件。"""
|
||
result = AsyncResult(job_id, app=celery_app)
|
||
if result.state != "SUCCESS":
|
||
raise HTTPException(status_code=404, detail="视频尚未生成或任务不存在")
|
||
|
||
info = result.result or {}
|
||
video_path = Path(info.get("video_path", ""))
|
||
if not video_path.exists():
|
||
raise HTTPException(status_code=404, detail="视频文件不存在")
|
||
|
||
return FileResponse(
|
||
path=str(video_path),
|
||
media_type="video/mp4",
|
||
filename=f"{job_id}.mp4",
|
||
)
|