feat: add OpenAI TTS/STT provider support in voice pipeline
- Add STT_PROVIDER/TTS_PROVIDER config (local or openai) in settings - Pipeline uses OpenAI API for STT/TTS when provider is "openai" - Skip loading local models (Kokoro/faster-whisper) when using OpenAI - VAD (Silero) always loads for speech detection Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f7d39d8544
commit
c02c2a9a11
|
|
@ -327,8 +327,13 @@ services:
|
|||
- KOKORO_MODEL=${KOKORO_MODEL:-kokoro-82m}
|
||||
- KOKORO_VOICE=${KOKORO_VOICE:-zf_xiaoxiao}
|
||||
- DEVICE=${VOICE_DEVICE:-cpu}
|
||||
- STT_PROVIDER=${STT_PROVIDER:-local}
|
||||
- TTS_PROVIDER=${TTS_PROVIDER:-local}
|
||||
- OPENAI_API_KEY=${OPENAI_API_KEY}
|
||||
- OPENAI_BASE_URL=${OPENAI_BASE_URL}
|
||||
- OPENAI_STT_MODEL=${OPENAI_STT_MODEL:-gpt-4o-transcribe}
|
||||
- OPENAI_TTS_MODEL=${OPENAI_TTS_MODEL:-tts-1}
|
||||
- OPENAI_TTS_VOICE=${OPENAI_TTS_VOICE:-alloy}
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "python3 -c \"import urllib.request; urllib.request.urlopen('http://localhost:3008/docs')\""]
|
||||
interval: 30s
|
||||
|
|
|
|||
|
|
@ -81,9 +81,28 @@ def _load_models_sync():
|
|||
def _p(msg):
|
||||
print(msg, flush=True)
|
||||
|
||||
_p(f"[bg] Loading models (device={settings.device}, whisper={settings.whisper_model})...")
|
||||
_p(f"[bg] Loading models (stt={settings.stt_provider}, tts={settings.tts_provider}, device={settings.device})...")
|
||||
|
||||
# STT
|
||||
# OpenAI client (shared by STT and TTS if either uses OpenAI)
|
||||
if settings.stt_provider == "openai" or settings.tts_provider == "openai":
|
||||
try:
|
||||
from openai import OpenAI
|
||||
import httpx as _httpx
|
||||
|
||||
kwargs = {"api_key": settings.openai_api_key}
|
||||
if settings.openai_base_url:
|
||||
kwargs["base_url"] = settings.openai_base_url
|
||||
kwargs["http_client"] = _httpx.Client(verify=False)
|
||||
app.state.openai_client = OpenAI(**kwargs)
|
||||
_p(f"[bg] OpenAI client initialized (stt_model={settings.openai_stt_model}, tts_model={settings.openai_tts_model})")
|
||||
except Exception as e:
|
||||
app.state.openai_client = None
|
||||
_p(f"[bg] WARNING: OpenAI client init failed: {e}")
|
||||
else:
|
||||
app.state.openai_client = None
|
||||
|
||||
# STT — only load local model if provider is "local"
|
||||
if settings.stt_provider == "local":
|
||||
try:
|
||||
from ..stt.whisper_service import WhisperSTTService
|
||||
from faster_whisper import WhisperModel
|
||||
|
|
@ -105,8 +124,12 @@ def _load_models_sync():
|
|||
except Exception as e:
|
||||
app.state.stt = None
|
||||
_p(f"[bg] WARNING: STT failed: {e}")
|
||||
else:
|
||||
app.state.stt = None
|
||||
_p(f"[bg] STT: using OpenAI ({settings.openai_stt_model})")
|
||||
|
||||
# TTS
|
||||
# TTS — only load local model if provider is "local"
|
||||
if settings.tts_provider == "local":
|
||||
try:
|
||||
from ..tts.kokoro_service import KokoroTTSService, _patch_misaki_compat
|
||||
|
||||
|
|
@ -120,8 +143,11 @@ def _load_models_sync():
|
|||
except Exception as e:
|
||||
app.state.tts = None
|
||||
_p(f"[bg] WARNING: TTS failed: {e}")
|
||||
else:
|
||||
app.state.tts = None
|
||||
_p(f"[bg] TTS: using OpenAI ({settings.openai_tts_model}, voice={settings.openai_tts_voice})")
|
||||
|
||||
# VAD
|
||||
# VAD — always load (needed for speech detection regardless of provider)
|
||||
try:
|
||||
from ..vad.silero_service import SileroVADService
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -231,6 +231,7 @@ async def voice_websocket(websocket: WebSocket, session_id: str):
|
|||
stt=getattr(app.state, "stt", None),
|
||||
tts=getattr(app.state, "tts", None),
|
||||
vad=getattr(app.state, "vad", None),
|
||||
openai_client=getattr(app.state, "openai_client", None),
|
||||
)
|
||||
|
||||
# Run the pipeline task in the background
|
||||
|
|
|
|||
|
|
@ -17,6 +17,10 @@ class Settings(BaseSettings):
|
|||
# Agent Service
|
||||
agent_service_url: str = "http://agent-service:3002"
|
||||
|
||||
# Voice provider: "local" (Kokoro+faster-whisper) or "openai"
|
||||
stt_provider: str = "local" # "local" or "openai"
|
||||
tts_provider: str = "local" # "local" or "openai"
|
||||
|
||||
# STT (faster-whisper)
|
||||
whisper_model: str = "base"
|
||||
whisper_language: str = "zh"
|
||||
|
|
@ -25,6 +29,13 @@ class Settings(BaseSettings):
|
|||
kokoro_model: str = "kokoro-82m"
|
||||
kokoro_voice: str = "zf_xiaoxiao"
|
||||
|
||||
# OpenAI voice
|
||||
openai_api_key: str = ""
|
||||
openai_base_url: str = ""
|
||||
openai_stt_model: str = "gpt-4o-transcribe"
|
||||
openai_tts_model: str = "tts-1"
|
||||
openai_tts_voice: str = "alloy"
|
||||
|
||||
# Device (cpu or cuda)
|
||||
device: str = "cpu"
|
||||
|
||||
|
|
|
|||
|
|
@ -10,9 +10,13 @@ back as speech via TTS.
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
import tempfile
|
||||
import time
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
|
|
@ -60,12 +64,14 @@ class VoicePipelineTask:
|
|||
stt: Optional[WhisperSTTService] = None,
|
||||
tts: Optional[KokoroTTSService] = None,
|
||||
vad: Optional[SileroVADService] = None,
|
||||
openai_client=None,
|
||||
):
|
||||
self.websocket = websocket
|
||||
self.session_context = session_context
|
||||
self.stt = stt
|
||||
self.tts = tts
|
||||
self.vad = vad
|
||||
self.openai_client = openai_client
|
||||
|
||||
# Agent session ID (reused across turns for conversation continuity)
|
||||
self._agent_session_id: Optional[str] = None
|
||||
|
|
@ -272,7 +278,10 @@ class VoicePipelineTask:
|
|||
print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Stream+TTS={total_ms}ms (first audio at {first_audio_ms or 0}ms, {tts_count} chunks) =====", flush=True)
|
||||
|
||||
async def _transcribe(self, audio_data: bytes) -> str:
|
||||
"""Transcribe audio using STT service."""
|
||||
"""Transcribe audio using STT service (local or OpenAI)."""
|
||||
if settings.stt_provider == "openai":
|
||||
return await self._transcribe_openai(audio_data)
|
||||
|
||||
if self.stt is None or self.stt._model is None:
|
||||
logger.warning("STT not available")
|
||||
return ""
|
||||
|
|
@ -282,6 +291,35 @@ class VoicePipelineTask:
|
|||
logger.error("STT error: %s", exc)
|
||||
return ""
|
||||
|
||||
async def _transcribe_openai(self, audio_data: bytes) -> str:
|
||||
"""Transcribe audio using OpenAI API."""
|
||||
if self.openai_client is None:
|
||||
logger.warning("OpenAI client not available for STT")
|
||||
return ""
|
||||
try:
|
||||
# Write PCM to a temp WAV file for the API
|
||||
wav_bytes = _pcm_to_wav(audio_data, _SAMPLE_RATE)
|
||||
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||||
tmp.write(wav_bytes)
|
||||
tmp.close()
|
||||
|
||||
client = self.openai_client
|
||||
model = settings.openai_stt_model
|
||||
|
||||
def _do():
|
||||
with open(tmp.name, "rb") as f:
|
||||
result = client.audio.transcriptions.create(
|
||||
model=model, file=f, language="zh",
|
||||
)
|
||||
return result.text
|
||||
|
||||
text = await asyncio.get_event_loop().run_in_executor(None, _do)
|
||||
os.unlink(tmp.name)
|
||||
return text or ""
|
||||
except Exception as exc:
|
||||
logger.error("OpenAI STT error: %s", exc)
|
||||
return ""
|
||||
|
||||
async def _agent_stream(self, user_text: str) -> AsyncGenerator[str, None]:
|
||||
"""Stream text from agent-service, yielding chunks as they arrive.
|
||||
|
||||
|
|
@ -406,12 +444,15 @@ class VoicePipelineTask:
|
|||
|
||||
async def _synthesize_chunk(self, text: str):
|
||||
"""Synthesize a single sentence/chunk and send audio over WebSocket."""
|
||||
if self.tts is None or self.tts._pipeline is None:
|
||||
return
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
try:
|
||||
if settings.tts_provider == "openai":
|
||||
audio_bytes = await self._tts_openai(text)
|
||||
else:
|
||||
if self.tts is None or self.tts._pipeline is None:
|
||||
return
|
||||
audio_bytes = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self._tts_sync, text
|
||||
)
|
||||
|
|
@ -433,8 +474,38 @@ class VoicePipelineTask:
|
|||
except Exception as exc:
|
||||
logger.error("TTS chunk error: %s", exc)
|
||||
|
||||
async def _tts_openai(self, text: str) -> bytes:
|
||||
"""Synthesize text to 16kHz PCM via OpenAI TTS API."""
|
||||
if self.openai_client is None:
|
||||
logger.warning("OpenAI client not available for TTS")
|
||||
return b""
|
||||
try:
|
||||
client = self.openai_client
|
||||
model = settings.openai_tts_model
|
||||
voice = settings.openai_tts_voice
|
||||
|
||||
def _do():
|
||||
response = client.audio.speech.create(
|
||||
model=model, voice=voice, input=text,
|
||||
response_format="pcm", # raw 24kHz 16-bit mono PCM
|
||||
)
|
||||
return response.content
|
||||
|
||||
pcm_24k = await asyncio.get_event_loop().run_in_executor(None, _do)
|
||||
# OpenAI returns 24kHz 16-bit mono PCM, resample to 16kHz
|
||||
audio_np = np.frombuffer(pcm_24k, dtype=np.int16).astype(np.float32)
|
||||
if len(audio_np) > 0:
|
||||
target_samples = int(len(audio_np) / 24000 * _SAMPLE_RATE)
|
||||
indices = np.linspace(0, len(audio_np) - 1, target_samples)
|
||||
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
|
||||
return resampled.astype(np.int16).tobytes()
|
||||
return b""
|
||||
except Exception as exc:
|
||||
logger.error("OpenAI TTS error: %s", exc)
|
||||
return b""
|
||||
|
||||
def _tts_sync(self, text: str) -> bytes:
|
||||
"""Synchronous TTS synthesis (runs in thread pool)."""
|
||||
"""Synchronous TTS synthesis with Kokoro (runs in thread pool)."""
|
||||
try:
|
||||
samples = []
|
||||
for _, _, audio in self.tts._pipeline(text, voice=self.tts.voice):
|
||||
|
|
@ -457,6 +528,27 @@ class VoicePipelineTask:
|
|||
return b""
|
||||
|
||||
|
||||
def _pcm_to_wav(pcm_bytes: bytes, sample_rate: int) -> bytes:
|
||||
"""Wrap raw 16-bit mono PCM into a WAV container."""
|
||||
buf = io.BytesIO()
|
||||
data_size = len(pcm_bytes)
|
||||
buf.write(b"RIFF")
|
||||
buf.write(struct.pack("<I", 36 + data_size))
|
||||
buf.write(b"WAVE")
|
||||
buf.write(b"fmt ")
|
||||
buf.write(struct.pack("<I", 16))
|
||||
buf.write(struct.pack("<H", 1)) # PCM
|
||||
buf.write(struct.pack("<H", 1)) # mono
|
||||
buf.write(struct.pack("<I", sample_rate))
|
||||
buf.write(struct.pack("<I", sample_rate * 2)) # byte rate
|
||||
buf.write(struct.pack("<H", 2)) # block align
|
||||
buf.write(struct.pack("<H", 16)) # bits per sample
|
||||
buf.write(b"data")
|
||||
buf.write(struct.pack("<I", data_size))
|
||||
buf.write(pcm_bytes)
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
async def create_voice_pipeline(
|
||||
websocket: WebSocket,
|
||||
session_context: dict,
|
||||
|
|
@ -464,23 +556,14 @@ async def create_voice_pipeline(
|
|||
stt=None,
|
||||
tts=None,
|
||||
vad=None,
|
||||
openai_client=None,
|
||||
) -> VoicePipelineTask:
|
||||
"""Create a voice pipeline task for the given WebSocket connection.
|
||||
|
||||
Args:
|
||||
websocket: FastAPI WebSocket connection (already accepted)
|
||||
session_context: Session metadata dict
|
||||
stt: Pre-initialized STT service
|
||||
tts: Pre-initialized TTS service
|
||||
vad: Pre-initialized VAD service
|
||||
|
||||
Returns:
|
||||
VoicePipelineTask ready to run
|
||||
"""
|
||||
"""Create a voice pipeline task for the given WebSocket connection."""
|
||||
return VoicePipelineTask(
|
||||
websocket,
|
||||
session_context,
|
||||
stt=stt,
|
||||
tts=tts,
|
||||
vad=vad,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue