feat: add OpenAI TTS/STT provider support in voice pipeline

- Add STT_PROVIDER/TTS_PROVIDER config (local or openai) in settings
- Pipeline uses OpenAI API for STT/TTS when provider is "openai"
- Skip loading local models (Kokoro/faster-whisper) when using OpenAI
- VAD (Silero) always loads for speech detection

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-02-24 09:27:38 -08:00
parent f7d39d8544
commit c02c2a9a11
5 changed files with 177 additions and 51 deletions

View File

@ -327,8 +327,13 @@ services:
- KOKORO_MODEL=${KOKORO_MODEL:-kokoro-82m}
- KOKORO_VOICE=${KOKORO_VOICE:-zf_xiaoxiao}
- DEVICE=${VOICE_DEVICE:-cpu}
- STT_PROVIDER=${STT_PROVIDER:-local}
- TTS_PROVIDER=${TTS_PROVIDER:-local}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- OPENAI_BASE_URL=${OPENAI_BASE_URL}
- OPENAI_STT_MODEL=${OPENAI_STT_MODEL:-gpt-4o-transcribe}
- OPENAI_TTS_MODEL=${OPENAI_TTS_MODEL:-tts-1}
- OPENAI_TTS_VOICE=${OPENAI_TTS_VOICE:-alloy}
healthcheck:
test: ["CMD-SHELL", "python3 -c \"import urllib.request; urllib.request.urlopen('http://localhost:3008/docs')\""]
interval: 30s

View File

@ -81,47 +81,73 @@ def _load_models_sync():
def _p(msg):
print(msg, flush=True)
_p(f"[bg] Loading models (device={settings.device}, whisper={settings.whisper_model})...")
_p(f"[bg] Loading models (stt={settings.stt_provider}, tts={settings.tts_provider}, device={settings.device})...")
# STT
try:
from ..stt.whisper_service import WhisperSTTService
from faster_whisper import WhisperModel
stt = WhisperSTTService(
model=settings.whisper_model,
device=settings.device,
language=settings.whisper_language,
)
compute_type = "float16" if settings.device == "cuda" else "int8"
# OpenAI client (shared by STT and TTS if either uses OpenAI)
if settings.stt_provider == "openai" or settings.tts_provider == "openai":
try:
stt._model = WhisperModel(stt.model_name, device=stt.device, compute_type=compute_type)
from openai import OpenAI
import httpx as _httpx
kwargs = {"api_key": settings.openai_api_key}
if settings.openai_base_url:
kwargs["base_url"] = settings.openai_base_url
kwargs["http_client"] = _httpx.Client(verify=False)
app.state.openai_client = OpenAI(**kwargs)
_p(f"[bg] OpenAI client initialized (stt_model={settings.openai_stt_model}, tts_model={settings.openai_tts_model})")
except Exception as e:
_p(f"[bg] Whisper fallback to CPU: {e}")
if stt.device != "cpu":
stt._model = WhisperModel(stt.model_name, device="cpu", compute_type="int8")
app.state.stt = stt
_p(f"[bg] STT loaded: {settings.whisper_model}")
except Exception as e:
app.state.openai_client = None
_p(f"[bg] WARNING: OpenAI client init failed: {e}")
else:
app.state.openai_client = None
# STT — only load local model if provider is "local"
if settings.stt_provider == "local":
try:
from ..stt.whisper_service import WhisperSTTService
from faster_whisper import WhisperModel
stt = WhisperSTTService(
model=settings.whisper_model,
device=settings.device,
language=settings.whisper_language,
)
compute_type = "float16" if settings.device == "cuda" else "int8"
try:
stt._model = WhisperModel(stt.model_name, device=stt.device, compute_type=compute_type)
except Exception as e:
_p(f"[bg] Whisper fallback to CPU: {e}")
if stt.device != "cpu":
stt._model = WhisperModel(stt.model_name, device="cpu", compute_type="int8")
app.state.stt = stt
_p(f"[bg] STT loaded: {settings.whisper_model}")
except Exception as e:
app.state.stt = None
_p(f"[bg] WARNING: STT failed: {e}")
else:
app.state.stt = None
_p(f"[bg] WARNING: STT failed: {e}")
_p(f"[bg] STT: using OpenAI ({settings.openai_stt_model})")
# TTS
try:
from ..tts.kokoro_service import KokoroTTSService, _patch_misaki_compat
# TTS — only load local model if provider is "local"
if settings.tts_provider == "local":
try:
from ..tts.kokoro_service import KokoroTTSService, _patch_misaki_compat
_patch_misaki_compat()
from kokoro import KPipeline
_patch_misaki_compat()
from kokoro import KPipeline
tts = KokoroTTSService(model=settings.kokoro_model, voice=settings.kokoro_voice)
tts._pipeline = KPipeline(lang_code='z')
app.state.tts = tts
_p(f"[bg] TTS loaded: {settings.kokoro_model} voice={settings.kokoro_voice}")
except Exception as e:
tts = KokoroTTSService(model=settings.kokoro_model, voice=settings.kokoro_voice)
tts._pipeline = KPipeline(lang_code='z')
app.state.tts = tts
_p(f"[bg] TTS loaded: {settings.kokoro_model} voice={settings.kokoro_voice}")
except Exception as e:
app.state.tts = None
_p(f"[bg] WARNING: TTS failed: {e}")
else:
app.state.tts = None
_p(f"[bg] WARNING: TTS failed: {e}")
_p(f"[bg] TTS: using OpenAI ({settings.openai_tts_model}, voice={settings.openai_tts_voice})")
# VAD
# VAD — always load (needed for speech detection regardless of provider)
try:
from ..vad.silero_service import SileroVADService
import torch

View File

@ -231,6 +231,7 @@ async def voice_websocket(websocket: WebSocket, session_id: str):
stt=getattr(app.state, "stt", None),
tts=getattr(app.state, "tts", None),
vad=getattr(app.state, "vad", None),
openai_client=getattr(app.state, "openai_client", None),
)
# Run the pipeline task in the background

View File

@ -17,6 +17,10 @@ class Settings(BaseSettings):
# Agent Service
agent_service_url: str = "http://agent-service:3002"
# Voice provider: "local" (Kokoro+faster-whisper) or "openai"
stt_provider: str = "local" # "local" or "openai"
tts_provider: str = "local" # "local" or "openai"
# STT (faster-whisper)
whisper_model: str = "base"
whisper_language: str = "zh"
@ -25,6 +29,13 @@ class Settings(BaseSettings):
kokoro_model: str = "kokoro-82m"
kokoro_voice: str = "zf_xiaoxiao"
# OpenAI voice
openai_api_key: str = ""
openai_base_url: str = ""
openai_stt_model: str = "gpt-4o-transcribe"
openai_tts_model: str = "tts-1"
openai_tts_voice: str = "alloy"
# Device (cpu or cuda)
device: str = "cpu"

View File

@ -10,9 +10,13 @@ back as speech via TTS.
"""
import asyncio
import io
import json
import logging
import os
import re
import struct
import tempfile
import time
from typing import AsyncGenerator, Optional
@ -60,12 +64,14 @@ class VoicePipelineTask:
stt: Optional[WhisperSTTService] = None,
tts: Optional[KokoroTTSService] = None,
vad: Optional[SileroVADService] = None,
openai_client=None,
):
self.websocket = websocket
self.session_context = session_context
self.stt = stt
self.tts = tts
self.vad = vad
self.openai_client = openai_client
# Agent session ID (reused across turns for conversation continuity)
self._agent_session_id: Optional[str] = None
@ -272,7 +278,10 @@ class VoicePipelineTask:
print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Stream+TTS={total_ms}ms (first audio at {first_audio_ms or 0}ms, {tts_count} chunks) =====", flush=True)
async def _transcribe(self, audio_data: bytes) -> str:
"""Transcribe audio using STT service."""
"""Transcribe audio using STT service (local or OpenAI)."""
if settings.stt_provider == "openai":
return await self._transcribe_openai(audio_data)
if self.stt is None or self.stt._model is None:
logger.warning("STT not available")
return ""
@ -282,6 +291,35 @@ class VoicePipelineTask:
logger.error("STT error: %s", exc)
return ""
async def _transcribe_openai(self, audio_data: bytes) -> str:
"""Transcribe audio using OpenAI API."""
if self.openai_client is None:
logger.warning("OpenAI client not available for STT")
return ""
try:
# Write PCM to a temp WAV file for the API
wav_bytes = _pcm_to_wav(audio_data, _SAMPLE_RATE)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
tmp.write(wav_bytes)
tmp.close()
client = self.openai_client
model = settings.openai_stt_model
def _do():
with open(tmp.name, "rb") as f:
result = client.audio.transcriptions.create(
model=model, file=f, language="zh",
)
return result.text
text = await asyncio.get_event_loop().run_in_executor(None, _do)
os.unlink(tmp.name)
return text or ""
except Exception as exc:
logger.error("OpenAI STT error: %s", exc)
return ""
async def _agent_stream(self, user_text: str) -> AsyncGenerator[str, None]:
"""Stream text from agent-service, yielding chunks as they arrive.
@ -406,15 +444,18 @@ class VoicePipelineTask:
async def _synthesize_chunk(self, text: str):
"""Synthesize a single sentence/chunk and send audio over WebSocket."""
if self.tts is None or self.tts._pipeline is None:
return
if not text.strip():
return
try:
audio_bytes = await asyncio.get_event_loop().run_in_executor(
None, self._tts_sync, text
)
if settings.tts_provider == "openai":
audio_bytes = await self._tts_openai(text)
else:
if self.tts is None or self.tts._pipeline is None:
return
audio_bytes = await asyncio.get_event_loop().run_in_executor(
None, self._tts_sync, text
)
if not audio_bytes or self._cancelled_tts:
return
@ -433,8 +474,38 @@ class VoicePipelineTask:
except Exception as exc:
logger.error("TTS chunk error: %s", exc)
async def _tts_openai(self, text: str) -> bytes:
"""Synthesize text to 16kHz PCM via OpenAI TTS API."""
if self.openai_client is None:
logger.warning("OpenAI client not available for TTS")
return b""
try:
client = self.openai_client
model = settings.openai_tts_model
voice = settings.openai_tts_voice
def _do():
response = client.audio.speech.create(
model=model, voice=voice, input=text,
response_format="pcm", # raw 24kHz 16-bit mono PCM
)
return response.content
pcm_24k = await asyncio.get_event_loop().run_in_executor(None, _do)
# OpenAI returns 24kHz 16-bit mono PCM, resample to 16kHz
audio_np = np.frombuffer(pcm_24k, dtype=np.int16).astype(np.float32)
if len(audio_np) > 0:
target_samples = int(len(audio_np) / 24000 * _SAMPLE_RATE)
indices = np.linspace(0, len(audio_np) - 1, target_samples)
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
return resampled.astype(np.int16).tobytes()
return b""
except Exception as exc:
logger.error("OpenAI TTS error: %s", exc)
return b""
def _tts_sync(self, text: str) -> bytes:
"""Synchronous TTS synthesis (runs in thread pool)."""
"""Synchronous TTS synthesis with Kokoro (runs in thread pool)."""
try:
samples = []
for _, _, audio in self.tts._pipeline(text, voice=self.tts.voice):
@ -457,6 +528,27 @@ class VoicePipelineTask:
return b""
def _pcm_to_wav(pcm_bytes: bytes, sample_rate: int) -> bytes:
"""Wrap raw 16-bit mono PCM into a WAV container."""
buf = io.BytesIO()
data_size = len(pcm_bytes)
buf.write(b"RIFF")
buf.write(struct.pack("<I", 36 + data_size))
buf.write(b"WAVE")
buf.write(b"fmt ")
buf.write(struct.pack("<I", 16))
buf.write(struct.pack("<H", 1)) # PCM
buf.write(struct.pack("<H", 1)) # mono
buf.write(struct.pack("<I", sample_rate))
buf.write(struct.pack("<I", sample_rate * 2)) # byte rate
buf.write(struct.pack("<H", 2)) # block align
buf.write(struct.pack("<H", 16)) # bits per sample
buf.write(b"data")
buf.write(struct.pack("<I", data_size))
buf.write(pcm_bytes)
return buf.getvalue()
async def create_voice_pipeline(
websocket: WebSocket,
session_context: dict,
@ -464,23 +556,14 @@ async def create_voice_pipeline(
stt=None,
tts=None,
vad=None,
openai_client=None,
) -> VoicePipelineTask:
"""Create a voice pipeline task for the given WebSocket connection.
Args:
websocket: FastAPI WebSocket connection (already accepted)
session_context: Session metadata dict
stt: Pre-initialized STT service
tts: Pre-initialized TTS service
vad: Pre-initialized VAD service
Returns:
VoicePipelineTask ready to run
"""
"""Create a voice pipeline task for the given WebSocket connection."""
return VoicePipelineTask(
websocket,
session_context,
stt=stt,
tts=tts,
vad=vad,
openai_client=openai_client,
)