feat: streaming TTS — synthesize per-sentence as agent tokens arrive

Replace batch TTS (wait for full response) with streaming approach:
- _agent_generate → _agent_stream async generator (yield text chunks)
- _process_speech accumulates tokens, splits on sentence boundaries
- Each sentence is TTS'd and sent immediately while more tokens arrive
- First audio plays within ~1s of agent response vs waiting for full text

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-02-24 03:14:22 -08:00
parent aa2a49afd4
commit 65e68a0487
1 changed files with 118 additions and 99 deletions

View File

@ -12,8 +12,9 @@ back as speech via TTS.
import asyncio import asyncio
import json import json
import logging import logging
import re
import time import time
from typing import Optional from typing import AsyncGenerator, Optional
import httpx import httpx
import numpy as np import numpy as np
@ -42,6 +43,10 @@ _VAD_CHUNK_SAMPLES = 512
_VAD_CHUNK_BYTES = _VAD_CHUNK_SAMPLES * _BYTES_PER_SAMPLE _VAD_CHUNK_BYTES = _VAD_CHUNK_SAMPLES * _BYTES_PER_SAMPLE
# Max audio output chunk size sent over WebSocket (4KB) # Max audio output chunk size sent over WebSocket (4KB)
_WS_AUDIO_CHUNK = 4096 _WS_AUDIO_CHUNK = 4096
# Sentence-ending punctuation for splitting TTS chunks
_SENTENCE_END_RE = re.compile(r'[。!?;\n]|[.!?;]\s')
# Minimum chars before we flush a sentence to TTS
_MIN_SENTENCE_LEN = 4
class VoicePipelineTask: class VoicePipelineTask:
@ -170,7 +175,7 @@ class VoicePipelineTask:
return True # Assume speech on error return True # Assume speech on error
async def _process_speech(self, audio_data: bytes): async def _process_speech(self, audio_data: bytes):
"""Transcribe speech, generate LLM response, synthesize and send TTS.""" """Transcribe speech, stream agent response, synthesize TTS per-sentence."""
session_id = self.session_context.get("session_id", "?") session_id = self.session_context.get("session_id", "?")
audio_secs = len(audio_data) / (_SAMPLE_RATE * _BYTES_PER_SAMPLE) audio_secs = len(audio_data) / (_SAMPLE_RATE * _BYTES_PER_SAMPLE)
print(f"[pipeline] ===== Processing speech: {audio_secs:.1f}s of audio, session={session_id} =====", flush=True) print(f"[pipeline] ===== Processing speech: {audio_secs:.1f}s of audio, session={session_id} =====", flush=True)
@ -185,7 +190,7 @@ class VoicePipelineTask:
print(f"[pipeline] [STT] ({stt_ms}ms) User said: \"{text.strip()}\"", flush=True) print(f"[pipeline] [STT] ({stt_ms}ms) User said: \"{text.strip()}\"", flush=True)
# Notify client that we heard them # Notify client of user transcript
try: try:
await self.websocket.send_text( await self.websocket.send_text(
json.dumps({"type": "transcript", "text": text.strip(), "role": "user"}) json.dumps({"type": "transcript", "text": text.strip(), "role": "user"})
@ -193,33 +198,73 @@ class VoicePipelineTask:
except Exception: except Exception:
pass pass
# 2. Agent service — create task + subscribe for response # 2. Stream agent response → sentence-split → TTS per sentence
print(f"[pipeline] [AGENT] Sending to agent: \"{text.strip()}\"", flush=True) print(f"[pipeline] [AGENT] Sending to agent: \"{text.strip()}\"", flush=True)
t1 = time.time() t1 = time.time()
response_text = await self._agent_generate(text.strip()) self._speaking = True
agent_ms = int((time.time() - t1) * 1000) self._cancelled_tts = False
if not response_text: sentence_buf = ""
print(f"[pipeline] [AGENT] ({agent_ms}ms) Agent returned empty response!", flush=True) full_response = []
return tts_count = 0
first_audio_ms = None
print(f"[pipeline] [AGENT] ({agent_ms}ms) Agent response ({len(response_text)} chars): \"{response_text[:200]}\"", flush=True)
# Notify client of the response text
try: try:
await self.websocket.send_text( async for chunk in self._agent_stream(text.strip()):
json.dumps({"type": "transcript", "text": response_text, "role": "assistant"}) if self._cancelled_tts:
) print(f"[pipeline] [TTS] Barge-in, stopping TTS stream", flush=True)
except Exception: break
pass
# 3. TTS → send audio back full_response.append(chunk)
print(f"[pipeline] [TTS] Synthesizing {len(response_text)} chars...", flush=True) sentence_buf += chunk
t2 = time.time()
await self._synthesize_and_send(response_text) # Check for sentence boundaries
tts_ms = int((time.time() - t2) * 1000) while True:
print(f"[pipeline] [TTS] ({tts_ms}ms) Audio sent to client", flush=True) match = _SENTENCE_END_RE.search(sentence_buf)
print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Agent={agent_ms}ms + TTS={tts_ms}ms = {stt_ms+agent_ms+tts_ms}ms =====", flush=True) if match and match.end() >= _MIN_SENTENCE_LEN:
sentence = sentence_buf[:match.end()].strip()
sentence_buf = sentence_buf[match.end():]
if sentence:
tts_count += 1
if first_audio_ms is None:
first_audio_ms = int((time.time() - t1) * 1000)
print(f"[pipeline] [TTS] First sentence ready at {first_audio_ms}ms: \"{sentence[:60]}\"", flush=True)
else:
print(f"[pipeline] [TTS] Sentence #{tts_count}: \"{sentence[:60]}\"", flush=True)
await self._synthesize_chunk(sentence)
else:
break
# Flush remaining buffer
remaining = sentence_buf.strip()
if remaining and not self._cancelled_tts:
tts_count += 1
if first_audio_ms is None:
first_audio_ms = int((time.time() - t1) * 1000)
print(f"[pipeline] [TTS] Final flush #{tts_count}: \"{remaining[:60]}\"", flush=True)
await self._synthesize_chunk(remaining)
except Exception as exc:
print(f"[pipeline] Stream/TTS error: {type(exc).__name__}: {exc}", flush=True)
finally:
self._speaking = False
total_ms = int((time.time() - t1) * 1000)
response_text = "".join(full_response).strip()
if response_text:
print(f"[pipeline] [AGENT] Response ({len(response_text)} chars): \"{response_text[:200]}\"", flush=True)
# Send full transcript to client
try:
await self.websocket.send_text(
json.dumps({"type": "transcript", "text": response_text, "role": "assistant"})
)
except Exception:
pass
else:
print(f"[pipeline] [AGENT] No response text collected", flush=True)
print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Stream+TTS={total_ms}ms (first audio at {first_audio_ms or 0}ms, {tts_count} chunks) =====", flush=True)
async def _transcribe(self, audio_data: bytes) -> str: async def _transcribe(self, audio_data: bytes) -> str:
"""Transcribe audio using STT service.""" """Transcribe audio using STT service."""
@ -232,15 +277,15 @@ class VoicePipelineTask:
logger.error("STT error: %s", exc) logger.error("STT error: %s", exc)
return "" return ""
async def _agent_generate(self, user_text: str) -> str: async def _agent_stream(self, user_text: str) -> AsyncGenerator[str, None]:
"""Send user text to agent-service, subscribe via WS, collect response. """Stream text from agent-service, yielding chunks as they arrive.
Flow (subscribe-first to avoid race condition): Flow:
1. WS connect to /ws/agent subscribe_session (with existing or new sessionId) 1. WS connect to /ws/agent subscribe_session
2. POST /api/v1/agent/tasks triggers engine stream 2. POST /api/v1/agent/tasks triggers engine stream
3. Collect 'text' stream events until 'completed' 3. Yield text chunks as they arrive, return on 'completed'
""" """
agent_url = settings.agent_service_url # http://agent-service:3002 agent_url = settings.agent_service_url
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
if self._auth_header: if self._auth_header:
headers["Authorization"] = self._auth_header headers["Authorization"] = self._auth_header
@ -248,24 +293,23 @@ class VoicePipelineTask:
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://") ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/ws/agent" ws_url = f"{ws_url}/ws/agent"
try: event_count = 0
collected_text = [] text_count = 0
event_count = 0 total_chars = 0
timeout_secs = 120 # Max wait for agent response timeout_secs = 120
try:
print(f"[pipeline] [AGENT] Connecting WS: {ws_url}", flush=True) print(f"[pipeline] [AGENT] Connecting WS: {ws_url}", flush=True)
async with websockets.connect(ws_url) as ws: async with websockets.connect(ws_url) as ws:
print(f"[pipeline] [AGENT] WS connected", flush=True) print(f"[pipeline] [AGENT] WS connected", flush=True)
# 1. Subscribe FIRST (before creating task to avoid missing events) # 1. Pre-subscribe with existing session ID
pre_session_id = self._agent_session_id or "" pre_session_id = self._agent_session_id or ""
if pre_session_id: if pre_session_id:
subscribe_msg = json.dumps({ await ws.send(json.dumps({
"event": "subscribe_session", "event": "subscribe_session",
"data": {"sessionId": pre_session_id}, "data": {"sessionId": pre_session_id},
}) }))
await ws.send(subscribe_msg)
print(f"[pipeline] [AGENT] Pre-subscribed session={pre_session_id}", flush=True) print(f"[pipeline] [AGENT] Pre-subscribed session={pre_session_id}", flush=True)
# 2. Create agent task # 2. Create agent task
@ -281,9 +325,10 @@ class VoicePipelineTask:
headers=headers, headers=headers,
) )
print(f"[pipeline] [AGENT] POST response: {resp.status_code}", flush=True) print(f"[pipeline] [AGENT] POST response: {resp.status_code}", flush=True)
if resp.status_code != 200 and resp.status_code != 201: if resp.status_code not in (200, 201):
print(f"[pipeline] [AGENT] Task creation FAILED: {resp.status_code} {resp.text[:200]}", flush=True) print(f"[pipeline] [AGENT] Task creation FAILED: {resp.status_code} {resp.text[:200]}", flush=True)
return "抱歉Agent服务暂时不可用。" yield "抱歉Agent服务暂时不可用。"
return
data = resp.json() data = resp.json()
session_id = data.get("sessionId", "") session_id = data.get("sessionId", "")
@ -291,15 +336,14 @@ class VoicePipelineTask:
self._agent_session_id = session_id self._agent_session_id = session_id
print(f"[pipeline] [AGENT] Task created: session={session_id}, task={task_id}", flush=True) print(f"[pipeline] [AGENT] Task created: session={session_id}, task={task_id}", flush=True)
# 3. Subscribe with actual session/task IDs (covers first-time case) # 3. Subscribe with actual IDs
subscribe_msg = json.dumps({ await ws.send(json.dumps({
"event": "subscribe_session", "event": "subscribe_session",
"data": {"sessionId": session_id, "taskId": task_id}, "data": {"sessionId": session_id, "taskId": task_id},
}) }))
await ws.send(subscribe_msg)
print(f"[pipeline] [AGENT] Subscribed session={session_id}, task={task_id}", flush=True) print(f"[pipeline] [AGENT] Subscribed session={session_id}, task={task_id}", flush=True)
# 4. Collect events until completed # 4. Stream events
deadline = time.time() + timeout_secs deadline = time.time() + timeout_secs
while time.time() < deadline: while time.time() < deadline:
remaining = deadline - time.time() remaining = deadline - time.time()
@ -307,7 +351,7 @@ class VoicePipelineTask:
raw = await asyncio.wait_for(ws.recv(), timeout=min(5.0, remaining)) raw = await asyncio.wait_for(ws.recv(), timeout=min(5.0, remaining))
except asyncio.TimeoutError: except asyncio.TimeoutError:
if time.time() >= deadline: if time.time() >= deadline:
print(f"[pipeline] [AGENT] TIMEOUT after {timeout_secs}s waiting for events", flush=True) print(f"[pipeline] [AGENT] TIMEOUT after {timeout_secs}s", flush=True)
continue continue
except Exception as ws_err: except Exception as ws_err:
print(f"[pipeline] [AGENT] WS recv error: {type(ws_err).__name__}: {ws_err}", flush=True) print(f"[pipeline] [AGENT] WS recv error: {type(ws_err).__name__}: {ws_err}", flush=True)
@ -316,7 +360,6 @@ class VoicePipelineTask:
try: try:
msg = json.loads(raw) msg = json.loads(raw)
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
print(f"[pipeline] [AGENT] Non-JSON WS message: {str(raw)[:100]}", flush=True)
continue continue
event_type = msg.get("event", "") event_type = msg.get("event", "")
@ -328,59 +371,42 @@ class VoicePipelineTask:
elif event_type == "stream_event": elif event_type == "stream_event":
evt_data = msg.get("data", {}) evt_data = msg.get("data", {})
evt_type = evt_data.get("type", "") evt_type = evt_data.get("type", "")
# Engine events are flat: { type, content, summary, ... }
# (no nested "data" sub-field)
if evt_type == "text": if evt_type == "text":
content = evt_data.get("content", "") content = evt_data.get("content", "")
if content: if content:
collected_text.append(content) text_count += 1
# Log first and periodic text events total_chars += len(content)
if len(collected_text) <= 3 or len(collected_text) % 10 == 0: if text_count <= 3 or text_count % 10 == 0:
total_len = sum(len(t) for t in collected_text) print(f"[pipeline] [AGENT] Text #{text_count}: +{len(content)} chars (total: {total_chars})", flush=True)
print(f"[pipeline] [AGENT] Text event #{len(collected_text)}: +{len(content)} chars (total: {total_len})", flush=True) yield content
elif evt_type == "completed": elif evt_type == "completed":
summary = evt_data.get("summary", "") summary = evt_data.get("summary", "")
if summary and not collected_text: if summary and text_count == 0:
collected_text.append(summary) print(f"[pipeline] [AGENT] Using summary: \"{summary[:100]}\"", flush=True)
print(f"[pipeline] [AGENT] Using summary as response: \"{summary[:100]}\"", flush=True) yield summary
total_chars = sum(len(t) for t in collected_text) print(f"[pipeline] [AGENT] Completed! {text_count} text events, {total_chars} chars, {event_count} WS events", flush=True)
print(f"[pipeline] [AGENT] Completed! {len(collected_text)} text events, {total_chars} chars total, {event_count} WS events received", flush=True) return
break
elif evt_type == "error": elif evt_type == "error":
err_msg = evt_data.get("message", "Unknown error") err_msg = evt_data.get("message", "Unknown error")
print(f"[pipeline] [AGENT] ERROR event: {err_msg}", flush=True) print(f"[pipeline] [AGENT] ERROR: {err_msg}", flush=True)
return f"Agent 错误: {err_msg}" yield f"Agent 错误: {err_msg}"
return
else:
print(f"[pipeline] [AGENT] Stream event type={evt_type}", flush=True)
else:
print(f"[pipeline] [AGENT] WS event: {event_type}", flush=True)
result = "".join(collected_text).strip()
if not result:
print(f"[pipeline] [AGENT] WARNING: No text collected after {event_count} events!", flush=True)
return "Agent 未返回回复。"
return result
except Exception as exc: except Exception as exc:
print(f"[pipeline] Agent generate error: {type(exc).__name__}: {exc}", flush=True) print(f"[pipeline] [AGENT] Stream error: {type(exc).__name__}: {exc}", flush=True)
return "抱歉Agent服务暂时不可用,请稍后再试" yield "抱歉Agent服务暂时不可用。"
async def _synthesize_and_send(self, text: str): async def _synthesize_chunk(self, text: str):
"""Synthesize text to speech and send audio chunks over WebSocket.""" """Synthesize a single sentence/chunk and send audio over WebSocket."""
self._speaking = True if self.tts is None or self.tts._pipeline is None:
self._cancelled_tts = False return
if not text.strip():
return
try: try:
if self.tts is None or self.tts._pipeline is None:
logger.warning("TTS not available, skipping audio response")
return
# Run TTS (CPU-bound) in a thread
audio_bytes = await asyncio.get_event_loop().run_in_executor( audio_bytes = await asyncio.get_event_loop().run_in_executor(
None, self._tts_sync, text None, self._tts_sync, text
) )
@ -388,7 +414,7 @@ class VoicePipelineTask:
if not audio_bytes or self._cancelled_tts: if not audio_bytes or self._cancelled_tts:
return return
# Send audio in chunks # Send audio in WS-sized chunks
offset = 0 offset = 0
while offset < len(audio_bytes) and not self._cancelled_tts: while offset < len(audio_bytes) and not self._cancelled_tts:
end = min(offset + _WS_AUDIO_CHUNK, len(audio_bytes)) end = min(offset + _WS_AUDIO_CHUNK, len(audio_bytes))
@ -397,13 +423,10 @@ class VoicePipelineTask:
except Exception: except Exception:
break break
offset = end offset = end
# Small yield to not starve the event loop await asyncio.sleep(0.005)
await asyncio.sleep(0.01)
except Exception as exc: except Exception as exc:
logger.error("TTS/send error: %s", exc) logger.error("TTS chunk error: %s", exc)
finally:
self._speaking = False
def _tts_sync(self, text: str) -> bytes: def _tts_sync(self, text: str) -> bytes:
"""Synchronous TTS synthesis (runs in thread pool).""" """Synchronous TTS synthesis (runs in thread pool)."""
@ -416,13 +439,9 @@ class VoicePipelineTask:
return b"" return b""
audio_np = np.concatenate(samples) audio_np = np.concatenate(samples)
# Kokoro outputs at 24kHz, we need 16kHz # Kokoro outputs at 24kHz, resample to 16kHz
# Resample using linear interpolation
if len(audio_np) > 0: if len(audio_np) > 0:
original_rate = 24000 target_samples = int(len(audio_np) / 24000 * _SAMPLE_RATE)
target_rate = _SAMPLE_RATE
duration = len(audio_np) / original_rate
target_samples = int(duration * target_rate)
indices = np.linspace(0, len(audio_np) - 1, target_samples) indices = np.linspace(0, len(audio_np) - 1, target_samples)
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np) resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
return (resampled * 32768).clip(-32768, 32767).astype(np.int16).tobytes() return (resampled * 32768).clip(-32768, 32767).astype(np.int16).tobytes()