""" Voice dialogue pipeline — direct WebSocket audio I/O. Pipeline: Audio Input → VAD → STT → Agent Service → TTS → Audio Output Runs as an async task that reads binary PCM frames from a FastAPI WebSocket, detects speech with VAD, transcribes with STT, sends text to the Agent service (which uses Claude SDK with tools), and synthesizes the response back as speech via TTS. """ import asyncio import io import json import logging import os import re import struct import tempfile import time from typing import AsyncGenerator, Optional import httpx import numpy as np import websockets from fastapi import WebSocket from ..config.settings import settings from ..stt.whisper_service import WhisperSTTService from ..tts.kokoro_service import KokoroTTSService from ..vad.silero_service import SileroVADService logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) # Minimum speech duration in seconds before we transcribe _MIN_SPEECH_SECS = 0.5 # Silence duration in seconds after speech ends before we process _SILENCE_AFTER_SPEECH_SECS = 0.8 # Sample rate _SAMPLE_RATE = 16000 # Bytes per sample (16-bit PCM mono) _BYTES_PER_SAMPLE = 2 # VAD chunk size (512 samples = 32ms at 16kHz, Silero expects this) _VAD_CHUNK_SAMPLES = 512 _VAD_CHUNK_BYTES = _VAD_CHUNK_SAMPLES * _BYTES_PER_SAMPLE # Max audio output chunk size sent over WebSocket (4KB) _WS_AUDIO_CHUNK = 4096 # Sentence-ending punctuation for splitting TTS chunks _SENTENCE_END_RE = re.compile(r'[。!?;\n]|[.!?;]\s') # Minimum chars before we flush a sentence to TTS _MIN_SENTENCE_LEN = 4 class VoicePipelineTask: """Async voice pipeline that bridges a FastAPI WebSocket to STT/LLM/TTS.""" def __init__( self, websocket: WebSocket, session_context: dict, *, stt: Optional[WhisperSTTService] = None, tts: Optional[KokoroTTSService] = None, vad: Optional[SileroVADService] = None, openai_client=None, ): self.websocket = websocket self.session_context = session_context self.stt = stt self.tts = tts self.vad = vad self.openai_client = openai_client # Agent session ID (reused across turns for conversation continuity) self._agent_session_id: Optional[str] = None self._auth_header: str = session_context.get("auth_header", "") self._cancelled = False self._speaking = False # True while sending TTS audio to client async def run(self): """Main loop: read audio → VAD → STT → LLM → TTS → send audio.""" sid = self.session_context.get("session_id", "?") print(f"[pipeline] Started for session {sid}", flush=True) print(f"[pipeline] STT={self.stt is not None and self.stt._model is not None}, " f"TTS={self.tts is not None and self.tts._pipeline is not None}, " f"VAD={self.vad is not None and self.vad._model is not None}", flush=True) # Audio buffer for accumulating speech speech_buffer = bytearray() vad_buffer = bytearray() # accumulates until _VAD_CHUNK_BYTES is_speech_active = False silence_start: Optional[float] = None speech_start: Optional[float] = None chunk_count = 0 try: while not self._cancelled: try: data = await asyncio.wait_for( self.websocket.receive_bytes(), timeout=30.0 ) except asyncio.TimeoutError: print(f"[pipeline] No data for 30s, session {sid}", flush=True) continue except Exception as exc: print(f"[pipeline] WebSocket read error: {type(exc).__name__}: {exc}", flush=True) break chunk_count += 1 if chunk_count <= 3 or chunk_count % 100 == 0: print(f"[pipeline] Received audio chunk #{chunk_count}, len={len(data)}", flush=True) # Accumulate into VAD-sized chunks vad_buffer.extend(data) while len(vad_buffer) >= _VAD_CHUNK_BYTES: chunk = bytes(vad_buffer[:_VAD_CHUNK_BYTES]) del vad_buffer[:_VAD_CHUNK_BYTES] # Run VAD has_speech = self._detect_speech(chunk) if has_speech: if not is_speech_active: is_speech_active = True speech_start = time.time() silence_start = None print(f"[pipeline] Speech detected!", flush=True) # Barge-in: if we were speaking TTS, stop if self._speaking: self._cancelled_tts = True print("[pipeline] Barge-in!", flush=True) speech_buffer.extend(chunk) silence_start = None else: if is_speech_active: # Still accumulate a bit during silence gap speech_buffer.extend(chunk) if silence_start is None: silence_start = time.time() elif time.time() - silence_start >= _SILENCE_AFTER_SPEECH_SECS: # Silence detected after speech — process speech_duration = time.time() - (speech_start or time.time()) buf_secs = len(speech_buffer) / (_SAMPLE_RATE * _BYTES_PER_SAMPLE) print(f"[pipeline] Silence after speech: duration={speech_duration:.1f}s, buffer={buf_secs:.1f}s", flush=True) if speech_duration >= _MIN_SPEECH_SECS and len(speech_buffer) > 0: await self._process_speech(bytes(speech_buffer)) # Reset speech_buffer.clear() is_speech_active = False silence_start = None speech_start = None except asyncio.CancelledError: print(f"[pipeline] Cancelled, session {sid}", flush=True) except Exception as exc: print(f"[pipeline] ERROR: {type(exc).__name__}: {exc}", flush=True) import traceback traceback.print_exc() finally: print(f"[pipeline] Ended for session {sid}", flush=True) def cancel(self): self._cancelled = True def _detect_speech(self, chunk: bytes) -> bool: """Run VAD on a single chunk. Returns True if speech detected.""" if self.vad is None or self.vad._model is None: # No VAD — treat everything as speech return True try: return self.vad.detect_speech(chunk) except Exception as exc: logger.debug("VAD error: %s", exc) return True # Assume speech on error async def _process_speech(self, audio_data: bytes): """Transcribe speech, stream agent response, synthesize TTS per-sentence.""" session_id = self.session_context.get("session_id", "?") audio_secs = len(audio_data) / (_SAMPLE_RATE * _BYTES_PER_SAMPLE) print(f"[pipeline] ===== Processing speech: {audio_secs:.1f}s of audio, session={session_id} =====", flush=True) # 1. STT t0 = time.time() text = await self._transcribe(audio_data) stt_ms = int((time.time() - t0) * 1000) if not text or not text.strip(): print(f"[pipeline] STT returned empty text (took {stt_ms}ms), skipping", flush=True) return print(f"[pipeline] [STT] ({stt_ms}ms) User said: \"{text.strip()}\"", flush=True) # Notify client of user transcript try: await self.websocket.send_text( json.dumps({"type": "transcript", "text": text.strip(), "role": "user"}) ) except Exception: pass # 2. Stream agent response → sentence-split → TTS per sentence print(f"[pipeline] [AGENT] Sending to agent: \"{text.strip()}\"", flush=True) t1 = time.time() self._speaking = True self._cancelled_tts = False sentence_buf = "" full_response = [] tts_count = 0 first_audio_ms = None try: async for chunk in self._agent_stream(text.strip()): if self._cancelled_tts: print(f"[pipeline] [TTS] Barge-in, stopping TTS stream", flush=True) break full_response.append(chunk) sentence_buf += chunk # Check for sentence boundaries search_start = 0 while True: match = _SENTENCE_END_RE.search(sentence_buf, search_start) if not match: break # Skip matches that are too early — sentence too short if match.end() < _MIN_SENTENCE_LEN: search_start = match.end() continue sentence = sentence_buf[:match.end()].strip() sentence_buf = sentence_buf[match.end():] search_start = 0 if sentence: tts_count += 1 if first_audio_ms is None: first_audio_ms = int((time.time() - t1) * 1000) print(f"[pipeline] [TTS] First sentence ready at {first_audio_ms}ms: \"{sentence[:60]}\"", flush=True) else: print(f"[pipeline] [TTS] Sentence #{tts_count}: \"{sentence[:60]}\"", flush=True) await self._synthesize_chunk(sentence) # Flush remaining buffer remaining = sentence_buf.strip() if remaining and not self._cancelled_tts: tts_count += 1 if first_audio_ms is None: first_audio_ms = int((time.time() - t1) * 1000) print(f"[pipeline] [TTS] Final flush #{tts_count}: \"{remaining[:60]}\"", flush=True) await self._synthesize_chunk(remaining) except Exception as exc: print(f"[pipeline] Stream/TTS error: {type(exc).__name__}: {exc}", flush=True) finally: self._speaking = False total_ms = int((time.time() - t1) * 1000) response_text = "".join(full_response).strip() if response_text: print(f"[pipeline] [AGENT] Response ({len(response_text)} chars): \"{response_text[:200]}\"", flush=True) # Send full transcript to client try: await self.websocket.send_text( json.dumps({"type": "transcript", "text": response_text, "role": "assistant"}) ) except Exception: pass else: print(f"[pipeline] [AGENT] No response text collected", flush=True) print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Stream+TTS={total_ms}ms (first audio at {first_audio_ms or 0}ms, {tts_count} chunks) =====", flush=True) async def _transcribe(self, audio_data: bytes) -> str: """Transcribe audio using STT service (local or OpenAI).""" if settings.stt_provider == "openai": return await self._transcribe_openai(audio_data) if self.stt is None or self.stt._model is None: logger.warning("STT not available") return "" try: return await self.stt.transcribe(audio_data) except Exception as exc: logger.error("STT error: %s", exc) return "" async def _transcribe_openai(self, audio_data: bytes) -> str: """Transcribe audio using OpenAI API.""" if self.openai_client is None: logger.warning("OpenAI client not available for STT") return "" try: # Write PCM to a temp WAV file for the API wav_bytes = _pcm_to_wav(audio_data, _SAMPLE_RATE) tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") tmp.write(wav_bytes) tmp.close() client = self.openai_client model = settings.openai_stt_model def _do(): with open(tmp.name, "rb") as f: result = client.audio.transcriptions.create( model=model, file=f, language="zh", ) return result.text text = await asyncio.get_event_loop().run_in_executor(None, _do) os.unlink(tmp.name) return text or "" except Exception as exc: logger.error("OpenAI STT error: %s", exc) return "" async def _agent_stream(self, user_text: str) -> AsyncGenerator[str, None]: """Stream text from agent-service, yielding chunks as they arrive. Flow: 1. WS connect to /ws/agent → subscribe_session 2. POST /api/v1/agent/tasks → triggers engine stream 3. Yield text chunks as they arrive, return on 'completed' """ agent_url = settings.agent_service_url headers = {"Content-Type": "application/json"} if self._auth_header: headers["Authorization"] = self._auth_header ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://") ws_url = f"{ws_url}/ws/agent" event_count = 0 text_count = 0 total_chars = 0 timeout_secs = 120 try: print(f"[pipeline] [AGENT] Connecting WS: {ws_url}", flush=True) async with websockets.connect(ws_url) as ws: print(f"[pipeline] [AGENT] WS connected", flush=True) # 1. Pre-subscribe with existing session ID pre_session_id = self._agent_session_id or "" if pre_session_id: await ws.send(json.dumps({ "event": "subscribe_session", "data": {"sessionId": pre_session_id}, })) print(f"[pipeline] [AGENT] Pre-subscribed session={pre_session_id}", flush=True) # 2. Create agent task (use claude_api engine for streaming TTS) body = {"prompt": user_text, "engineType": "claude_api"} if self._agent_session_id: body["sessionId"] = self._agent_session_id print(f"[pipeline] [AGENT] POST /tasks prompt=\"{user_text[:80]}\"", flush=True) async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{agent_url}/api/v1/agent/tasks", json=body, headers=headers, ) print(f"[pipeline] [AGENT] POST response: {resp.status_code}", flush=True) if resp.status_code not in (200, 201): print(f"[pipeline] [AGENT] Task creation FAILED: {resp.status_code} {resp.text[:200]}", flush=True) yield "抱歉,Agent服务暂时不可用。" return data = resp.json() session_id = data.get("sessionId", "") task_id = data.get("taskId", "") self._agent_session_id = session_id print(f"[pipeline] [AGENT] Task created: session={session_id}, task={task_id}", flush=True) # 3. Subscribe with actual IDs await ws.send(json.dumps({ "event": "subscribe_session", "data": {"sessionId": session_id, "taskId": task_id}, })) print(f"[pipeline] [AGENT] Subscribed session={session_id}, task={task_id}", flush=True) # 4. Stream events deadline = time.time() + timeout_secs while time.time() < deadline: remaining = deadline - time.time() try: raw = await asyncio.wait_for(ws.recv(), timeout=min(5.0, remaining)) except asyncio.TimeoutError: if time.time() >= deadline: print(f"[pipeline] [AGENT] TIMEOUT after {timeout_secs}s", flush=True) continue except Exception as ws_err: print(f"[pipeline] [AGENT] WS recv error: {type(ws_err).__name__}: {ws_err}", flush=True) break try: msg = json.loads(raw) except (json.JSONDecodeError, TypeError): continue event_type = msg.get("event", "") event_count += 1 if event_type == "subscribed": print(f"[pipeline] [AGENT] Subscription confirmed: {msg.get('data', {})}", flush=True) elif event_type == "stream_event": evt_data = msg.get("data", {}) evt_type = evt_data.get("type", "") if evt_type == "text": content = evt_data.get("content", "") if content: text_count += 1 total_chars += len(content) if text_count <= 3 or text_count % 10 == 0: print(f"[pipeline] [AGENT] Text #{text_count}: +{len(content)} chars (total: {total_chars})", flush=True) yield content elif evt_type == "completed": summary = evt_data.get("summary", "") if summary and text_count == 0: print(f"[pipeline] [AGENT] Using summary: \"{summary[:100]}\"", flush=True) yield summary print(f"[pipeline] [AGENT] Completed! {text_count} text events, {total_chars} chars, {event_count} WS events", flush=True) return elif evt_type == "error": err_msg = evt_data.get("message", "Unknown error") print(f"[pipeline] [AGENT] ERROR: {err_msg}", flush=True) yield f"Agent 错误: {err_msg}" return except Exception as exc: print(f"[pipeline] [AGENT] Stream error: {type(exc).__name__}: {exc}", flush=True) yield "抱歉,Agent服务暂时不可用。" async def _synthesize_chunk(self, text: str): """Synthesize a single sentence/chunk and send audio over WebSocket.""" if not text.strip(): return try: if settings.tts_provider == "openai": audio_bytes = await self._tts_openai(text) else: if self.tts is None or self.tts._pipeline is None: return audio_bytes = await asyncio.get_event_loop().run_in_executor( None, self._tts_sync, text ) if not audio_bytes or self._cancelled_tts: return # Send audio in WS-sized chunks sent_bytes = 0 offset = 0 while offset < len(audio_bytes) and not self._cancelled_tts: end = min(offset + _WS_AUDIO_CHUNK, len(audio_bytes)) try: await self.websocket.send_bytes(audio_bytes[offset:end]) sent_bytes += end - offset except Exception as e: print(f"[pipeline] [TTS] send_bytes failed at {offset}/{len(audio_bytes)}: {e}", flush=True) break offset = end await asyncio.sleep(0.005) print(f"[pipeline] [TTS] Sent {sent_bytes}/{len(audio_bytes)} bytes for: \"{text[:20]}\"", flush=True) except Exception as exc: logger.error("TTS chunk error: %s", exc) async def _tts_openai(self, text: str) -> bytes: """Synthesize text to 16kHz PCM via OpenAI TTS API.""" if self.openai_client is None: logger.warning("OpenAI client not available for TTS") return b"" try: client = self.openai_client model = settings.openai_tts_model voice = settings.openai_tts_voice def _do(): response = client.audio.speech.create( model=model, voice=voice, input=text, response_format="pcm", # raw 24kHz 16-bit mono PCM ) return response.content pcm_24k = await asyncio.get_event_loop().run_in_executor(None, _do) # OpenAI returns 24kHz 16-bit mono PCM, resample to 16kHz audio_np = np.frombuffer(pcm_24k, dtype=np.int16).astype(np.float32) if len(audio_np) > 0: target_samples = int(len(audio_np) / 24000 * _SAMPLE_RATE) indices = np.linspace(0, len(audio_np) - 1, target_samples) resampled = np.interp(indices, np.arange(len(audio_np)), audio_np) return resampled.astype(np.int16).tobytes() return b"" except Exception as exc: logger.error("OpenAI TTS error: %s", exc) return b"" def _tts_sync(self, text: str) -> bytes: """Synchronous TTS synthesis with Kokoro (runs in thread pool).""" try: samples = [] for _, _, audio in self.tts._pipeline(text, voice=self.tts.voice): samples.append(audio) if not samples: return b"" audio_np = np.concatenate(samples) # Kokoro outputs at 24kHz, resample to 16kHz if len(audio_np) > 0: target_samples = int(len(audio_np) / 24000 * _SAMPLE_RATE) indices = np.linspace(0, len(audio_np) - 1, target_samples) resampled = np.interp(indices, np.arange(len(audio_np)), audio_np) return (resampled * 32768).clip(-32768, 32767).astype(np.int16).tobytes() return (audio_np * 32768).clip(-32768, 32767).astype(np.int16).tobytes() except Exception as exc: logger.error("TTS synthesis error: %s", exc) return b"" def _pcm_to_wav(pcm_bytes: bytes, sample_rate: int) -> bytes: """Wrap raw 16-bit mono PCM into a WAV container.""" buf = io.BytesIO() data_size = len(pcm_bytes) buf.write(b"RIFF") buf.write(struct.pack(" VoicePipelineTask: """Create a voice pipeline task for the given WebSocket connection.""" return VoicePipelineTask( websocket, session_context, stt=stt, tts=tts, vad=vad, openai_client=openai_client, )