feat: streaming TTS — synthesize per-sentence as agent tokens arrive

Replace batch TTS (wait for full response) with streaming approach:
- _agent_generate → _agent_stream async generator (yield text chunks)
- _process_speech accumulates tokens, splits on sentence boundaries
- Each sentence is TTS'd and sent immediately while more tokens arrive
- First audio plays within ~1s of agent response vs waiting for full text

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-02-24 03:14:22 -08:00
parent aa2a49afd4
commit 65e68a0487
1 changed files with 118 additions and 99 deletions

View File

@ -12,8 +12,9 @@ back as speech via TTS.
import asyncio
import json
import logging
import re
import time
from typing import Optional
from typing import AsyncGenerator, Optional
import httpx
import numpy as np
@ -42,6 +43,10 @@ _VAD_CHUNK_SAMPLES = 512
_VAD_CHUNK_BYTES = _VAD_CHUNK_SAMPLES * _BYTES_PER_SAMPLE
# Max audio output chunk size sent over WebSocket (4KB)
_WS_AUDIO_CHUNK = 4096
# Sentence-ending punctuation for splitting TTS chunks
_SENTENCE_END_RE = re.compile(r'[。!?;\n]|[.!?;]\s')
# Minimum chars before we flush a sentence to TTS
_MIN_SENTENCE_LEN = 4
class VoicePipelineTask:
@ -170,7 +175,7 @@ class VoicePipelineTask:
return True # Assume speech on error
async def _process_speech(self, audio_data: bytes):
"""Transcribe speech, generate LLM response, synthesize and send TTS."""
"""Transcribe speech, stream agent response, synthesize TTS per-sentence."""
session_id = self.session_context.get("session_id", "?")
audio_secs = len(audio_data) / (_SAMPLE_RATE * _BYTES_PER_SAMPLE)
print(f"[pipeline] ===== Processing speech: {audio_secs:.1f}s of audio, session={session_id} =====", flush=True)
@ -185,7 +190,7 @@ class VoicePipelineTask:
print(f"[pipeline] [STT] ({stt_ms}ms) User said: \"{text.strip()}\"", flush=True)
# Notify client that we heard them
# Notify client of user transcript
try:
await self.websocket.send_text(
json.dumps({"type": "transcript", "text": text.strip(), "role": "user"})
@ -193,33 +198,73 @@ class VoicePipelineTask:
except Exception:
pass
# 2. Agent service — create task + subscribe for response
# 2. Stream agent response → sentence-split → TTS per sentence
print(f"[pipeline] [AGENT] Sending to agent: \"{text.strip()}\"", flush=True)
t1 = time.time()
response_text = await self._agent_generate(text.strip())
agent_ms = int((time.time() - t1) * 1000)
self._speaking = True
self._cancelled_tts = False
if not response_text:
print(f"[pipeline] [AGENT] ({agent_ms}ms) Agent returned empty response!", flush=True)
return
sentence_buf = ""
full_response = []
tts_count = 0
first_audio_ms = None
print(f"[pipeline] [AGENT] ({agent_ms}ms) Agent response ({len(response_text)} chars): \"{response_text[:200]}\"", flush=True)
try:
async for chunk in self._agent_stream(text.strip()):
if self._cancelled_tts:
print(f"[pipeline] [TTS] Barge-in, stopping TTS stream", flush=True)
break
# Notify client of the response text
full_response.append(chunk)
sentence_buf += chunk
# Check for sentence boundaries
while True:
match = _SENTENCE_END_RE.search(sentence_buf)
if match and match.end() >= _MIN_SENTENCE_LEN:
sentence = sentence_buf[:match.end()].strip()
sentence_buf = sentence_buf[match.end():]
if sentence:
tts_count += 1
if first_audio_ms is None:
first_audio_ms = int((time.time() - t1) * 1000)
print(f"[pipeline] [TTS] First sentence ready at {first_audio_ms}ms: \"{sentence[:60]}\"", flush=True)
else:
print(f"[pipeline] [TTS] Sentence #{tts_count}: \"{sentence[:60]}\"", flush=True)
await self._synthesize_chunk(sentence)
else:
break
# Flush remaining buffer
remaining = sentence_buf.strip()
if remaining and not self._cancelled_tts:
tts_count += 1
if first_audio_ms is None:
first_audio_ms = int((time.time() - t1) * 1000)
print(f"[pipeline] [TTS] Final flush #{tts_count}: \"{remaining[:60]}\"", flush=True)
await self._synthesize_chunk(remaining)
except Exception as exc:
print(f"[pipeline] Stream/TTS error: {type(exc).__name__}: {exc}", flush=True)
finally:
self._speaking = False
total_ms = int((time.time() - t1) * 1000)
response_text = "".join(full_response).strip()
if response_text:
print(f"[pipeline] [AGENT] Response ({len(response_text)} chars): \"{response_text[:200]}\"", flush=True)
# Send full transcript to client
try:
await self.websocket.send_text(
json.dumps({"type": "transcript", "text": response_text, "role": "assistant"})
)
except Exception:
pass
else:
print(f"[pipeline] [AGENT] No response text collected", flush=True)
# 3. TTS → send audio back
print(f"[pipeline] [TTS] Synthesizing {len(response_text)} chars...", flush=True)
t2 = time.time()
await self._synthesize_and_send(response_text)
tts_ms = int((time.time() - t2) * 1000)
print(f"[pipeline] [TTS] ({tts_ms}ms) Audio sent to client", flush=True)
print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Agent={agent_ms}ms + TTS={tts_ms}ms = {stt_ms+agent_ms+tts_ms}ms =====", flush=True)
print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Stream+TTS={total_ms}ms (first audio at {first_audio_ms or 0}ms, {tts_count} chunks) =====", flush=True)
async def _transcribe(self, audio_data: bytes) -> str:
"""Transcribe audio using STT service."""
@ -232,15 +277,15 @@ class VoicePipelineTask:
logger.error("STT error: %s", exc)
return ""
async def _agent_generate(self, user_text: str) -> str:
"""Send user text to agent-service, subscribe via WS, collect response.
async def _agent_stream(self, user_text: str) -> AsyncGenerator[str, None]:
"""Stream text from agent-service, yielding chunks as they arrive.
Flow (subscribe-first to avoid race condition):
1. WS connect to /ws/agent subscribe_session (with existing or new sessionId)
Flow:
1. WS connect to /ws/agent subscribe_session
2. POST /api/v1/agent/tasks triggers engine stream
3. Collect 'text' stream events until 'completed'
3. Yield text chunks as they arrive, return on 'completed'
"""
agent_url = settings.agent_service_url # http://agent-service:3002
agent_url = settings.agent_service_url
headers = {"Content-Type": "application/json"}
if self._auth_header:
headers["Authorization"] = self._auth_header
@ -248,24 +293,23 @@ class VoicePipelineTask:
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/ws/agent"
try:
collected_text = []
event_count = 0
timeout_secs = 120 # Max wait for agent response
text_count = 0
total_chars = 0
timeout_secs = 120
try:
print(f"[pipeline] [AGENT] Connecting WS: {ws_url}", flush=True)
async with websockets.connect(ws_url) as ws:
print(f"[pipeline] [AGENT] WS connected", flush=True)
# 1. Subscribe FIRST (before creating task to avoid missing events)
# 1. Pre-subscribe with existing session ID
pre_session_id = self._agent_session_id or ""
if pre_session_id:
subscribe_msg = json.dumps({
await ws.send(json.dumps({
"event": "subscribe_session",
"data": {"sessionId": pre_session_id},
})
await ws.send(subscribe_msg)
}))
print(f"[pipeline] [AGENT] Pre-subscribed session={pre_session_id}", flush=True)
# 2. Create agent task
@ -281,9 +325,10 @@ class VoicePipelineTask:
headers=headers,
)
print(f"[pipeline] [AGENT] POST response: {resp.status_code}", flush=True)
if resp.status_code != 200 and resp.status_code != 201:
if resp.status_code not in (200, 201):
print(f"[pipeline] [AGENT] Task creation FAILED: {resp.status_code} {resp.text[:200]}", flush=True)
return "抱歉Agent服务暂时不可用。"
yield "抱歉Agent服务暂时不可用。"
return
data = resp.json()
session_id = data.get("sessionId", "")
@ -291,15 +336,14 @@ class VoicePipelineTask:
self._agent_session_id = session_id
print(f"[pipeline] [AGENT] Task created: session={session_id}, task={task_id}", flush=True)
# 3. Subscribe with actual session/task IDs (covers first-time case)
subscribe_msg = json.dumps({
# 3. Subscribe with actual IDs
await ws.send(json.dumps({
"event": "subscribe_session",
"data": {"sessionId": session_id, "taskId": task_id},
})
await ws.send(subscribe_msg)
}))
print(f"[pipeline] [AGENT] Subscribed session={session_id}, task={task_id}", flush=True)
# 4. Collect events until completed
# 4. Stream events
deadline = time.time() + timeout_secs
while time.time() < deadline:
remaining = deadline - time.time()
@ -307,7 +351,7 @@ class VoicePipelineTask:
raw = await asyncio.wait_for(ws.recv(), timeout=min(5.0, remaining))
except asyncio.TimeoutError:
if time.time() >= deadline:
print(f"[pipeline] [AGENT] TIMEOUT after {timeout_secs}s waiting for events", flush=True)
print(f"[pipeline] [AGENT] TIMEOUT after {timeout_secs}s", flush=True)
continue
except Exception as ws_err:
print(f"[pipeline] [AGENT] WS recv error: {type(ws_err).__name__}: {ws_err}", flush=True)
@ -316,7 +360,6 @@ class VoicePipelineTask:
try:
msg = json.loads(raw)
except (json.JSONDecodeError, TypeError):
print(f"[pipeline] [AGENT] Non-JSON WS message: {str(raw)[:100]}", flush=True)
continue
event_type = msg.get("event", "")
@ -328,59 +371,42 @@ class VoicePipelineTask:
elif event_type == "stream_event":
evt_data = msg.get("data", {})
evt_type = evt_data.get("type", "")
# Engine events are flat: { type, content, summary, ... }
# (no nested "data" sub-field)
if evt_type == "text":
content = evt_data.get("content", "")
if content:
collected_text.append(content)
# Log first and periodic text events
if len(collected_text) <= 3 or len(collected_text) % 10 == 0:
total_len = sum(len(t) for t in collected_text)
print(f"[pipeline] [AGENT] Text event #{len(collected_text)}: +{len(content)} chars (total: {total_len})", flush=True)
text_count += 1
total_chars += len(content)
if text_count <= 3 or text_count % 10 == 0:
print(f"[pipeline] [AGENT] Text #{text_count}: +{len(content)} chars (total: {total_chars})", flush=True)
yield content
elif evt_type == "completed":
summary = evt_data.get("summary", "")
if summary and not collected_text:
collected_text.append(summary)
print(f"[pipeline] [AGENT] Using summary as response: \"{summary[:100]}\"", flush=True)
total_chars = sum(len(t) for t in collected_text)
print(f"[pipeline] [AGENT] Completed! {len(collected_text)} text events, {total_chars} chars total, {event_count} WS events received", flush=True)
break
if summary and text_count == 0:
print(f"[pipeline] [AGENT] Using summary: \"{summary[:100]}\"", flush=True)
yield summary
print(f"[pipeline] [AGENT] Completed! {text_count} text events, {total_chars} chars, {event_count} WS events", flush=True)
return
elif evt_type == "error":
err_msg = evt_data.get("message", "Unknown error")
print(f"[pipeline] [AGENT] ERROR event: {err_msg}", flush=True)
return f"Agent 错误: {err_msg}"
else:
print(f"[pipeline] [AGENT] Stream event type={evt_type}", flush=True)
else:
print(f"[pipeline] [AGENT] WS event: {event_type}", flush=True)
result = "".join(collected_text).strip()
if not result:
print(f"[pipeline] [AGENT] WARNING: No text collected after {event_count} events!", flush=True)
return "Agent 未返回回复。"
return result
except Exception as exc:
print(f"[pipeline] Agent generate error: {type(exc).__name__}: {exc}", flush=True)
return "抱歉Agent服务暂时不可用请稍后再试。"
async def _synthesize_and_send(self, text: str):
"""Synthesize text to speech and send audio chunks over WebSocket."""
self._speaking = True
self._cancelled_tts = False
try:
if self.tts is None or self.tts._pipeline is None:
logger.warning("TTS not available, skipping audio response")
print(f"[pipeline] [AGENT] ERROR: {err_msg}", flush=True)
yield f"Agent 错误: {err_msg}"
return
# Run TTS (CPU-bound) in a thread
except Exception as exc:
print(f"[pipeline] [AGENT] Stream error: {type(exc).__name__}: {exc}", flush=True)
yield "抱歉Agent服务暂时不可用。"
async def _synthesize_chunk(self, text: str):
"""Synthesize a single sentence/chunk and send audio over WebSocket."""
if self.tts is None or self.tts._pipeline is None:
return
if not text.strip():
return
try:
audio_bytes = await asyncio.get_event_loop().run_in_executor(
None, self._tts_sync, text
)
@ -388,7 +414,7 @@ class VoicePipelineTask:
if not audio_bytes or self._cancelled_tts:
return
# Send audio in chunks
# Send audio in WS-sized chunks
offset = 0
while offset < len(audio_bytes) and not self._cancelled_tts:
end = min(offset + _WS_AUDIO_CHUNK, len(audio_bytes))
@ -397,13 +423,10 @@ class VoicePipelineTask:
except Exception:
break
offset = end
# Small yield to not starve the event loop
await asyncio.sleep(0.01)
await asyncio.sleep(0.005)
except Exception as exc:
logger.error("TTS/send error: %s", exc)
finally:
self._speaking = False
logger.error("TTS chunk error: %s", exc)
def _tts_sync(self, text: str) -> bytes:
"""Synchronous TTS synthesis (runs in thread pool)."""
@ -416,13 +439,9 @@ class VoicePipelineTask:
return b""
audio_np = np.concatenate(samples)
# Kokoro outputs at 24kHz, we need 16kHz
# Resample using linear interpolation
# Kokoro outputs at 24kHz, resample to 16kHz
if len(audio_np) > 0:
original_rate = 24000
target_rate = _SAMPLE_RATE
duration = len(audio_np) / original_rate
target_samples = int(duration * target_rate)
target_samples = int(len(audio_np) / 24000 * _SAMPLE_RATE)
indices = np.linspace(0, len(audio_np) - 1, target_samples)
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
return (resampled * 32768).clip(-32768, 32767).astype(np.int16).tobytes()