it0/packages/services/voice-service/src/pipeline/base_pipeline.py

574 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Voice dialogue pipeline — direct WebSocket audio I/O.
Pipeline: Audio Input → VAD → STT → Agent Service → TTS → Audio Output
Runs as an async task that reads binary PCM frames from a FastAPI WebSocket,
detects speech with VAD, transcribes with STT, sends text to the Agent
service (which uses Claude SDK with tools), and synthesizes the response
back as speech via TTS.
"""
import asyncio
import io
import json
import logging
import os
import re
import struct
import tempfile
import time
from typing import AsyncGenerator, Optional
import httpx
import numpy as np
import websockets
from fastapi import WebSocket
from ..config.settings import settings
from ..stt.whisper_service import WhisperSTTService
from ..tts.kokoro_service import KokoroTTSService
from ..vad.silero_service import SileroVADService
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
# Minimum speech duration in seconds before we transcribe
_MIN_SPEECH_SECS = 0.5
# Silence duration in seconds after speech ends before we process
_SILENCE_AFTER_SPEECH_SECS = 0.8
# Sample rate
_SAMPLE_RATE = 16000
# Bytes per sample (16-bit PCM mono)
_BYTES_PER_SAMPLE = 2
# VAD chunk size (512 samples = 32ms at 16kHz, Silero expects this)
_VAD_CHUNK_SAMPLES = 512
_VAD_CHUNK_BYTES = _VAD_CHUNK_SAMPLES * _BYTES_PER_SAMPLE
# Max audio output chunk size sent over WebSocket (4KB)
_WS_AUDIO_CHUNK = 4096
# Sentence-ending punctuation for splitting TTS chunks
_SENTENCE_END_RE = re.compile(r'[。!?;\n]|[.!?;]\s')
# Minimum chars before we flush a sentence to TTS
_MIN_SENTENCE_LEN = 4
class VoicePipelineTask:
"""Async voice pipeline that bridges a FastAPI WebSocket to STT/LLM/TTS."""
def __init__(
self,
websocket: WebSocket,
session_context: dict,
*,
stt: Optional[WhisperSTTService] = None,
tts: Optional[KokoroTTSService] = None,
vad: Optional[SileroVADService] = None,
openai_client=None,
):
self.websocket = websocket
self.session_context = session_context
self.stt = stt
self.tts = tts
self.vad = vad
self.openai_client = openai_client
# Agent session ID (reused across turns for conversation continuity)
self._agent_session_id: Optional[str] = None
self._auth_header: str = session_context.get("auth_header", "")
self._cancelled = False
self._speaking = False # True while sending TTS audio to client
async def run(self):
"""Main loop: read audio → VAD → STT → LLM → TTS → send audio."""
sid = self.session_context.get("session_id", "?")
print(f"[pipeline] Started for session {sid}", flush=True)
print(f"[pipeline] STT={self.stt is not None and self.stt._model is not None}, "
f"TTS={self.tts is not None and self.tts._pipeline is not None}, "
f"VAD={self.vad is not None and self.vad._model is not None}", flush=True)
# Audio buffer for accumulating speech
speech_buffer = bytearray()
vad_buffer = bytearray() # accumulates until _VAD_CHUNK_BYTES
is_speech_active = False
silence_start: Optional[float] = None
speech_start: Optional[float] = None
chunk_count = 0
try:
while not self._cancelled:
try:
data = await asyncio.wait_for(
self.websocket.receive_bytes(), timeout=30.0
)
except asyncio.TimeoutError:
print(f"[pipeline] No data for 30s, session {sid}", flush=True)
continue
except Exception as exc:
print(f"[pipeline] WebSocket read error: {type(exc).__name__}: {exc}", flush=True)
break
chunk_count += 1
if chunk_count <= 3 or chunk_count % 100 == 0:
print(f"[pipeline] Received audio chunk #{chunk_count}, len={len(data)}", flush=True)
# Accumulate into VAD-sized chunks
vad_buffer.extend(data)
while len(vad_buffer) >= _VAD_CHUNK_BYTES:
chunk = bytes(vad_buffer[:_VAD_CHUNK_BYTES])
del vad_buffer[:_VAD_CHUNK_BYTES]
# Run VAD
has_speech = self._detect_speech(chunk)
if has_speech:
if not is_speech_active:
is_speech_active = True
speech_start = time.time()
silence_start = None
print(f"[pipeline] Speech detected!", flush=True)
# Barge-in: if we were speaking TTS, stop
if self._speaking:
self._cancelled_tts = True
print("[pipeline] Barge-in!", flush=True)
speech_buffer.extend(chunk)
silence_start = None
else:
if is_speech_active:
# Still accumulate a bit during silence gap
speech_buffer.extend(chunk)
if silence_start is None:
silence_start = time.time()
elif time.time() - silence_start >= _SILENCE_AFTER_SPEECH_SECS:
# Silence detected after speech — process
speech_duration = time.time() - (speech_start or time.time())
buf_secs = len(speech_buffer) / (_SAMPLE_RATE * _BYTES_PER_SAMPLE)
print(f"[pipeline] Silence after speech: duration={speech_duration:.1f}s, buffer={buf_secs:.1f}s", flush=True)
if speech_duration >= _MIN_SPEECH_SECS and len(speech_buffer) > 0:
await self._process_speech(bytes(speech_buffer))
# Reset
speech_buffer.clear()
is_speech_active = False
silence_start = None
speech_start = None
except asyncio.CancelledError:
print(f"[pipeline] Cancelled, session {sid}", flush=True)
except Exception as exc:
print(f"[pipeline] ERROR: {type(exc).__name__}: {exc}", flush=True)
import traceback
traceback.print_exc()
finally:
print(f"[pipeline] Ended for session {sid}", flush=True)
def cancel(self):
self._cancelled = True
def _detect_speech(self, chunk: bytes) -> bool:
"""Run VAD on a single chunk. Returns True if speech detected."""
if self.vad is None or self.vad._model is None:
# No VAD — treat everything as speech
return True
try:
return self.vad.detect_speech(chunk)
except Exception as exc:
logger.debug("VAD error: %s", exc)
return True # Assume speech on error
async def _process_speech(self, audio_data: bytes):
"""Transcribe speech, stream agent response, synthesize TTS per-sentence."""
session_id = self.session_context.get("session_id", "?")
audio_secs = len(audio_data) / (_SAMPLE_RATE * _BYTES_PER_SAMPLE)
print(f"[pipeline] ===== Processing speech: {audio_secs:.1f}s of audio, session={session_id} =====", flush=True)
# 1. STT
t0 = time.time()
text = await self._transcribe(audio_data)
stt_ms = int((time.time() - t0) * 1000)
if not text or not text.strip():
print(f"[pipeline] STT returned empty text (took {stt_ms}ms), skipping", flush=True)
return
print(f"[pipeline] [STT] ({stt_ms}ms) User said: \"{text.strip()}\"", flush=True)
# Notify client of user transcript
try:
await self.websocket.send_text(
json.dumps({"type": "transcript", "text": text.strip(), "role": "user"})
)
except Exception:
pass
# 2. Stream agent response → sentence-split → TTS per sentence
print(f"[pipeline] [AGENT] Sending to agent: \"{text.strip()}\"", flush=True)
t1 = time.time()
self._speaking = True
self._cancelled_tts = False
sentence_buf = ""
full_response = []
tts_count = 0
first_audio_ms = None
try:
async for chunk in self._agent_stream(text.strip()):
if self._cancelled_tts:
print(f"[pipeline] [TTS] Barge-in, stopping TTS stream", flush=True)
break
full_response.append(chunk)
sentence_buf += chunk
# Check for sentence boundaries
search_start = 0
while True:
match = _SENTENCE_END_RE.search(sentence_buf, search_start)
if not match:
break
# Skip matches that are too early — sentence too short
if match.end() < _MIN_SENTENCE_LEN:
search_start = match.end()
continue
sentence = sentence_buf[:match.end()].strip()
sentence_buf = sentence_buf[match.end():]
search_start = 0
if sentence:
tts_count += 1
if first_audio_ms is None:
first_audio_ms = int((time.time() - t1) * 1000)
print(f"[pipeline] [TTS] First sentence ready at {first_audio_ms}ms: \"{sentence[:60]}\"", flush=True)
else:
print(f"[pipeline] [TTS] Sentence #{tts_count}: \"{sentence[:60]}\"", flush=True)
await self._synthesize_chunk(sentence)
# Flush remaining buffer
remaining = sentence_buf.strip()
if remaining and not self._cancelled_tts:
tts_count += 1
if first_audio_ms is None:
first_audio_ms = int((time.time() - t1) * 1000)
print(f"[pipeline] [TTS] Final flush #{tts_count}: \"{remaining[:60]}\"", flush=True)
await self._synthesize_chunk(remaining)
except Exception as exc:
print(f"[pipeline] Stream/TTS error: {type(exc).__name__}: {exc}", flush=True)
finally:
self._speaking = False
total_ms = int((time.time() - t1) * 1000)
response_text = "".join(full_response).strip()
if response_text:
print(f"[pipeline] [AGENT] Response ({len(response_text)} chars): \"{response_text[:200]}\"", flush=True)
# Send full transcript to client
try:
await self.websocket.send_text(
json.dumps({"type": "transcript", "text": response_text, "role": "assistant"})
)
except Exception:
pass
else:
print(f"[pipeline] [AGENT] No response text collected", flush=True)
print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Stream+TTS={total_ms}ms (first audio at {first_audio_ms or 0}ms, {tts_count} chunks) =====", flush=True)
async def _transcribe(self, audio_data: bytes) -> str:
"""Transcribe audio using STT service (local or OpenAI)."""
if settings.stt_provider == "openai":
return await self._transcribe_openai(audio_data)
if self.stt is None or self.stt._model is None:
logger.warning("STT not available")
return ""
try:
return await self.stt.transcribe(audio_data)
except Exception as exc:
logger.error("STT error: %s", exc)
return ""
async def _transcribe_openai(self, audio_data: bytes) -> str:
"""Transcribe audio using OpenAI API."""
if self.openai_client is None:
logger.warning("OpenAI client not available for STT")
return ""
try:
# Write PCM to a temp WAV file for the API
wav_bytes = _pcm_to_wav(audio_data, _SAMPLE_RATE)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
tmp.write(wav_bytes)
tmp.close()
client = self.openai_client
model = settings.openai_stt_model
def _do():
with open(tmp.name, "rb") as f:
result = client.audio.transcriptions.create(
model=model, file=f, language="zh",
)
return result.text
text = await asyncio.get_event_loop().run_in_executor(None, _do)
os.unlink(tmp.name)
return text or ""
except Exception as exc:
logger.error("OpenAI STT error: %s", exc)
return ""
async def _agent_stream(self, user_text: str) -> AsyncGenerator[str, None]:
"""Stream text from agent-service, yielding chunks as they arrive.
Flow:
1. WS connect to /ws/agent → subscribe_session
2. POST /api/v1/agent/tasks → triggers engine stream
3. Yield text chunks as they arrive, return on 'completed'
"""
agent_url = settings.agent_service_url
headers = {"Content-Type": "application/json"}
if self._auth_header:
headers["Authorization"] = self._auth_header
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/ws/agent"
event_count = 0
text_count = 0
total_chars = 0
timeout_secs = 120
try:
print(f"[pipeline] [AGENT] Connecting WS: {ws_url}", flush=True)
async with websockets.connect(ws_url) as ws:
print(f"[pipeline] [AGENT] WS connected", flush=True)
# 1. Pre-subscribe with existing session ID
pre_session_id = self._agent_session_id or ""
if pre_session_id:
await ws.send(json.dumps({
"event": "subscribe_session",
"data": {"sessionId": pre_session_id},
}))
print(f"[pipeline] [AGENT] Pre-subscribed session={pre_session_id}", flush=True)
# 2. Create agent task (use claude_api engine for streaming TTS)
body = {"prompt": user_text, "engineType": "claude_api"}
if self._agent_session_id:
body["sessionId"] = self._agent_session_id
print(f"[pipeline] [AGENT] POST /tasks prompt=\"{user_text[:80]}\"", flush=True)
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{agent_url}/api/v1/agent/tasks",
json=body,
headers=headers,
)
print(f"[pipeline] [AGENT] POST response: {resp.status_code}", flush=True)
if resp.status_code not in (200, 201):
print(f"[pipeline] [AGENT] Task creation FAILED: {resp.status_code} {resp.text[:200]}", flush=True)
yield "抱歉Agent服务暂时不可用。"
return
data = resp.json()
session_id = data.get("sessionId", "")
task_id = data.get("taskId", "")
self._agent_session_id = session_id
print(f"[pipeline] [AGENT] Task created: session={session_id}, task={task_id}", flush=True)
# 3. Subscribe with actual IDs
await ws.send(json.dumps({
"event": "subscribe_session",
"data": {"sessionId": session_id, "taskId": task_id},
}))
print(f"[pipeline] [AGENT] Subscribed session={session_id}, task={task_id}", flush=True)
# 4. Stream events
deadline = time.time() + timeout_secs
while time.time() < deadline:
remaining = deadline - time.time()
try:
raw = await asyncio.wait_for(ws.recv(), timeout=min(5.0, remaining))
except asyncio.TimeoutError:
if time.time() >= deadline:
print(f"[pipeline] [AGENT] TIMEOUT after {timeout_secs}s", flush=True)
continue
except Exception as ws_err:
print(f"[pipeline] [AGENT] WS recv error: {type(ws_err).__name__}: {ws_err}", flush=True)
break
try:
msg = json.loads(raw)
except (json.JSONDecodeError, TypeError):
continue
event_type = msg.get("event", "")
event_count += 1
if event_type == "subscribed":
print(f"[pipeline] [AGENT] Subscription confirmed: {msg.get('data', {})}", flush=True)
elif event_type == "stream_event":
evt_data = msg.get("data", {})
evt_type = evt_data.get("type", "")
if evt_type == "text":
content = evt_data.get("content", "")
if content:
text_count += 1
total_chars += len(content)
if text_count <= 3 or text_count % 10 == 0:
print(f"[pipeline] [AGENT] Text #{text_count}: +{len(content)} chars (total: {total_chars})", flush=True)
yield content
elif evt_type == "completed":
summary = evt_data.get("summary", "")
if summary and text_count == 0:
print(f"[pipeline] [AGENT] Using summary: \"{summary[:100]}\"", flush=True)
yield summary
print(f"[pipeline] [AGENT] Completed! {text_count} text events, {total_chars} chars, {event_count} WS events", flush=True)
return
elif evt_type == "error":
err_msg = evt_data.get("message", "Unknown error")
print(f"[pipeline] [AGENT] ERROR: {err_msg}", flush=True)
yield f"Agent 错误: {err_msg}"
return
except Exception as exc:
print(f"[pipeline] [AGENT] Stream error: {type(exc).__name__}: {exc}", flush=True)
yield "抱歉Agent服务暂时不可用。"
async def _synthesize_chunk(self, text: str):
"""Synthesize a single sentence/chunk and send audio over WebSocket."""
if not text.strip():
return
try:
if settings.tts_provider == "openai":
audio_bytes = await self._tts_openai(text)
else:
if self.tts is None or self.tts._pipeline is None:
return
audio_bytes = await asyncio.get_event_loop().run_in_executor(
None, self._tts_sync, text
)
if not audio_bytes or self._cancelled_tts:
return
# Send audio in WS-sized chunks
sent_bytes = 0
offset = 0
while offset < len(audio_bytes) and not self._cancelled_tts:
end = min(offset + _WS_AUDIO_CHUNK, len(audio_bytes))
try:
await self.websocket.send_bytes(audio_bytes[offset:end])
sent_bytes += end - offset
except Exception as e:
print(f"[pipeline] [TTS] send_bytes failed at {offset}/{len(audio_bytes)}: {e}", flush=True)
break
offset = end
await asyncio.sleep(0.005)
print(f"[pipeline] [TTS] Sent {sent_bytes}/{len(audio_bytes)} bytes for: \"{text[:20]}\"", flush=True)
except Exception as exc:
logger.error("TTS chunk error: %s", exc)
async def _tts_openai(self, text: str) -> bytes:
"""Synthesize text to 16kHz PCM via OpenAI TTS API."""
if self.openai_client is None:
logger.warning("OpenAI client not available for TTS")
return b""
try:
client = self.openai_client
model = settings.openai_tts_model
voice = settings.openai_tts_voice
def _do():
response = client.audio.speech.create(
model=model, voice=voice, input=text,
response_format="pcm", # raw 24kHz 16-bit mono PCM
)
return response.content
pcm_24k = await asyncio.get_event_loop().run_in_executor(None, _do)
# OpenAI returns 24kHz 16-bit mono PCM, resample to 16kHz
audio_np = np.frombuffer(pcm_24k, dtype=np.int16).astype(np.float32)
if len(audio_np) > 0:
target_samples = int(len(audio_np) / 24000 * _SAMPLE_RATE)
indices = np.linspace(0, len(audio_np) - 1, target_samples)
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
return resampled.astype(np.int16).tobytes()
return b""
except Exception as exc:
logger.error("OpenAI TTS error: %s", exc)
return b""
def _tts_sync(self, text: str) -> bytes:
"""Synchronous TTS synthesis with Kokoro (runs in thread pool)."""
try:
samples = []
for _, _, audio in self.tts._pipeline(text, voice=self.tts.voice):
samples.append(audio)
if not samples:
return b""
audio_np = np.concatenate(samples)
# Kokoro outputs at 24kHz, resample to 16kHz
if len(audio_np) > 0:
target_samples = int(len(audio_np) / 24000 * _SAMPLE_RATE)
indices = np.linspace(0, len(audio_np) - 1, target_samples)
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
return (resampled * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
return (audio_np * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
except Exception as exc:
logger.error("TTS synthesis error: %s", exc)
return b""
def _pcm_to_wav(pcm_bytes: bytes, sample_rate: int) -> bytes:
"""Wrap raw 16-bit mono PCM into a WAV container."""
buf = io.BytesIO()
data_size = len(pcm_bytes)
buf.write(b"RIFF")
buf.write(struct.pack("<I", 36 + data_size))
buf.write(b"WAVE")
buf.write(b"fmt ")
buf.write(struct.pack("<I", 16))
buf.write(struct.pack("<H", 1)) # PCM
buf.write(struct.pack("<H", 1)) # mono
buf.write(struct.pack("<I", sample_rate))
buf.write(struct.pack("<I", sample_rate * 2)) # byte rate
buf.write(struct.pack("<H", 2)) # block align
buf.write(struct.pack("<H", 16)) # bits per sample
buf.write(b"data")
buf.write(struct.pack("<I", data_size))
buf.write(pcm_bytes)
return buf.getvalue()
async def create_voice_pipeline(
websocket: WebSocket,
session_context: dict,
*,
stt=None,
tts=None,
vad=None,
openai_client=None,
) -> VoicePipelineTask:
"""Create a voice pipeline task for the given WebSocket connection."""
return VoicePipelineTask(
websocket,
session_context,
stt=stt,
tts=tts,
vad=vad,
openai_client=openai_client,
)