feat: lazy-load local TTS/STT models on first request

Local /synthesize and /transcribe endpoints now auto-load Kokoro/Whisper
models on first call instead of returning 503 when not pre-loaded at
startup. This allows switching between Local and OpenAI providers in the
Flutter test page without requiring server restart.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-02-25 04:38:49 -08:00
parent 7b71a4f2fc
commit 4456550393
2 changed files with 72 additions and 9 deletions

View File

@ -30,7 +30,7 @@ class _VoiceTestPageState extends ConsumerState<VoiceTestPage> {
bool _playerReady = false; bool _playerReady = false;
bool _recorderReady = false; bool _recorderReady = false;
_VoiceProvider _provider = _VoiceProvider.local; _VoiceProvider _provider = _VoiceProvider.openai;
String _ttsStatus = ''; String _ttsStatus = '';
String _sttStatus = ''; String _sttStatus = '';

View File

@ -5,12 +5,73 @@ import io
import os import os
import struct import struct
import tempfile import tempfile
import threading
import numpy as np import numpy as np
from fastapi import APIRouter, Request, Query, UploadFile, File from fastapi import APIRouter, Request, Query, UploadFile, File
from fastapi.responses import HTMLResponse, Response from fastapi.responses import HTMLResponse, Response
router = APIRouter() router = APIRouter()
_loading_lock = threading.Lock()
def _ensure_local_tts(app) -> bool:
"""Lazy-load Kokoro TTS if not yet loaded. Returns True if available."""
tts = getattr(app.state, "tts", None)
if tts is not None and tts._pipeline is not None:
return True
with _loading_lock:
# Double-check after acquiring lock
tts = getattr(app.state, "tts", None)
if tts is not None and tts._pipeline is not None:
return True
try:
from ..tts.kokoro_service import KokoroTTSService, _patch_misaki_compat
from ..config.settings import settings
_patch_misaki_compat()
from kokoro import KPipeline
svc = KokoroTTSService(model=settings.kokoro_model, voice=settings.kokoro_voice)
svc._pipeline = KPipeline(lang_code='z')
app.state.tts = svc
print(f"[lazy] Local TTS loaded: {settings.kokoro_model} voice={settings.kokoro_voice}", flush=True)
return True
except Exception as e:
print(f"[lazy] Failed to load local TTS: {e}", flush=True)
return False
def _ensure_local_stt(app) -> bool:
"""Lazy-load faster-whisper STT if not yet loaded. Returns True if available."""
stt = getattr(app.state, "stt", None)
if stt is not None and stt._model is not None:
return True
with _loading_lock:
# Double-check after acquiring lock
stt = getattr(app.state, "stt", None)
if stt is not None and stt._model is not None:
return True
try:
from ..stt.whisper_service import WhisperSTTService
from faster_whisper import WhisperModel
from ..config.settings import settings
svc = WhisperSTTService(
model=settings.whisper_model,
device=settings.device,
language=settings.whisper_language,
)
compute_type = "float16" if settings.device == "cuda" else "int8"
try:
svc._model = WhisperModel(svc.model_name, device=svc.device, compute_type=compute_type)
except Exception:
if svc.device != "cpu":
svc._model = WhisperModel(svc.model_name, device="cpu", compute_type="int8")
app.state.stt = svc
print(f"[lazy] Local STT loaded: {settings.whisper_model}", flush=True)
return True
except Exception as e:
print(f"[lazy] Failed to load local STT: {e}", flush=True)
return False
_SAMPLE_RATE = 24000 # Kokoro native output rate _SAMPLE_RATE = 24000 # Kokoro native output rate
@ -210,10 +271,11 @@ async function doRoundTrip(blob) {
@router.get("/tts/synthesize") @router.get("/tts/synthesize")
async def tts_synthesize(request: Request, text: str = Query(..., min_length=1, max_length=500)): async def tts_synthesize(request: Request, text: str = Query(..., min_length=1, max_length=500)):
"""Synthesize text to WAV audio.""" """Synthesize text to WAV audio (lazy-loads local Kokoro model on first call)."""
tts = getattr(request.app.state, "tts", None) loaded = await asyncio.get_event_loop().run_in_executor(None, _ensure_local_tts, request.app)
if tts is None or tts._pipeline is None: if not loaded:
return Response(content="TTS model not loaded", status_code=503) return Response(content="Failed to load local TTS model", status_code=503)
tts = request.app.state.tts
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
def _synth(): def _synth():
@ -242,10 +304,11 @@ async def tts_synthesize(request: Request, text: str = Query(..., min_length=1,
@router.post("/stt/transcribe") @router.post("/stt/transcribe")
async def stt_transcribe(request: Request, audio: UploadFile = File(...)): async def stt_transcribe(request: Request, audio: UploadFile = File(...)):
"""Transcribe uploaded audio to text via faster-whisper.""" """Transcribe uploaded audio to text via faster-whisper (lazy-loads on first call)."""
stt = getattr(request.app.state, "stt", None) loaded = await asyncio.get_event_loop().run_in_executor(None, _ensure_local_stt, request.app)
if stt is None or stt._model is None: if not loaded:
return {"error": "STT model not loaded", "text": ""} return {"error": "Failed to load local STT model", "text": ""}
stt = request.app.state.stt
# Save uploaded file to temp # Save uploaded file to temp
raw = await audio.read() raw = await audio.read()