mysora/api/main.py

95 lines
2.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
FastAPI 推理服务入口。
端点:
POST /v1/generate 提交生成任务,返回 job_id
GET /v1/jobs/{job_id} 查询任务状态
GET /v1/videos/{job_id} 下载生成视频
GET /health 健康检查
"""
from datetime import datetime, timezone
from pathlib import Path
from celery.result import AsyncResult
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from .schemas import (
GenerateRequest,
JobResponse,
JobStatus,
SubmitResponse,
)
from .tasks import app as celery_app, generate_video
api = FastAPI(
title="Open-Sora Inference API",
version="1.0.0",
description="文本/图像生成视频 API基于 Open-Sora v2.011B",
)
@api.get("/health")
def health():
return {"status": "ok", "time": datetime.now(timezone.utc).isoformat()}
@api.post("/v1/generate", response_model=SubmitResponse, status_code=202)
def submit(req: GenerateRequest):
"""提交生成任务,立即返回 job_id客户端轮询 /v1/jobs/{job_id} 获取结果。"""
task = generate_video.apply_async(
kwargs={"request_dict": req.model_dump()},
)
return SubmitResponse(job_id=task.id)
@api.get("/v1/jobs/{job_id}", response_model=JobResponse)
def get_job(job_id: str):
"""查询任务状态。status 可能是 pending / processing / completed / failed。"""
result = AsyncResult(job_id, app=celery_app)
if result.state == "PENDING":
return JobResponse(job_id=job_id, status=JobStatus.pending)
if result.state == "STARTED":
return JobResponse(job_id=job_id, status=JobStatus.processing)
if result.state == "SUCCESS":
info = result.result or {}
return JobResponse(
job_id=job_id,
status=JobStatus.completed,
video_url=f"/v1/videos/{job_id}",
completed_at=info.get("completed_at"),
)
if result.state == "FAILURE":
return JobResponse(
job_id=job_id,
status=JobStatus.failed,
error=str(result.result),
)
# RETRY / REVOKED 等其他状态
return JobResponse(job_id=job_id, status=JobStatus.processing)
@api.get("/v1/videos/{job_id}")
def download_video(job_id: str):
"""下载已完成任务的视频文件。"""
result = AsyncResult(job_id, app=celery_app)
if result.state != "SUCCESS":
raise HTTPException(status_code=404, detail="视频尚未生成或任务不存在")
info = result.result or {}
video_path = Path(info.get("video_path", ""))
if not video_path.exists():
raise HTTPException(status_code=404, detail="视频文件不存在")
return FileResponse(
path=str(video_path),
media_type="video/mp4",
filename=f"{job_id}.mp4",
)