feat: lazy-load local TTS/STT models on first request
Local /synthesize and /transcribe endpoints now auto-load Kokoro/Whisper models on first call instead of returning 503 when not pre-loaded at startup. This allows switching between Local and OpenAI providers in the Flutter test page without requiring server restart. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
7b71a4f2fc
commit
4456550393
|
|
@ -30,7 +30,7 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
|
||||||
bool _playerReady = false;
|
bool _playerReady = false;
|
||||||
bool _recorderReady = false;
|
bool _recorderReady = false;
|
||||||
|
|
||||||
_VoiceProvider _provider = _VoiceProvider.local;
|
_VoiceProvider _provider = _VoiceProvider.openai;
|
||||||
|
|
||||||
String _ttsStatus = '';
|
String _ttsStatus = '';
|
||||||
String _sttStatus = '';
|
String _sttStatus = '';
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,73 @@ import io
|
||||||
import os
|
import os
|
||||||
import struct
|
import struct
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import threading
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from fastapi import APIRouter, Request, Query, UploadFile, File
|
from fastapi import APIRouter, Request, Query, UploadFile, File
|
||||||
from fastapi.responses import HTMLResponse, Response
|
from fastapi.responses import HTMLResponse, Response
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
_loading_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_local_tts(app) -> bool:
|
||||||
|
"""Lazy-load Kokoro TTS if not yet loaded. Returns True if available."""
|
||||||
|
tts = getattr(app.state, "tts", None)
|
||||||
|
if tts is not None and tts._pipeline is not None:
|
||||||
|
return True
|
||||||
|
with _loading_lock:
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
tts = getattr(app.state, "tts", None)
|
||||||
|
if tts is not None and tts._pipeline is not None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
from ..tts.kokoro_service import KokoroTTSService, _patch_misaki_compat
|
||||||
|
from ..config.settings import settings
|
||||||
|
_patch_misaki_compat()
|
||||||
|
from kokoro import KPipeline
|
||||||
|
svc = KokoroTTSService(model=settings.kokoro_model, voice=settings.kokoro_voice)
|
||||||
|
svc._pipeline = KPipeline(lang_code='z')
|
||||||
|
app.state.tts = svc
|
||||||
|
print(f"[lazy] Local TTS loaded: {settings.kokoro_model} voice={settings.kokoro_voice}", flush=True)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[lazy] Failed to load local TTS: {e}", flush=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_local_stt(app) -> bool:
|
||||||
|
"""Lazy-load faster-whisper STT if not yet loaded. Returns True if available."""
|
||||||
|
stt = getattr(app.state, "stt", None)
|
||||||
|
if stt is not None and stt._model is not None:
|
||||||
|
return True
|
||||||
|
with _loading_lock:
|
||||||
|
# Double-check after acquiring lock
|
||||||
|
stt = getattr(app.state, "stt", None)
|
||||||
|
if stt is not None and stt._model is not None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
from ..stt.whisper_service import WhisperSTTService
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
from ..config.settings import settings
|
||||||
|
svc = WhisperSTTService(
|
||||||
|
model=settings.whisper_model,
|
||||||
|
device=settings.device,
|
||||||
|
language=settings.whisper_language,
|
||||||
|
)
|
||||||
|
compute_type = "float16" if settings.device == "cuda" else "int8"
|
||||||
|
try:
|
||||||
|
svc._model = WhisperModel(svc.model_name, device=svc.device, compute_type=compute_type)
|
||||||
|
except Exception:
|
||||||
|
if svc.device != "cpu":
|
||||||
|
svc._model = WhisperModel(svc.model_name, device="cpu", compute_type="int8")
|
||||||
|
app.state.stt = svc
|
||||||
|
print(f"[lazy] Local STT loaded: {settings.whisper_model}", flush=True)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[lazy] Failed to load local STT: {e}", flush=True)
|
||||||
|
return False
|
||||||
|
|
||||||
_SAMPLE_RATE = 24000 # Kokoro native output rate
|
_SAMPLE_RATE = 24000 # Kokoro native output rate
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -210,10 +271,11 @@ async function doRoundTrip(blob) {
|
||||||
|
|
||||||
@router.get("/tts/synthesize")
|
@router.get("/tts/synthesize")
|
||||||
async def tts_synthesize(request: Request, text: str = Query(..., min_length=1, max_length=500)):
|
async def tts_synthesize(request: Request, text: str = Query(..., min_length=1, max_length=500)):
|
||||||
"""Synthesize text to WAV audio."""
|
"""Synthesize text to WAV audio (lazy-loads local Kokoro model on first call)."""
|
||||||
tts = getattr(request.app.state, "tts", None)
|
loaded = await asyncio.get_event_loop().run_in_executor(None, _ensure_local_tts, request.app)
|
||||||
if tts is None or tts._pipeline is None:
|
if not loaded:
|
||||||
return Response(content="TTS model not loaded", status_code=503)
|
return Response(content="Failed to load local TTS model", status_code=503)
|
||||||
|
tts = request.app.state.tts
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
def _synth():
|
def _synth():
|
||||||
|
|
@ -242,10 +304,11 @@ async def tts_synthesize(request: Request, text: str = Query(..., min_length=1,
|
||||||
|
|
||||||
@router.post("/stt/transcribe")
|
@router.post("/stt/transcribe")
|
||||||
async def stt_transcribe(request: Request, audio: UploadFile = File(...)):
|
async def stt_transcribe(request: Request, audio: UploadFile = File(...)):
|
||||||
"""Transcribe uploaded audio to text via faster-whisper."""
|
"""Transcribe uploaded audio to text via faster-whisper (lazy-loads on first call)."""
|
||||||
stt = getattr(request.app.state, "stt", None)
|
loaded = await asyncio.get_event_loop().run_in_executor(None, _ensure_local_stt, request.app)
|
||||||
if stt is None or stt._model is None:
|
if not loaded:
|
||||||
return {"error": "STT model not loaded", "text": ""}
|
return {"error": "Failed to load local STT model", "text": ""}
|
||||||
|
stt = request.app.state.stt
|
||||||
|
|
||||||
# Save uploaded file to temp
|
# Save uploaded file to temp
|
||||||
raw = await audio.read()
|
raw = await audio.read()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue