feat: add OpenAI TTS/STT provider support in voice pipeline

- Add STT_PROVIDER/TTS_PROVIDER config (local or openai) in settings
- Pipeline uses OpenAI API for STT/TTS when provider is "openai"
- Skip loading local models (Kokoro/faster-whisper) when using OpenAI
- VAD (Silero) always loads for speech detection

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-02-24 09:27:38 -08:00
parent f7d39d8544
commit c02c2a9a11
5 changed files with 177 additions and 51 deletions

View File

@ -327,8 +327,13 @@ services:
- KOKORO_MODEL=${KOKORO_MODEL:-kokoro-82m} - KOKORO_MODEL=${KOKORO_MODEL:-kokoro-82m}
- KOKORO_VOICE=${KOKORO_VOICE:-zf_xiaoxiao} - KOKORO_VOICE=${KOKORO_VOICE:-zf_xiaoxiao}
- DEVICE=${VOICE_DEVICE:-cpu} - DEVICE=${VOICE_DEVICE:-cpu}
- STT_PROVIDER=${STT_PROVIDER:-local}
- TTS_PROVIDER=${TTS_PROVIDER:-local}
- OPENAI_API_KEY=${OPENAI_API_KEY} - OPENAI_API_KEY=${OPENAI_API_KEY}
- OPENAI_BASE_URL=${OPENAI_BASE_URL} - OPENAI_BASE_URL=${OPENAI_BASE_URL}
- OPENAI_STT_MODEL=${OPENAI_STT_MODEL:-gpt-4o-transcribe}
- OPENAI_TTS_MODEL=${OPENAI_TTS_MODEL:-tts-1}
- OPENAI_TTS_VOICE=${OPENAI_TTS_VOICE:-alloy}
healthcheck: healthcheck:
test: ["CMD-SHELL", "python3 -c \"import urllib.request; urllib.request.urlopen('http://localhost:3008/docs')\""] test: ["CMD-SHELL", "python3 -c \"import urllib.request; urllib.request.urlopen('http://localhost:3008/docs')\""]
interval: 30s interval: 30s

View File

@ -81,9 +81,28 @@ def _load_models_sync():
def _p(msg): def _p(msg):
print(msg, flush=True) print(msg, flush=True)
_p(f"[bg] Loading models (device={settings.device}, whisper={settings.whisper_model})...") _p(f"[bg] Loading models (stt={settings.stt_provider}, tts={settings.tts_provider}, device={settings.device})...")
# STT # OpenAI client (shared by STT and TTS if either uses OpenAI)
if settings.stt_provider == "openai" or settings.tts_provider == "openai":
try:
from openai import OpenAI
import httpx as _httpx
kwargs = {"api_key": settings.openai_api_key}
if settings.openai_base_url:
kwargs["base_url"] = settings.openai_base_url
kwargs["http_client"] = _httpx.Client(verify=False)
app.state.openai_client = OpenAI(**kwargs)
_p(f"[bg] OpenAI client initialized (stt_model={settings.openai_stt_model}, tts_model={settings.openai_tts_model})")
except Exception as e:
app.state.openai_client = None
_p(f"[bg] WARNING: OpenAI client init failed: {e}")
else:
app.state.openai_client = None
# STT — only load local model if provider is "local"
if settings.stt_provider == "local":
try: try:
from ..stt.whisper_service import WhisperSTTService from ..stt.whisper_service import WhisperSTTService
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
@ -105,8 +124,12 @@ def _load_models_sync():
except Exception as e: except Exception as e:
app.state.stt = None app.state.stt = None
_p(f"[bg] WARNING: STT failed: {e}") _p(f"[bg] WARNING: STT failed: {e}")
else:
app.state.stt = None
_p(f"[bg] STT: using OpenAI ({settings.openai_stt_model})")
# TTS # TTS — only load local model if provider is "local"
if settings.tts_provider == "local":
try: try:
from ..tts.kokoro_service import KokoroTTSService, _patch_misaki_compat from ..tts.kokoro_service import KokoroTTSService, _patch_misaki_compat
@ -120,8 +143,11 @@ def _load_models_sync():
except Exception as e: except Exception as e:
app.state.tts = None app.state.tts = None
_p(f"[bg] WARNING: TTS failed: {e}") _p(f"[bg] WARNING: TTS failed: {e}")
else:
app.state.tts = None
_p(f"[bg] TTS: using OpenAI ({settings.openai_tts_model}, voice={settings.openai_tts_voice})")
# VAD # VAD — always load (needed for speech detection regardless of provider)
try: try:
from ..vad.silero_service import SileroVADService from ..vad.silero_service import SileroVADService
import torch import torch

View File

@ -231,6 +231,7 @@ async def voice_websocket(websocket: WebSocket, session_id: str):
stt=getattr(app.state, "stt", None), stt=getattr(app.state, "stt", None),
tts=getattr(app.state, "tts", None), tts=getattr(app.state, "tts", None),
vad=getattr(app.state, "vad", None), vad=getattr(app.state, "vad", None),
openai_client=getattr(app.state, "openai_client", None),
) )
# Run the pipeline task in the background # Run the pipeline task in the background

View File

@ -17,6 +17,10 @@ class Settings(BaseSettings):
# Agent Service # Agent Service
agent_service_url: str = "http://agent-service:3002" agent_service_url: str = "http://agent-service:3002"
# Voice provider: "local" (Kokoro+faster-whisper) or "openai"
stt_provider: str = "local" # "local" or "openai"
tts_provider: str = "local" # "local" or "openai"
# STT (faster-whisper) # STT (faster-whisper)
whisper_model: str = "base" whisper_model: str = "base"
whisper_language: str = "zh" whisper_language: str = "zh"
@ -25,6 +29,13 @@ class Settings(BaseSettings):
kokoro_model: str = "kokoro-82m" kokoro_model: str = "kokoro-82m"
kokoro_voice: str = "zf_xiaoxiao" kokoro_voice: str = "zf_xiaoxiao"
# OpenAI voice
openai_api_key: str = ""
openai_base_url: str = ""
openai_stt_model: str = "gpt-4o-transcribe"
openai_tts_model: str = "tts-1"
openai_tts_voice: str = "alloy"
# Device (cpu or cuda) # Device (cpu or cuda)
device: str = "cpu" device: str = "cpu"

View File

@ -10,9 +10,13 @@ back as speech via TTS.
""" """
import asyncio import asyncio
import io
import json import json
import logging import logging
import os
import re import re
import struct
import tempfile
import time import time
from typing import AsyncGenerator, Optional from typing import AsyncGenerator, Optional
@ -60,12 +64,14 @@ class VoicePipelineTask:
stt: Optional[WhisperSTTService] = None, stt: Optional[WhisperSTTService] = None,
tts: Optional[KokoroTTSService] = None, tts: Optional[KokoroTTSService] = None,
vad: Optional[SileroVADService] = None, vad: Optional[SileroVADService] = None,
openai_client=None,
): ):
self.websocket = websocket self.websocket = websocket
self.session_context = session_context self.session_context = session_context
self.stt = stt self.stt = stt
self.tts = tts self.tts = tts
self.vad = vad self.vad = vad
self.openai_client = openai_client
# Agent session ID (reused across turns for conversation continuity) # Agent session ID (reused across turns for conversation continuity)
self._agent_session_id: Optional[str] = None self._agent_session_id: Optional[str] = None
@ -272,7 +278,10 @@ class VoicePipelineTask:
print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Stream+TTS={total_ms}ms (first audio at {first_audio_ms or 0}ms, {tts_count} chunks) =====", flush=True) print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Stream+TTS={total_ms}ms (first audio at {first_audio_ms or 0}ms, {tts_count} chunks) =====", flush=True)
async def _transcribe(self, audio_data: bytes) -> str: async def _transcribe(self, audio_data: bytes) -> str:
"""Transcribe audio using STT service.""" """Transcribe audio using STT service (local or OpenAI)."""
if settings.stt_provider == "openai":
return await self._transcribe_openai(audio_data)
if self.stt is None or self.stt._model is None: if self.stt is None or self.stt._model is None:
logger.warning("STT not available") logger.warning("STT not available")
return "" return ""
@ -282,6 +291,35 @@ class VoicePipelineTask:
logger.error("STT error: %s", exc) logger.error("STT error: %s", exc)
return "" return ""
async def _transcribe_openai(self, audio_data: bytes) -> str:
"""Transcribe audio using OpenAI API."""
if self.openai_client is None:
logger.warning("OpenAI client not available for STT")
return ""
try:
# Write PCM to a temp WAV file for the API
wav_bytes = _pcm_to_wav(audio_data, _SAMPLE_RATE)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
tmp.write(wav_bytes)
tmp.close()
client = self.openai_client
model = settings.openai_stt_model
def _do():
with open(tmp.name, "rb") as f:
result = client.audio.transcriptions.create(
model=model, file=f, language="zh",
)
return result.text
text = await asyncio.get_event_loop().run_in_executor(None, _do)
os.unlink(tmp.name)
return text or ""
except Exception as exc:
logger.error("OpenAI STT error: %s", exc)
return ""
async def _agent_stream(self, user_text: str) -> AsyncGenerator[str, None]: async def _agent_stream(self, user_text: str) -> AsyncGenerator[str, None]:
"""Stream text from agent-service, yielding chunks as they arrive. """Stream text from agent-service, yielding chunks as they arrive.
@ -406,12 +444,15 @@ class VoicePipelineTask:
async def _synthesize_chunk(self, text: str): async def _synthesize_chunk(self, text: str):
"""Synthesize a single sentence/chunk and send audio over WebSocket.""" """Synthesize a single sentence/chunk and send audio over WebSocket."""
if self.tts is None or self.tts._pipeline is None:
return
if not text.strip(): if not text.strip():
return return
try: try:
if settings.tts_provider == "openai":
audio_bytes = await self._tts_openai(text)
else:
if self.tts is None or self.tts._pipeline is None:
return
audio_bytes = await asyncio.get_event_loop().run_in_executor( audio_bytes = await asyncio.get_event_loop().run_in_executor(
None, self._tts_sync, text None, self._tts_sync, text
) )
@ -433,8 +474,38 @@ class VoicePipelineTask:
except Exception as exc: except Exception as exc:
logger.error("TTS chunk error: %s", exc) logger.error("TTS chunk error: %s", exc)
async def _tts_openai(self, text: str) -> bytes:
"""Synthesize text to 16kHz PCM via OpenAI TTS API."""
if self.openai_client is None:
logger.warning("OpenAI client not available for TTS")
return b""
try:
client = self.openai_client
model = settings.openai_tts_model
voice = settings.openai_tts_voice
def _do():
response = client.audio.speech.create(
model=model, voice=voice, input=text,
response_format="pcm", # raw 24kHz 16-bit mono PCM
)
return response.content
pcm_24k = await asyncio.get_event_loop().run_in_executor(None, _do)
# OpenAI returns 24kHz 16-bit mono PCM, resample to 16kHz
audio_np = np.frombuffer(pcm_24k, dtype=np.int16).astype(np.float32)
if len(audio_np) > 0:
target_samples = int(len(audio_np) / 24000 * _SAMPLE_RATE)
indices = np.linspace(0, len(audio_np) - 1, target_samples)
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
return resampled.astype(np.int16).tobytes()
return b""
except Exception as exc:
logger.error("OpenAI TTS error: %s", exc)
return b""
def _tts_sync(self, text: str) -> bytes: def _tts_sync(self, text: str) -> bytes:
"""Synchronous TTS synthesis (runs in thread pool).""" """Synchronous TTS synthesis with Kokoro (runs in thread pool)."""
try: try:
samples = [] samples = []
for _, _, audio in self.tts._pipeline(text, voice=self.tts.voice): for _, _, audio in self.tts._pipeline(text, voice=self.tts.voice):
@ -457,6 +528,27 @@ class VoicePipelineTask:
return b"" return b""
def _pcm_to_wav(pcm_bytes: bytes, sample_rate: int) -> bytes:
"""Wrap raw 16-bit mono PCM into a WAV container."""
buf = io.BytesIO()
data_size = len(pcm_bytes)
buf.write(b"RIFF")
buf.write(struct.pack("<I", 36 + data_size))
buf.write(b"WAVE")
buf.write(b"fmt ")
buf.write(struct.pack("<I", 16))
buf.write(struct.pack("<H", 1)) # PCM
buf.write(struct.pack("<H", 1)) # mono
buf.write(struct.pack("<I", sample_rate))
buf.write(struct.pack("<I", sample_rate * 2)) # byte rate
buf.write(struct.pack("<H", 2)) # block align
buf.write(struct.pack("<H", 16)) # bits per sample
buf.write(b"data")
buf.write(struct.pack("<I", data_size))
buf.write(pcm_bytes)
return buf.getvalue()
async def create_voice_pipeline( async def create_voice_pipeline(
websocket: WebSocket, websocket: WebSocket,
session_context: dict, session_context: dict,
@ -464,23 +556,14 @@ async def create_voice_pipeline(
stt=None, stt=None,
tts=None, tts=None,
vad=None, vad=None,
openai_client=None,
) -> VoicePipelineTask: ) -> VoicePipelineTask:
"""Create a voice pipeline task for the given WebSocket connection. """Create a voice pipeline task for the given WebSocket connection."""
Args:
websocket: FastAPI WebSocket connection (already accepted)
session_context: Session metadata dict
stt: Pre-initialized STT service
tts: Pre-initialized TTS service
vad: Pre-initialized VAD service
Returns:
VoicePipelineTask ready to run
"""
return VoicePipelineTask( return VoicePipelineTask(
websocket, websocket,
session_context, session_context,
stt=stt, stt=stt,
tts=tts, tts=tts,
vad=vad, vad=vad,
openai_client=openai_client,
) )