feat: streaming TTS — synthesize per-sentence as agent tokens arrive
Replace batch TTS (wait for full response) with streaming approach: - _agent_generate → _agent_stream async generator (yield text chunks) - _process_speech accumulates tokens, splits on sentence boundaries - Each sentence is TTS'd and sent immediately while more tokens arrive - First audio plays within ~1s of agent response vs waiting for full text Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
aa2a49afd4
commit
65e68a0487
|
|
@ -12,8 +12,9 @@ back as speech via TTS.
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
|
|
@ -42,6 +43,10 @@ _VAD_CHUNK_SAMPLES = 512
|
|||
_VAD_CHUNK_BYTES = _VAD_CHUNK_SAMPLES * _BYTES_PER_SAMPLE
|
||||
# Max audio output chunk size sent over WebSocket (4KB)
|
||||
_WS_AUDIO_CHUNK = 4096
|
||||
# Sentence-ending punctuation for splitting TTS chunks
|
||||
_SENTENCE_END_RE = re.compile(r'[。!?;\n]|[.!?;]\s')
|
||||
# Minimum chars before we flush a sentence to TTS
|
||||
_MIN_SENTENCE_LEN = 4
|
||||
|
||||
|
||||
class VoicePipelineTask:
|
||||
|
|
@ -170,7 +175,7 @@ class VoicePipelineTask:
|
|||
return True # Assume speech on error
|
||||
|
||||
async def _process_speech(self, audio_data: bytes):
|
||||
"""Transcribe speech, generate LLM response, synthesize and send TTS."""
|
||||
"""Transcribe speech, stream agent response, synthesize TTS per-sentence."""
|
||||
session_id = self.session_context.get("session_id", "?")
|
||||
audio_secs = len(audio_data) / (_SAMPLE_RATE * _BYTES_PER_SAMPLE)
|
||||
print(f"[pipeline] ===== Processing speech: {audio_secs:.1f}s of audio, session={session_id} =====", flush=True)
|
||||
|
|
@ -185,7 +190,7 @@ class VoicePipelineTask:
|
|||
|
||||
print(f"[pipeline] [STT] ({stt_ms}ms) User said: \"{text.strip()}\"", flush=True)
|
||||
|
||||
# Notify client that we heard them
|
||||
# Notify client of user transcript
|
||||
try:
|
||||
await self.websocket.send_text(
|
||||
json.dumps({"type": "transcript", "text": text.strip(), "role": "user"})
|
||||
|
|
@ -193,33 +198,73 @@ class VoicePipelineTask:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Agent service — create task + subscribe for response
|
||||
# 2. Stream agent response → sentence-split → TTS per sentence
|
||||
print(f"[pipeline] [AGENT] Sending to agent: \"{text.strip()}\"", flush=True)
|
||||
t1 = time.time()
|
||||
response_text = await self._agent_generate(text.strip())
|
||||
agent_ms = int((time.time() - t1) * 1000)
|
||||
self._speaking = True
|
||||
self._cancelled_tts = False
|
||||
|
||||
if not response_text:
|
||||
print(f"[pipeline] [AGENT] ({agent_ms}ms) Agent returned empty response!", flush=True)
|
||||
return
|
||||
sentence_buf = ""
|
||||
full_response = []
|
||||
tts_count = 0
|
||||
first_audio_ms = None
|
||||
|
||||
print(f"[pipeline] [AGENT] ({agent_ms}ms) Agent response ({len(response_text)} chars): \"{response_text[:200]}\"", flush=True)
|
||||
|
||||
# Notify client of the response text
|
||||
try:
|
||||
await self.websocket.send_text(
|
||||
json.dumps({"type": "transcript", "text": response_text, "role": "assistant"})
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
async for chunk in self._agent_stream(text.strip()):
|
||||
if self._cancelled_tts:
|
||||
print(f"[pipeline] [TTS] Barge-in, stopping TTS stream", flush=True)
|
||||
break
|
||||
|
||||
# 3. TTS → send audio back
|
||||
print(f"[pipeline] [TTS] Synthesizing {len(response_text)} chars...", flush=True)
|
||||
t2 = time.time()
|
||||
await self._synthesize_and_send(response_text)
|
||||
tts_ms = int((time.time() - t2) * 1000)
|
||||
print(f"[pipeline] [TTS] ({tts_ms}ms) Audio sent to client", flush=True)
|
||||
print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Agent={agent_ms}ms + TTS={tts_ms}ms = {stt_ms+agent_ms+tts_ms}ms =====", flush=True)
|
||||
full_response.append(chunk)
|
||||
sentence_buf += chunk
|
||||
|
||||
# Check for sentence boundaries
|
||||
while True:
|
||||
match = _SENTENCE_END_RE.search(sentence_buf)
|
||||
if match and match.end() >= _MIN_SENTENCE_LEN:
|
||||
sentence = sentence_buf[:match.end()].strip()
|
||||
sentence_buf = sentence_buf[match.end():]
|
||||
if sentence:
|
||||
tts_count += 1
|
||||
if first_audio_ms is None:
|
||||
first_audio_ms = int((time.time() - t1) * 1000)
|
||||
print(f"[pipeline] [TTS] First sentence ready at {first_audio_ms}ms: \"{sentence[:60]}\"", flush=True)
|
||||
else:
|
||||
print(f"[pipeline] [TTS] Sentence #{tts_count}: \"{sentence[:60]}\"", flush=True)
|
||||
await self._synthesize_chunk(sentence)
|
||||
else:
|
||||
break
|
||||
|
||||
# Flush remaining buffer
|
||||
remaining = sentence_buf.strip()
|
||||
if remaining and not self._cancelled_tts:
|
||||
tts_count += 1
|
||||
if first_audio_ms is None:
|
||||
first_audio_ms = int((time.time() - t1) * 1000)
|
||||
print(f"[pipeline] [TTS] Final flush #{tts_count}: \"{remaining[:60]}\"", flush=True)
|
||||
await self._synthesize_chunk(remaining)
|
||||
|
||||
except Exception as exc:
|
||||
print(f"[pipeline] Stream/TTS error: {type(exc).__name__}: {exc}", flush=True)
|
||||
finally:
|
||||
self._speaking = False
|
||||
|
||||
total_ms = int((time.time() - t1) * 1000)
|
||||
response_text = "".join(full_response).strip()
|
||||
|
||||
if response_text:
|
||||
print(f"[pipeline] [AGENT] Response ({len(response_text)} chars): \"{response_text[:200]}\"", flush=True)
|
||||
# Send full transcript to client
|
||||
try:
|
||||
await self.websocket.send_text(
|
||||
json.dumps({"type": "transcript", "text": response_text, "role": "assistant"})
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(f"[pipeline] [AGENT] No response text collected", flush=True)
|
||||
|
||||
print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Stream+TTS={total_ms}ms (first audio at {first_audio_ms or 0}ms, {tts_count} chunks) =====", flush=True)
|
||||
|
||||
async def _transcribe(self, audio_data: bytes) -> str:
|
||||
"""Transcribe audio using STT service."""
|
||||
|
|
@ -232,15 +277,15 @@ class VoicePipelineTask:
|
|||
logger.error("STT error: %s", exc)
|
||||
return ""
|
||||
|
||||
async def _agent_generate(self, user_text: str) -> str:
|
||||
"""Send user text to agent-service, subscribe via WS, collect response.
|
||||
async def _agent_stream(self, user_text: str) -> AsyncGenerator[str, None]:
|
||||
"""Stream text from agent-service, yielding chunks as they arrive.
|
||||
|
||||
Flow (subscribe-first to avoid race condition):
|
||||
1. WS connect to /ws/agent → subscribe_session (with existing or new sessionId)
|
||||
2. POST /api/v1/agent/tasks → triggers engine stream
|
||||
3. Collect 'text' stream events until 'completed'
|
||||
Flow:
|
||||
1. WS connect to /ws/agent → subscribe_session
|
||||
2. POST /api/v1/agent/tasks → triggers engine stream
|
||||
3. Yield text chunks as they arrive, return on 'completed'
|
||||
"""
|
||||
agent_url = settings.agent_service_url # http://agent-service:3002
|
||||
agent_url = settings.agent_service_url
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self._auth_header:
|
||||
headers["Authorization"] = self._auth_header
|
||||
|
|
@ -248,24 +293,23 @@ class VoicePipelineTask:
|
|||
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||
ws_url = f"{ws_url}/ws/agent"
|
||||
|
||||
try:
|
||||
collected_text = []
|
||||
event_count = 0
|
||||
timeout_secs = 120 # Max wait for agent response
|
||||
event_count = 0
|
||||
text_count = 0
|
||||
total_chars = 0
|
||||
timeout_secs = 120
|
||||
|
||||
try:
|
||||
print(f"[pipeline] [AGENT] Connecting WS: {ws_url}", flush=True)
|
||||
async with websockets.connect(ws_url) as ws:
|
||||
print(f"[pipeline] [AGENT] WS connected", flush=True)
|
||||
|
||||
# 1. Subscribe FIRST (before creating task to avoid missing events)
|
||||
# 1. Pre-subscribe with existing session ID
|
||||
pre_session_id = self._agent_session_id or ""
|
||||
|
||||
if pre_session_id:
|
||||
subscribe_msg = json.dumps({
|
||||
await ws.send(json.dumps({
|
||||
"event": "subscribe_session",
|
||||
"data": {"sessionId": pre_session_id},
|
||||
})
|
||||
await ws.send(subscribe_msg)
|
||||
}))
|
||||
print(f"[pipeline] [AGENT] Pre-subscribed session={pre_session_id}", flush=True)
|
||||
|
||||
# 2. Create agent task
|
||||
|
|
@ -281,9 +325,10 @@ class VoicePipelineTask:
|
|||
headers=headers,
|
||||
)
|
||||
print(f"[pipeline] [AGENT] POST response: {resp.status_code}", flush=True)
|
||||
if resp.status_code != 200 and resp.status_code != 201:
|
||||
if resp.status_code not in (200, 201):
|
||||
print(f"[pipeline] [AGENT] Task creation FAILED: {resp.status_code} {resp.text[:200]}", flush=True)
|
||||
return "抱歉,Agent服务暂时不可用。"
|
||||
yield "抱歉,Agent服务暂时不可用。"
|
||||
return
|
||||
|
||||
data = resp.json()
|
||||
session_id = data.get("sessionId", "")
|
||||
|
|
@ -291,15 +336,14 @@ class VoicePipelineTask:
|
|||
self._agent_session_id = session_id
|
||||
print(f"[pipeline] [AGENT] Task created: session={session_id}, task={task_id}", flush=True)
|
||||
|
||||
# 3. Subscribe with actual session/task IDs (covers first-time case)
|
||||
subscribe_msg = json.dumps({
|
||||
# 3. Subscribe with actual IDs
|
||||
await ws.send(json.dumps({
|
||||
"event": "subscribe_session",
|
||||
"data": {"sessionId": session_id, "taskId": task_id},
|
||||
})
|
||||
await ws.send(subscribe_msg)
|
||||
}))
|
||||
print(f"[pipeline] [AGENT] Subscribed session={session_id}, task={task_id}", flush=True)
|
||||
|
||||
# 4. Collect events until completed
|
||||
# 4. Stream events
|
||||
deadline = time.time() + timeout_secs
|
||||
while time.time() < deadline:
|
||||
remaining = deadline - time.time()
|
||||
|
|
@ -307,7 +351,7 @@ class VoicePipelineTask:
|
|||
raw = await asyncio.wait_for(ws.recv(), timeout=min(5.0, remaining))
|
||||
except asyncio.TimeoutError:
|
||||
if time.time() >= deadline:
|
||||
print(f"[pipeline] [AGENT] TIMEOUT after {timeout_secs}s waiting for events", flush=True)
|
||||
print(f"[pipeline] [AGENT] TIMEOUT after {timeout_secs}s", flush=True)
|
||||
continue
|
||||
except Exception as ws_err:
|
||||
print(f"[pipeline] [AGENT] WS recv error: {type(ws_err).__name__}: {ws_err}", flush=True)
|
||||
|
|
@ -316,7 +360,6 @@ class VoicePipelineTask:
|
|||
try:
|
||||
msg = json.loads(raw)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
print(f"[pipeline] [AGENT] Non-JSON WS message: {str(raw)[:100]}", flush=True)
|
||||
continue
|
||||
|
||||
event_type = msg.get("event", "")
|
||||
|
|
@ -328,59 +371,42 @@ class VoicePipelineTask:
|
|||
elif event_type == "stream_event":
|
||||
evt_data = msg.get("data", {})
|
||||
evt_type = evt_data.get("type", "")
|
||||
# Engine events are flat: { type, content, summary, ... }
|
||||
# (no nested "data" sub-field)
|
||||
|
||||
if evt_type == "text":
|
||||
content = evt_data.get("content", "")
|
||||
if content:
|
||||
collected_text.append(content)
|
||||
# Log first and periodic text events
|
||||
if len(collected_text) <= 3 or len(collected_text) % 10 == 0:
|
||||
total_len = sum(len(t) for t in collected_text)
|
||||
print(f"[pipeline] [AGENT] Text event #{len(collected_text)}: +{len(content)} chars (total: {total_len})", flush=True)
|
||||
text_count += 1
|
||||
total_chars += len(content)
|
||||
if text_count <= 3 or text_count % 10 == 0:
|
||||
print(f"[pipeline] [AGENT] Text #{text_count}: +{len(content)} chars (total: {total_chars})", flush=True)
|
||||
yield content
|
||||
|
||||
elif evt_type == "completed":
|
||||
summary = evt_data.get("summary", "")
|
||||
if summary and not collected_text:
|
||||
collected_text.append(summary)
|
||||
print(f"[pipeline] [AGENT] Using summary as response: \"{summary[:100]}\"", flush=True)
|
||||
total_chars = sum(len(t) for t in collected_text)
|
||||
print(f"[pipeline] [AGENT] Completed! {len(collected_text)} text events, {total_chars} chars total, {event_count} WS events received", flush=True)
|
||||
break
|
||||
if summary and text_count == 0:
|
||||
print(f"[pipeline] [AGENT] Using summary: \"{summary[:100]}\"", flush=True)
|
||||
yield summary
|
||||
print(f"[pipeline] [AGENT] Completed! {text_count} text events, {total_chars} chars, {event_count} WS events", flush=True)
|
||||
return
|
||||
|
||||
elif evt_type == "error":
|
||||
err_msg = evt_data.get("message", "Unknown error")
|
||||
print(f"[pipeline] [AGENT] ERROR event: {err_msg}", flush=True)
|
||||
return f"Agent 错误: {err_msg}"
|
||||
|
||||
else:
|
||||
print(f"[pipeline] [AGENT] Stream event type={evt_type}", flush=True)
|
||||
|
||||
else:
|
||||
print(f"[pipeline] [AGENT] WS event: {event_type}", flush=True)
|
||||
|
||||
result = "".join(collected_text).strip()
|
||||
if not result:
|
||||
print(f"[pipeline] [AGENT] WARNING: No text collected after {event_count} events!", flush=True)
|
||||
return "Agent 未返回回复。"
|
||||
return result
|
||||
print(f"[pipeline] [AGENT] ERROR: {err_msg}", flush=True)
|
||||
yield f"Agent 错误: {err_msg}"
|
||||
return
|
||||
|
||||
except Exception as exc:
|
||||
print(f"[pipeline] Agent generate error: {type(exc).__name__}: {exc}", flush=True)
|
||||
return "抱歉,Agent服务暂时不可用,请稍后再试。"
|
||||
print(f"[pipeline] [AGENT] Stream error: {type(exc).__name__}: {exc}", flush=True)
|
||||
yield "抱歉,Agent服务暂时不可用。"
|
||||
|
||||
async def _synthesize_and_send(self, text: str):
|
||||
"""Synthesize text to speech and send audio chunks over WebSocket."""
|
||||
self._speaking = True
|
||||
self._cancelled_tts = False
|
||||
async def _synthesize_chunk(self, text: str):
|
||||
"""Synthesize a single sentence/chunk and send audio over WebSocket."""
|
||||
if self.tts is None or self.tts._pipeline is None:
|
||||
return
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
try:
|
||||
if self.tts is None or self.tts._pipeline is None:
|
||||
logger.warning("TTS not available, skipping audio response")
|
||||
return
|
||||
|
||||
# Run TTS (CPU-bound) in a thread
|
||||
audio_bytes = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self._tts_sync, text
|
||||
)
|
||||
|
|
@ -388,7 +414,7 @@ class VoicePipelineTask:
|
|||
if not audio_bytes or self._cancelled_tts:
|
||||
return
|
||||
|
||||
# Send audio in chunks
|
||||
# Send audio in WS-sized chunks
|
||||
offset = 0
|
||||
while offset < len(audio_bytes) and not self._cancelled_tts:
|
||||
end = min(offset + _WS_AUDIO_CHUNK, len(audio_bytes))
|
||||
|
|
@ -397,13 +423,10 @@ class VoicePipelineTask:
|
|||
except Exception:
|
||||
break
|
||||
offset = end
|
||||
# Small yield to not starve the event loop
|
||||
await asyncio.sleep(0.01)
|
||||
await asyncio.sleep(0.005)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error("TTS/send error: %s", exc)
|
||||
finally:
|
||||
self._speaking = False
|
||||
logger.error("TTS chunk error: %s", exc)
|
||||
|
||||
def _tts_sync(self, text: str) -> bytes:
|
||||
"""Synchronous TTS synthesis (runs in thread pool)."""
|
||||
|
|
@ -416,13 +439,9 @@ class VoicePipelineTask:
|
|||
return b""
|
||||
|
||||
audio_np = np.concatenate(samples)
|
||||
# Kokoro outputs at 24kHz, we need 16kHz
|
||||
# Resample using linear interpolation
|
||||
# Kokoro outputs at 24kHz, resample to 16kHz
|
||||
if len(audio_np) > 0:
|
||||
original_rate = 24000
|
||||
target_rate = _SAMPLE_RATE
|
||||
duration = len(audio_np) / original_rate
|
||||
target_samples = int(duration * target_rate)
|
||||
target_samples = int(len(audio_np) / 24000 * _SAMPLE_RATE)
|
||||
indices = np.linspace(0, len(audio_np) - 1, target_samples)
|
||||
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
|
||||
return (resampled * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
|
||||
|
|
|
|||
Loading…
Reference in New Issue