574 lines
24 KiB
Python
574 lines
24 KiB
Python
"""
|
||
Voice dialogue pipeline — direct WebSocket audio I/O.
|
||
|
||
Pipeline: Audio Input → VAD → STT → Agent Service → TTS → Audio Output
|
||
|
||
Runs as an async task that reads binary PCM frames from a FastAPI WebSocket,
|
||
detects speech with VAD, transcribes with STT, sends text to the Agent
|
||
service (which uses Claude SDK with tools), and synthesizes the response
|
||
back as speech via TTS.
|
||
"""
|
||
|
||
import asyncio
|
||
import io
|
||
import json
|
||
import logging
|
||
import os
|
||
import re
|
||
import struct
|
||
import tempfile
|
||
import time
|
||
from typing import AsyncGenerator, Optional
|
||
|
||
import httpx
|
||
import numpy as np
|
||
import websockets
|
||
from fastapi import WebSocket
|
||
|
||
from ..config.settings import settings
|
||
from ..stt.whisper_service import WhisperSTTService
|
||
from ..tts.kokoro_service import KokoroTTSService
|
||
from ..vad.silero_service import SileroVADService
|
||
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
logger.setLevel(logging.INFO)
|
||
|
||
# Minimum speech duration in seconds before we transcribe
|
||
_MIN_SPEECH_SECS = 0.5
|
||
# Silence duration in seconds after speech ends before we process
|
||
_SILENCE_AFTER_SPEECH_SECS = 0.8
|
||
# Sample rate
|
||
_SAMPLE_RATE = 16000
|
||
# Bytes per sample (16-bit PCM mono)
|
||
_BYTES_PER_SAMPLE = 2
|
||
# VAD chunk size (512 samples = 32ms at 16kHz, Silero expects this)
|
||
_VAD_CHUNK_SAMPLES = 512
|
||
_VAD_CHUNK_BYTES = _VAD_CHUNK_SAMPLES * _BYTES_PER_SAMPLE
|
||
# Max audio output chunk size sent over WebSocket (4KB)
|
||
_WS_AUDIO_CHUNK = 4096
|
||
# Sentence-ending punctuation for splitting TTS chunks
|
||
_SENTENCE_END_RE = re.compile(r'[。!?;\n]|[.!?;]\s')
|
||
# Minimum chars before we flush a sentence to TTS
|
||
_MIN_SENTENCE_LEN = 4
|
||
|
||
|
||
class VoicePipelineTask:
|
||
"""Async voice pipeline that bridges a FastAPI WebSocket to STT/LLM/TTS."""
|
||
|
||
def __init__(
|
||
self,
|
||
websocket: WebSocket,
|
||
session_context: dict,
|
||
*,
|
||
stt: Optional[WhisperSTTService] = None,
|
||
tts: Optional[KokoroTTSService] = None,
|
||
vad: Optional[SileroVADService] = None,
|
||
openai_client=None,
|
||
):
|
||
self.websocket = websocket
|
||
self.session_context = session_context
|
||
self.stt = stt
|
||
self.tts = tts
|
||
self.vad = vad
|
||
self.openai_client = openai_client
|
||
|
||
# Agent session ID (reused across turns for conversation continuity)
|
||
self._agent_session_id: Optional[str] = None
|
||
self._auth_header: str = session_context.get("auth_header", "")
|
||
|
||
self._cancelled = False
|
||
self._speaking = False # True while sending TTS audio to client
|
||
|
||
async def run(self):
|
||
"""Main loop: read audio → VAD → STT → LLM → TTS → send audio."""
|
||
sid = self.session_context.get("session_id", "?")
|
||
print(f"[pipeline] Started for session {sid}", flush=True)
|
||
print(f"[pipeline] STT={self.stt is not None and self.stt._model is not None}, "
|
||
f"TTS={self.tts is not None and self.tts._pipeline is not None}, "
|
||
f"VAD={self.vad is not None and self.vad._model is not None}", flush=True)
|
||
|
||
# Audio buffer for accumulating speech
|
||
speech_buffer = bytearray()
|
||
vad_buffer = bytearray() # accumulates until _VAD_CHUNK_BYTES
|
||
is_speech_active = False
|
||
silence_start: Optional[float] = None
|
||
speech_start: Optional[float] = None
|
||
chunk_count = 0
|
||
|
||
try:
|
||
while not self._cancelled:
|
||
try:
|
||
data = await asyncio.wait_for(
|
||
self.websocket.receive_bytes(), timeout=30.0
|
||
)
|
||
except asyncio.TimeoutError:
|
||
print(f"[pipeline] No data for 30s, session {sid}", flush=True)
|
||
continue
|
||
except Exception as exc:
|
||
print(f"[pipeline] WebSocket read error: {type(exc).__name__}: {exc}", flush=True)
|
||
break
|
||
|
||
chunk_count += 1
|
||
if chunk_count <= 3 or chunk_count % 100 == 0:
|
||
print(f"[pipeline] Received audio chunk #{chunk_count}, len={len(data)}", flush=True)
|
||
|
||
# Accumulate into VAD-sized chunks
|
||
vad_buffer.extend(data)
|
||
|
||
while len(vad_buffer) >= _VAD_CHUNK_BYTES:
|
||
chunk = bytes(vad_buffer[:_VAD_CHUNK_BYTES])
|
||
del vad_buffer[:_VAD_CHUNK_BYTES]
|
||
|
||
# Run VAD
|
||
has_speech = self._detect_speech(chunk)
|
||
|
||
if has_speech:
|
||
if not is_speech_active:
|
||
is_speech_active = True
|
||
speech_start = time.time()
|
||
silence_start = None
|
||
print(f"[pipeline] Speech detected!", flush=True)
|
||
# Barge-in: if we were speaking TTS, stop
|
||
if self._speaking:
|
||
self._cancelled_tts = True
|
||
print("[pipeline] Barge-in!", flush=True)
|
||
|
||
speech_buffer.extend(chunk)
|
||
silence_start = None
|
||
else:
|
||
if is_speech_active:
|
||
# Still accumulate a bit during silence gap
|
||
speech_buffer.extend(chunk)
|
||
|
||
if silence_start is None:
|
||
silence_start = time.time()
|
||
elif time.time() - silence_start >= _SILENCE_AFTER_SPEECH_SECS:
|
||
# Silence detected after speech — process
|
||
speech_duration = time.time() - (speech_start or time.time())
|
||
buf_secs = len(speech_buffer) / (_SAMPLE_RATE * _BYTES_PER_SAMPLE)
|
||
print(f"[pipeline] Silence after speech: duration={speech_duration:.1f}s, buffer={buf_secs:.1f}s", flush=True)
|
||
if speech_duration >= _MIN_SPEECH_SECS and len(speech_buffer) > 0:
|
||
await self._process_speech(bytes(speech_buffer))
|
||
|
||
# Reset
|
||
speech_buffer.clear()
|
||
is_speech_active = False
|
||
silence_start = None
|
||
speech_start = None
|
||
|
||
except asyncio.CancelledError:
|
||
print(f"[pipeline] Cancelled, session {sid}", flush=True)
|
||
except Exception as exc:
|
||
print(f"[pipeline] ERROR: {type(exc).__name__}: {exc}", flush=True)
|
||
import traceback
|
||
traceback.print_exc()
|
||
finally:
|
||
print(f"[pipeline] Ended for session {sid}", flush=True)
|
||
|
||
def cancel(self):
|
||
self._cancelled = True
|
||
|
||
def _detect_speech(self, chunk: bytes) -> bool:
|
||
"""Run VAD on a single chunk. Returns True if speech detected."""
|
||
if self.vad is None or self.vad._model is None:
|
||
# No VAD — treat everything as speech
|
||
return True
|
||
try:
|
||
return self.vad.detect_speech(chunk)
|
||
except Exception as exc:
|
||
logger.debug("VAD error: %s", exc)
|
||
return True # Assume speech on error
|
||
|
||
async def _process_speech(self, audio_data: bytes):
|
||
"""Transcribe speech, stream agent response, synthesize TTS per-sentence."""
|
||
session_id = self.session_context.get("session_id", "?")
|
||
audio_secs = len(audio_data) / (_SAMPLE_RATE * _BYTES_PER_SAMPLE)
|
||
print(f"[pipeline] ===== Processing speech: {audio_secs:.1f}s of audio, session={session_id} =====", flush=True)
|
||
|
||
# 1. STT
|
||
t0 = time.time()
|
||
text = await self._transcribe(audio_data)
|
||
stt_ms = int((time.time() - t0) * 1000)
|
||
if not text or not text.strip():
|
||
print(f"[pipeline] STT returned empty text (took {stt_ms}ms), skipping", flush=True)
|
||
return
|
||
|
||
print(f"[pipeline] [STT] ({stt_ms}ms) User said: \"{text.strip()}\"", flush=True)
|
||
|
||
# Notify client of user transcript
|
||
try:
|
||
await self.websocket.send_text(
|
||
json.dumps({"type": "transcript", "text": text.strip(), "role": "user"})
|
||
)
|
||
except Exception:
|
||
pass
|
||
|
||
# 2. Stream agent response → sentence-split → TTS per sentence
|
||
print(f"[pipeline] [AGENT] Sending to agent: \"{text.strip()}\"", flush=True)
|
||
t1 = time.time()
|
||
self._speaking = True
|
||
self._cancelled_tts = False
|
||
|
||
sentence_buf = ""
|
||
full_response = []
|
||
tts_count = 0
|
||
first_audio_ms = None
|
||
|
||
try:
|
||
async for chunk in self._agent_stream(text.strip()):
|
||
if self._cancelled_tts:
|
||
print(f"[pipeline] [TTS] Barge-in, stopping TTS stream", flush=True)
|
||
break
|
||
|
||
full_response.append(chunk)
|
||
sentence_buf += chunk
|
||
|
||
# Check for sentence boundaries
|
||
search_start = 0
|
||
while True:
|
||
match = _SENTENCE_END_RE.search(sentence_buf, search_start)
|
||
if not match:
|
||
break
|
||
# Skip matches that are too early — sentence too short
|
||
if match.end() < _MIN_SENTENCE_LEN:
|
||
search_start = match.end()
|
||
continue
|
||
sentence = sentence_buf[:match.end()].strip()
|
||
sentence_buf = sentence_buf[match.end():]
|
||
search_start = 0
|
||
if sentence:
|
||
tts_count += 1
|
||
if first_audio_ms is None:
|
||
first_audio_ms = int((time.time() - t1) * 1000)
|
||
print(f"[pipeline] [TTS] First sentence ready at {first_audio_ms}ms: \"{sentence[:60]}\"", flush=True)
|
||
else:
|
||
print(f"[pipeline] [TTS] Sentence #{tts_count}: \"{sentence[:60]}\"", flush=True)
|
||
await self._synthesize_chunk(sentence)
|
||
|
||
# Flush remaining buffer
|
||
remaining = sentence_buf.strip()
|
||
if remaining and not self._cancelled_tts:
|
||
tts_count += 1
|
||
if first_audio_ms is None:
|
||
first_audio_ms = int((time.time() - t1) * 1000)
|
||
print(f"[pipeline] [TTS] Final flush #{tts_count}: \"{remaining[:60]}\"", flush=True)
|
||
await self._synthesize_chunk(remaining)
|
||
|
||
except Exception as exc:
|
||
print(f"[pipeline] Stream/TTS error: {type(exc).__name__}: {exc}", flush=True)
|
||
finally:
|
||
self._speaking = False
|
||
|
||
total_ms = int((time.time() - t1) * 1000)
|
||
response_text = "".join(full_response).strip()
|
||
|
||
if response_text:
|
||
print(f"[pipeline] [AGENT] Response ({len(response_text)} chars): \"{response_text[:200]}\"", flush=True)
|
||
# Send full transcript to client
|
||
try:
|
||
await self.websocket.send_text(
|
||
json.dumps({"type": "transcript", "text": response_text, "role": "assistant"})
|
||
)
|
||
except Exception:
|
||
pass
|
||
else:
|
||
print(f"[pipeline] [AGENT] No response text collected", flush=True)
|
||
|
||
print(f"[pipeline] ===== Turn complete: STT={stt_ms}ms + Stream+TTS={total_ms}ms (first audio at {first_audio_ms or 0}ms, {tts_count} chunks) =====", flush=True)
|
||
|
||
async def _transcribe(self, audio_data: bytes) -> str:
|
||
"""Transcribe audio using STT service (local or OpenAI)."""
|
||
if settings.stt_provider == "openai":
|
||
return await self._transcribe_openai(audio_data)
|
||
|
||
if self.stt is None or self.stt._model is None:
|
||
logger.warning("STT not available")
|
||
return ""
|
||
try:
|
||
return await self.stt.transcribe(audio_data)
|
||
except Exception as exc:
|
||
logger.error("STT error: %s", exc)
|
||
return ""
|
||
|
||
async def _transcribe_openai(self, audio_data: bytes) -> str:
|
||
"""Transcribe audio using OpenAI API."""
|
||
if self.openai_client is None:
|
||
logger.warning("OpenAI client not available for STT")
|
||
return ""
|
||
try:
|
||
# Write PCM to a temp WAV file for the API
|
||
wav_bytes = _pcm_to_wav(audio_data, _SAMPLE_RATE)
|
||
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
||
tmp.write(wav_bytes)
|
||
tmp.close()
|
||
|
||
client = self.openai_client
|
||
model = settings.openai_stt_model
|
||
|
||
def _do():
|
||
with open(tmp.name, "rb") as f:
|
||
result = client.audio.transcriptions.create(
|
||
model=model, file=f, language="zh",
|
||
)
|
||
return result.text
|
||
|
||
text = await asyncio.get_event_loop().run_in_executor(None, _do)
|
||
os.unlink(tmp.name)
|
||
return text or ""
|
||
except Exception as exc:
|
||
logger.error("OpenAI STT error: %s", exc)
|
||
return ""
|
||
|
||
async def _agent_stream(self, user_text: str) -> AsyncGenerator[str, None]:
|
||
"""Stream text from agent-service, yielding chunks as they arrive.
|
||
|
||
Flow:
|
||
1. WS connect to /ws/agent → subscribe_session
|
||
2. POST /api/v1/agent/tasks → triggers engine stream
|
||
3. Yield text chunks as they arrive, return on 'completed'
|
||
"""
|
||
agent_url = settings.agent_service_url
|
||
headers = {"Content-Type": "application/json"}
|
||
if self._auth_header:
|
||
headers["Authorization"] = self._auth_header
|
||
|
||
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
|
||
ws_url = f"{ws_url}/ws/agent"
|
||
|
||
event_count = 0
|
||
text_count = 0
|
||
total_chars = 0
|
||
timeout_secs = 120
|
||
|
||
try:
|
||
print(f"[pipeline] [AGENT] Connecting WS: {ws_url}", flush=True)
|
||
async with websockets.connect(ws_url) as ws:
|
||
print(f"[pipeline] [AGENT] WS connected", flush=True)
|
||
|
||
# 1. Pre-subscribe with existing session ID
|
||
pre_session_id = self._agent_session_id or ""
|
||
if pre_session_id:
|
||
await ws.send(json.dumps({
|
||
"event": "subscribe_session",
|
||
"data": {"sessionId": pre_session_id},
|
||
}))
|
||
print(f"[pipeline] [AGENT] Pre-subscribed session={pre_session_id}", flush=True)
|
||
|
||
# 2. Create agent task (use claude_api engine for streaming TTS)
|
||
body = {"prompt": user_text, "engineType": "claude_api"}
|
||
if self._agent_session_id:
|
||
body["sessionId"] = self._agent_session_id
|
||
|
||
print(f"[pipeline] [AGENT] POST /tasks prompt=\"{user_text[:80]}\"", flush=True)
|
||
async with httpx.AsyncClient(timeout=30) as client:
|
||
resp = await client.post(
|
||
f"{agent_url}/api/v1/agent/tasks",
|
||
json=body,
|
||
headers=headers,
|
||
)
|
||
print(f"[pipeline] [AGENT] POST response: {resp.status_code}", flush=True)
|
||
if resp.status_code not in (200, 201):
|
||
print(f"[pipeline] [AGENT] Task creation FAILED: {resp.status_code} {resp.text[:200]}", flush=True)
|
||
yield "抱歉,Agent服务暂时不可用。"
|
||
return
|
||
|
||
data = resp.json()
|
||
session_id = data.get("sessionId", "")
|
||
task_id = data.get("taskId", "")
|
||
self._agent_session_id = session_id
|
||
print(f"[pipeline] [AGENT] Task created: session={session_id}, task={task_id}", flush=True)
|
||
|
||
# 3. Subscribe with actual IDs
|
||
await ws.send(json.dumps({
|
||
"event": "subscribe_session",
|
||
"data": {"sessionId": session_id, "taskId": task_id},
|
||
}))
|
||
print(f"[pipeline] [AGENT] Subscribed session={session_id}, task={task_id}", flush=True)
|
||
|
||
# 4. Stream events
|
||
deadline = time.time() + timeout_secs
|
||
while time.time() < deadline:
|
||
remaining = deadline - time.time()
|
||
try:
|
||
raw = await asyncio.wait_for(ws.recv(), timeout=min(5.0, remaining))
|
||
except asyncio.TimeoutError:
|
||
if time.time() >= deadline:
|
||
print(f"[pipeline] [AGENT] TIMEOUT after {timeout_secs}s", flush=True)
|
||
continue
|
||
except Exception as ws_err:
|
||
print(f"[pipeline] [AGENT] WS recv error: {type(ws_err).__name__}: {ws_err}", flush=True)
|
||
break
|
||
|
||
try:
|
||
msg = json.loads(raw)
|
||
except (json.JSONDecodeError, TypeError):
|
||
continue
|
||
|
||
event_type = msg.get("event", "")
|
||
event_count += 1
|
||
|
||
if event_type == "subscribed":
|
||
print(f"[pipeline] [AGENT] Subscription confirmed: {msg.get('data', {})}", flush=True)
|
||
|
||
elif event_type == "stream_event":
|
||
evt_data = msg.get("data", {})
|
||
evt_type = evt_data.get("type", "")
|
||
|
||
if evt_type == "text":
|
||
content = evt_data.get("content", "")
|
||
if content:
|
||
text_count += 1
|
||
total_chars += len(content)
|
||
if text_count <= 3 or text_count % 10 == 0:
|
||
print(f"[pipeline] [AGENT] Text #{text_count}: +{len(content)} chars (total: {total_chars})", flush=True)
|
||
yield content
|
||
|
||
elif evt_type == "completed":
|
||
summary = evt_data.get("summary", "")
|
||
if summary and text_count == 0:
|
||
print(f"[pipeline] [AGENT] Using summary: \"{summary[:100]}\"", flush=True)
|
||
yield summary
|
||
print(f"[pipeline] [AGENT] Completed! {text_count} text events, {total_chars} chars, {event_count} WS events", flush=True)
|
||
return
|
||
|
||
elif evt_type == "error":
|
||
err_msg = evt_data.get("message", "Unknown error")
|
||
print(f"[pipeline] [AGENT] ERROR: {err_msg}", flush=True)
|
||
yield f"Agent 错误: {err_msg}"
|
||
return
|
||
|
||
except Exception as exc:
|
||
print(f"[pipeline] [AGENT] Stream error: {type(exc).__name__}: {exc}", flush=True)
|
||
yield "抱歉,Agent服务暂时不可用。"
|
||
|
||
async def _synthesize_chunk(self, text: str):
|
||
"""Synthesize a single sentence/chunk and send audio over WebSocket."""
|
||
if not text.strip():
|
||
return
|
||
|
||
try:
|
||
if settings.tts_provider == "openai":
|
||
audio_bytes = await self._tts_openai(text)
|
||
else:
|
||
if self.tts is None or self.tts._pipeline is None:
|
||
return
|
||
audio_bytes = await asyncio.get_event_loop().run_in_executor(
|
||
None, self._tts_sync, text
|
||
)
|
||
|
||
if not audio_bytes or self._cancelled_tts:
|
||
return
|
||
|
||
# Send audio in WS-sized chunks
|
||
sent_bytes = 0
|
||
offset = 0
|
||
while offset < len(audio_bytes) and not self._cancelled_tts:
|
||
end = min(offset + _WS_AUDIO_CHUNK, len(audio_bytes))
|
||
try:
|
||
await self.websocket.send_bytes(audio_bytes[offset:end])
|
||
sent_bytes += end - offset
|
||
except Exception as e:
|
||
print(f"[pipeline] [TTS] send_bytes failed at {offset}/{len(audio_bytes)}: {e}", flush=True)
|
||
break
|
||
offset = end
|
||
await asyncio.sleep(0.005)
|
||
print(f"[pipeline] [TTS] Sent {sent_bytes}/{len(audio_bytes)} bytes for: \"{text[:20]}\"", flush=True)
|
||
|
||
except Exception as exc:
|
||
logger.error("TTS chunk error: %s", exc)
|
||
|
||
async def _tts_openai(self, text: str) -> bytes:
|
||
"""Synthesize text to 16kHz PCM via OpenAI TTS API."""
|
||
if self.openai_client is None:
|
||
logger.warning("OpenAI client not available for TTS")
|
||
return b""
|
||
try:
|
||
client = self.openai_client
|
||
model = settings.openai_tts_model
|
||
voice = settings.openai_tts_voice
|
||
|
||
def _do():
|
||
response = client.audio.speech.create(
|
||
model=model, voice=voice, input=text,
|
||
response_format="pcm", # raw 24kHz 16-bit mono PCM
|
||
)
|
||
return response.content
|
||
|
||
pcm_24k = await asyncio.get_event_loop().run_in_executor(None, _do)
|
||
# OpenAI returns 24kHz 16-bit mono PCM, resample to 16kHz
|
||
audio_np = np.frombuffer(pcm_24k, dtype=np.int16).astype(np.float32)
|
||
if len(audio_np) > 0:
|
||
target_samples = int(len(audio_np) / 24000 * _SAMPLE_RATE)
|
||
indices = np.linspace(0, len(audio_np) - 1, target_samples)
|
||
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
|
||
return resampled.astype(np.int16).tobytes()
|
||
return b""
|
||
except Exception as exc:
|
||
logger.error("OpenAI TTS error: %s", exc)
|
||
return b""
|
||
|
||
def _tts_sync(self, text: str) -> bytes:
|
||
"""Synchronous TTS synthesis with Kokoro (runs in thread pool)."""
|
||
try:
|
||
samples = []
|
||
for _, _, audio in self.tts._pipeline(text, voice=self.tts.voice):
|
||
samples.append(audio)
|
||
|
||
if not samples:
|
||
return b""
|
||
|
||
audio_np = np.concatenate(samples)
|
||
# Kokoro outputs at 24kHz, resample to 16kHz
|
||
if len(audio_np) > 0:
|
||
target_samples = int(len(audio_np) / 24000 * _SAMPLE_RATE)
|
||
indices = np.linspace(0, len(audio_np) - 1, target_samples)
|
||
resampled = np.interp(indices, np.arange(len(audio_np)), audio_np)
|
||
return (resampled * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
|
||
|
||
return (audio_np * 32768).clip(-32768, 32767).astype(np.int16).tobytes()
|
||
except Exception as exc:
|
||
logger.error("TTS synthesis error: %s", exc)
|
||
return b""
|
||
|
||
|
||
def _pcm_to_wav(pcm_bytes: bytes, sample_rate: int) -> bytes:
|
||
"""Wrap raw 16-bit mono PCM into a WAV container."""
|
||
buf = io.BytesIO()
|
||
data_size = len(pcm_bytes)
|
||
buf.write(b"RIFF")
|
||
buf.write(struct.pack("<I", 36 + data_size))
|
||
buf.write(b"WAVE")
|
||
buf.write(b"fmt ")
|
||
buf.write(struct.pack("<I", 16))
|
||
buf.write(struct.pack("<H", 1)) # PCM
|
||
buf.write(struct.pack("<H", 1)) # mono
|
||
buf.write(struct.pack("<I", sample_rate))
|
||
buf.write(struct.pack("<I", sample_rate * 2)) # byte rate
|
||
buf.write(struct.pack("<H", 2)) # block align
|
||
buf.write(struct.pack("<H", 16)) # bits per sample
|
||
buf.write(b"data")
|
||
buf.write(struct.pack("<I", data_size))
|
||
buf.write(pcm_bytes)
|
||
return buf.getvalue()
|
||
|
||
|
||
async def create_voice_pipeline(
|
||
websocket: WebSocket,
|
||
session_context: dict,
|
||
*,
|
||
stt=None,
|
||
tts=None,
|
||
vad=None,
|
||
openai_client=None,
|
||
) -> VoicePipelineTask:
|
||
"""Create a voice pipeline task for the given WebSocket connection."""
|
||
return VoicePipelineTask(
|
||
websocket,
|
||
session_context,
|
||
stt=stt,
|
||
tts=tts,
|
||
vad=vad,
|
||
openai_client=openai_client,
|
||
)
|