47 lines
1.4 KiB
Python
47 lines
1.4 KiB
Python
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Header
|
|
from pydantic import BaseModel
|
|
from faster_whisper import WhisperModel
|
|
import os
|
|
|
|
app = FastAPI()
|
|
|
|
# 支持的模型
|
|
SUPPORTED_MODELS = {
|
|
"FunAudioLLM/SenseVoiceSmall": "small",
|
|
"FunAudioLLM/SenseVoiceMedium": "medium",
|
|
"FunAudioLLM/SenseVoiceLarge": "large"
|
|
}
|
|
|
|
# 模型缓存(常驻内存)
|
|
MODEL_CACHE = {}
|
|
|
|
def get_model(model_name: str):
|
|
if model_name not in SUPPORTED_MODELS:
|
|
raise HTTPException(status_code=400, detail="Unsupported model")
|
|
if model_name not in MODEL_CACHE:
|
|
MODEL_CACHE[model_name] = WhisperModel(SUPPORTED_MODELS[model_name], compute_type="int8")
|
|
return MODEL_CACHE[model_name]
|
|
|
|
@app.post("/audio/transcriptions")
|
|
async def transcribe_audio(
|
|
file: UploadFile = File(...),
|
|
model: str = Form(...),
|
|
authorization: str = Header(None)
|
|
):
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
|
|
# 保存音频临时文件
|
|
audio_path = "temp_audio." + file.filename.split(".")[-1]
|
|
with open(audio_path, "wb") as f:
|
|
f.write(await file.read())
|
|
|
|
# 模型推理
|
|
model_runner = get_model(model)
|
|
segments, _ = model_runner.transcribe(audio_path)
|
|
text = " ".join([seg.text for seg in segments])
|
|
|
|
os.remove(audio_path)
|
|
|
|
return {"text": text}
|