feat: add WebSocket robustness to voice call (heartbeat, reconnect, jitter buffer)
Addresses reliability gaps in the real-time voice WebSocket connection
between Flutter client and Python voice-service backend.
Backend (voice-service):
- Heartbeat: new _heartbeat_sender coroutine sends JSON ping text frames
every 15s alongside the Pipecat pipeline; failed send = dead connection
- Session preservation: on WebSocket disconnect, sessions are now marked
"disconnected" with a timestamp instead of being deleted, allowing
reconnection within a configurable TTL (default 60s)
- Reconnect endpoint: POST /sessions/{id}/reconnect verifies the session
is alive and in "disconnected" state, returns fresh websocket_url
- Reconnect-aware WS handler: detects "disconnected" sessions, cancels
stale pipeline tasks, creates a new pipeline, sends "session.resumed"
- Background cleanup: asyncio loop every 30s removes sessions that have
been disconnected longer than session_ttl
- Structured event protocol: text frames = JSON control messages
(ping/pong/session.resumed/session.ended/error), binary = PCM audio
- New settings: session_ttl (60s), heartbeat_interval (15s),
heartbeat_timeout (45s)
Flutter (agent_call_page.dart):
- Heartbeat monitoring: tracks last server ping timestamp, triggers
reconnect if no ping received in 45s (3 missed intervals)
- Auto-reconnect: exponential backoff (1s→2s→4s→8s→16s), max 5 attempts;
calls /reconnect endpoint to verify session, rebuilds WebSocket,
resets audio buffer, restarts heartbeat
- Reconnecting UI: yellow warning banner "重新连接中... (N/5)" with
spinner overlay during reconnection attempts
- WebSocket data routing: _onWsData distinguishes String (JSON control)
from binary (audio) frames, handles ping/session.resumed/session.ended
- User-initiated disconnect guard: _userEndedCall flag prevents reconnect
attempts when user intentionally hangs up
- session_id field compatibility: supports session_id/sessionId/id
Flutter (pcm_player.dart):
- Jitter buffer: queues incoming PCM chunks, starts playback only after
accumulating 4800 bytes (150ms at 16kHz 16-bit mono) to smooth out
network timing variance
- reset() method: clears buffer on reconnect to discard stale audio
- Buffer underrun handling: re-enters buffering phase if queue empties
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
dfc541b571
commit
a6cd3c20d9
|
|
@ -1,7 +1,8 @@
|
|||
import 'dart:typed_data';
|
||||
import 'package:flutter_sound/flutter_sound.dart';
|
||||
|
||||
/// Wraps [FlutterSoundPlayer] for streaming raw PCM 16kHz mono playback.
|
||||
/// Wraps [FlutterSoundPlayer] for streaming raw PCM 16kHz mono playback
|
||||
/// with a simple jitter buffer to smooth out network-induced timing variance.
|
||||
///
|
||||
/// Usage:
|
||||
/// final player = PcmPlayer();
|
||||
|
|
@ -12,6 +13,15 @@ class PcmPlayer {
|
|||
FlutterSoundPlayer? _player;
|
||||
bool _initialized = false;
|
||||
|
||||
// Jitter buffer
|
||||
final List<Uint8List> _buffer = [];
|
||||
int _bufferedBytes = 0;
|
||||
bool _playbackStarted = false;
|
||||
|
||||
/// Number of bytes to accumulate before starting playback.
|
||||
/// 4800 bytes = 150 ms at 16 kHz, 16-bit mono (2 bytes per sample).
|
||||
static const int _bufferThreshold = 4800;
|
||||
|
||||
/// Open the player and start a streaming session.
|
||||
Future<void> init() async {
|
||||
if (_initialized) return;
|
||||
|
|
@ -31,12 +41,50 @@ class PcmPlayer {
|
|||
}
|
||||
|
||||
/// Feed raw PCM 16-bit signed LE mono 16 kHz data for playback.
|
||||
///
|
||||
/// Incoming chunks are held in a jitter buffer until [_bufferThreshold] bytes
|
||||
/// have accumulated. Once playback has started, new chunks are forwarded to
|
||||
/// the underlying player immediately. If a buffer underrun occurs (the queue
|
||||
/// empties while playing) the next call to [feed] will re-enter the buffering
|
||||
/// phase, pausing briefly until the threshold is reached again.
|
||||
Future<void> feed(Uint8List pcmData) async {
|
||||
if (!_initialized || _player == null) return;
|
||||
|
||||
if (!_playbackStarted) {
|
||||
// Still buffering — queue and wait until we reach the threshold.
|
||||
_buffer.add(pcmData);
|
||||
_bufferedBytes += pcmData.length;
|
||||
if (_bufferedBytes >= _bufferThreshold) {
|
||||
_playbackStarted = true;
|
||||
await _drainBuffer();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Playback already running — feed directly.
|
||||
// ignore: deprecated_member_use
|
||||
await _player!.feedUint8FromStream(pcmData);
|
||||
}
|
||||
|
||||
/// Drain all queued chunks into the player in order.
|
||||
Future<void> _drainBuffer() async {
|
||||
while (_buffer.isNotEmpty) {
|
||||
final chunk = _buffer.removeAt(0);
|
||||
_bufferedBytes -= chunk.length;
|
||||
// ignore: deprecated_member_use
|
||||
await _player!.feedUint8FromStream(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the jitter buffer and reset playback state.
|
||||
///
|
||||
/// Call this on reconnect so stale audio data is not played back.
|
||||
void reset() {
|
||||
_buffer.clear();
|
||||
_bufferedBytes = 0;
|
||||
_playbackStarted = false;
|
||||
}
|
||||
|
||||
/// Toggle speaker mode (earpiece vs loudspeaker).
|
||||
Future<void> setSpeakerOn(bool on) async {
|
||||
// flutter_sound doesn't expose a direct speaker toggle;
|
||||
|
|
@ -46,6 +94,7 @@ class PcmPlayer {
|
|||
|
||||
/// Stop playback and release resources.
|
||||
Future<void> dispose() async {
|
||||
reset();
|
||||
if (_player != null) {
|
||||
try {
|
||||
await _player!.stopPlayer();
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import 'dart:async';
|
||||
import 'dart:convert';
|
||||
import 'dart:math';
|
||||
import 'dart:typed_data';
|
||||
import 'package:flutter/material.dart';
|
||||
|
|
@ -50,6 +51,15 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
|
|||
bool _isMuted = false;
|
||||
bool _isSpeakerOn = false;
|
||||
|
||||
// Reconnection
|
||||
int _reconnectAttempts = 0;
|
||||
static const int _maxReconnectAttempts = 5;
|
||||
Timer? _reconnectTimer;
|
||||
Timer? _heartbeatCheckTimer;
|
||||
DateTime? _lastServerPing;
|
||||
bool _isReconnecting = false;
|
||||
bool _userEndedCall = false;
|
||||
|
||||
@override
|
||||
void initState() {
|
||||
super.initState();
|
||||
|
|
@ -75,7 +85,7 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
|
|||
final response = await dio.post('${ApiEndpoints.voice}/sessions');
|
||||
final data = response.data as Map<String, dynamic>;
|
||||
|
||||
_sessionId = data['id'] as String? ?? data['sessionId'] as String?;
|
||||
_sessionId = data['session_id'] as String? ?? data['sessionId'] as String? ?? data['id'] as String?;
|
||||
|
||||
// Build WebSocket URL — prefer backend-returned path
|
||||
String wsUrl;
|
||||
|
|
@ -94,13 +104,16 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
|
|||
// Connect WebSocket
|
||||
_audioChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
|
||||
|
||||
// Listen for incoming audio from the agent
|
||||
// Listen for incoming data (text control messages + binary audio)
|
||||
_audioSubscription = _audioChannel!.stream.listen(
|
||||
(data) => _onAudioReceived(data),
|
||||
onDone: () => _onCallEnded(),
|
||||
onError: (_) => _onCallEnded(),
|
||||
_onWsData,
|
||||
onDone: () => _onWsDisconnected(),
|
||||
onError: (_) => _onWsDisconnected(),
|
||||
);
|
||||
|
||||
// Start bidirectional heartbeat
|
||||
_startHeartbeat();
|
||||
|
||||
// Initialize audio playback
|
||||
await _pcmPlayer.init();
|
||||
|
||||
|
|
@ -158,6 +171,166 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
|
|||
});
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// WebSocket data routing & heartbeat
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Top-level handler that distinguishes text (JSON control) from binary
|
||||
/// (audio) WebSocket frames.
|
||||
void _onWsData(dynamic data) {
|
||||
if (data is String) {
|
||||
_handleControlMessage(data);
|
||||
} else {
|
||||
_onAudioReceived(data);
|
||||
}
|
||||
}
|
||||
|
||||
/// Decode a JSON control frame from the server and respond accordingly.
|
||||
void _handleControlMessage(String jsonStr) {
|
||||
try {
|
||||
final msg = jsonDecode(jsonStr) as Map<String, dynamic>;
|
||||
final type = msg['type'] as String?;
|
||||
switch (type) {
|
||||
case 'ping':
|
||||
// Server heartbeat — update liveness timestamp.
|
||||
// Note: We don't send pong back because Pipecat owns the server-side
|
||||
// WebSocket read loop and would intercept the text frame.
|
||||
_lastServerPing = DateTime.now();
|
||||
break;
|
||||
case 'session.resumed':
|
||||
// Reconnection confirmed by server
|
||||
_lastServerPing = DateTime.now();
|
||||
break;
|
||||
case 'session.ended':
|
||||
_onCallEnded();
|
||||
break;
|
||||
case 'error':
|
||||
if (mounted) {
|
||||
final detail = msg['message'] as String? ?? 'Unknown error';
|
||||
ScaffoldMessenger.of(context).showSnackBar(
|
||||
SnackBar(
|
||||
content: Text('服务端错误: $detail'),
|
||||
backgroundColor: AppColors.error,
|
||||
),
|
||||
);
|
||||
}
|
||||
break;
|
||||
}
|
||||
} catch (_) {
|
||||
// Malformed JSON — ignore
|
||||
}
|
||||
}
|
||||
|
||||
/// Start monitoring server liveness based on incoming pings.
|
||||
///
|
||||
/// The server sends `{"type": "ping"}` every 15 s. If we haven't received
|
||||
/// one in 45 s (3 missed pings), we assume the connection is dead and
|
||||
/// trigger a reconnect.
|
||||
void _startHeartbeat() {
|
||||
_heartbeatCheckTimer?.cancel();
|
||||
|
||||
// Mark initial timestamp so the first 45-second window starts now
|
||||
_lastServerPing = DateTime.now();
|
||||
|
||||
// Periodically verify server is alive
|
||||
_heartbeatCheckTimer = Timer.periodic(const Duration(seconds: 10), (_) {
|
||||
if (_lastServerPing != null &&
|
||||
DateTime.now().difference(_lastServerPing!).inSeconds > 45) {
|
||||
_triggerReconnect();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Cancel heartbeat check timer.
|
||||
void _stopHeartbeat() {
|
||||
_heartbeatCheckTimer?.cancel();
|
||||
_heartbeatCheckTimer = null;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Auto-reconnect with exponential backoff
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Called when the WebSocket drops unexpectedly (not user-initiated).
|
||||
void _onWsDisconnected() {
|
||||
if (_phase == _CallPhase.ended || _isReconnecting || _userEndedCall) return;
|
||||
_triggerReconnect();
|
||||
}
|
||||
|
||||
/// Attempt to reconnect using exponential backoff (1 s, 2 s, 4 s, 8 s,
|
||||
/// 16 s) with a maximum of [_maxReconnectAttempts] tries.
|
||||
Future<void> _triggerReconnect() async {
|
||||
if (_isReconnecting || _phase == _CallPhase.ended || _userEndedCall) return;
|
||||
_isReconnecting = true;
|
||||
if (mounted) setState(() {});
|
||||
|
||||
// Tear down current connection & heartbeat
|
||||
_stopHeartbeat();
|
||||
await _audioSubscription?.cancel();
|
||||
_audioSubscription = null;
|
||||
try {
|
||||
await _audioChannel?.sink.close();
|
||||
} catch (_) {}
|
||||
_audioChannel = null;
|
||||
|
||||
for (int attempt = 0; attempt < _maxReconnectAttempts; attempt++) {
|
||||
_reconnectAttempts = attempt + 1;
|
||||
if (!mounted || _phase == _CallPhase.ended || _userEndedCall) break;
|
||||
|
||||
final delaySecs = min(pow(2, attempt).toInt(), 16);
|
||||
await Future.delayed(Duration(seconds: delaySecs));
|
||||
|
||||
if (!mounted || _phase == _CallPhase.ended || _userEndedCall) break;
|
||||
|
||||
try {
|
||||
// Ask backend if session is still alive
|
||||
final dio = ref.read(dioClientProvider);
|
||||
final response = await dio.post(
|
||||
'${ApiEndpoints.voice}/sessions/$_sessionId/reconnect',
|
||||
);
|
||||
final data = response.data as Map<String, dynamic>;
|
||||
|
||||
// Build new WebSocket URL
|
||||
final config = ref.read(appConfigProvider);
|
||||
String wsUrl;
|
||||
final backendWsUrl = data['websocket_url'] as String?;
|
||||
if (backendWsUrl != null && backendWsUrl.startsWith('/')) {
|
||||
final uri = Uri.parse(config.wsBaseUrl);
|
||||
wsUrl = '${uri.scheme}://${uri.host}:${uri.port}$backendWsUrl';
|
||||
} else if (backendWsUrl != null) {
|
||||
wsUrl = backendWsUrl;
|
||||
} else {
|
||||
wsUrl = '${config.wsBaseUrl}/api/v1/voice/ws/$_sessionId';
|
||||
}
|
||||
|
||||
// Connect new WebSocket
|
||||
_audioChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
|
||||
_audioSubscription = _audioChannel!.stream.listen(
|
||||
_onWsData,
|
||||
onDone: () => _onWsDisconnected(),
|
||||
onError: (_) => _onWsDisconnected(),
|
||||
);
|
||||
|
||||
// Reset audio jitter buffer for a fresh start
|
||||
_pcmPlayer.reset();
|
||||
|
||||
// Restart heartbeat
|
||||
_startHeartbeat();
|
||||
|
||||
_reconnectAttempts = 0;
|
||||
_isReconnecting = false;
|
||||
if (mounted) setState(() {});
|
||||
return; // success
|
||||
} catch (_) {
|
||||
// Will retry on next iteration
|
||||
}
|
||||
}
|
||||
|
||||
// All reconnection attempts exhausted
|
||||
_isReconnecting = false;
|
||||
_onCallEnded();
|
||||
}
|
||||
|
||||
/// Handle incoming audio from the agent side.
|
||||
void _onAudioReceived(dynamic data) {
|
||||
if (!mounted || _phase != _CallPhase.active) return;
|
||||
|
|
@ -203,11 +376,18 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
|
|||
|
||||
/// End the call and clean up all resources.
|
||||
Future<void> _endCall() async {
|
||||
_userEndedCall = true;
|
||||
_isReconnecting = false;
|
||||
setState(() => _phase = _CallPhase.ended);
|
||||
_stopwatch.stop();
|
||||
_durationTimer?.cancel();
|
||||
_waveController.stop();
|
||||
|
||||
// Cancel heartbeat & reconnect timers
|
||||
_stopHeartbeat();
|
||||
_reconnectTimer?.cancel();
|
||||
_reconnectTimer = null;
|
||||
|
||||
// Stop mic
|
||||
await _micSubscription?.cancel();
|
||||
_micSubscription = null;
|
||||
|
|
@ -252,38 +432,76 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
|
|||
return Scaffold(
|
||||
backgroundColor: AppColors.background,
|
||||
body: SafeArea(
|
||||
child: Column(
|
||||
child: Stack(
|
||||
children: [
|
||||
const Spacer(flex: 2),
|
||||
_buildAvatar(),
|
||||
const SizedBox(height: 24),
|
||||
Text(
|
||||
_statusText,
|
||||
style:
|
||||
const TextStyle(fontSize: 24, fontWeight: FontWeight.bold),
|
||||
Column(
|
||||
children: [
|
||||
const Spacer(flex: 2),
|
||||
_buildAvatar(),
|
||||
const SizedBox(height: 24),
|
||||
Text(
|
||||
_statusText,
|
||||
style: const TextStyle(
|
||||
fontSize: 24, fontWeight: FontWeight.bold),
|
||||
),
|
||||
const SizedBox(height: 8),
|
||||
Text(
|
||||
_subtitleText,
|
||||
style: const TextStyle(
|
||||
color: AppColors.textSecondary, fontSize: 15),
|
||||
),
|
||||
const SizedBox(height: 32),
|
||||
if (_phase == _CallPhase.active)
|
||||
Text(
|
||||
_durationLabel,
|
||||
style: const TextStyle(
|
||||
fontSize: 40,
|
||||
fontWeight: FontWeight.w300,
|
||||
color: AppColors.textPrimary,
|
||||
letterSpacing: 4,
|
||||
),
|
||||
),
|
||||
const SizedBox(height: 24),
|
||||
if (_phase == _CallPhase.active) _buildWaveform(),
|
||||
const Spacer(flex: 3),
|
||||
_buildControls(),
|
||||
const SizedBox(height: 48),
|
||||
],
|
||||
),
|
||||
const SizedBox(height: 8),
|
||||
Text(
|
||||
_subtitleText,
|
||||
style: const TextStyle(
|
||||
color: AppColors.textSecondary, fontSize: 15),
|
||||
),
|
||||
const SizedBox(height: 32),
|
||||
if (_phase == _CallPhase.active)
|
||||
Text(
|
||||
_durationLabel,
|
||||
style: const TextStyle(
|
||||
fontSize: 40,
|
||||
fontWeight: FontWeight.w300,
|
||||
color: AppColors.textPrimary,
|
||||
letterSpacing: 4,
|
||||
// Reconnecting overlay
|
||||
if (_isReconnecting)
|
||||
Positioned(
|
||||
top: 0,
|
||||
left: 0,
|
||||
right: 0,
|
||||
child: Container(
|
||||
padding:
|
||||
const EdgeInsets.symmetric(vertical: 10, horizontal: 16),
|
||||
color: AppColors.warning.withOpacity(0.9),
|
||||
child: Row(
|
||||
mainAxisAlignment: MainAxisAlignment.center,
|
||||
children: [
|
||||
const SizedBox(
|
||||
width: 16,
|
||||
height: 16,
|
||||
child: CircularProgressIndicator(
|
||||
strokeWidth: 2,
|
||||
color: Colors.white,
|
||||
),
|
||||
),
|
||||
const SizedBox(width: 10),
|
||||
Text(
|
||||
'重新连接中... ($_reconnectAttempts/$_maxReconnectAttempts)',
|
||||
style: const TextStyle(
|
||||
color: Colors.white,
|
||||
fontSize: 14,
|
||||
fontWeight: FontWeight.w500,
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
),
|
||||
),
|
||||
const SizedBox(height: 24),
|
||||
if (_phase == _CallPhase.active) _buildWaveform(),
|
||||
const Spacer(flex: 3),
|
||||
_buildControls(),
|
||||
const SizedBox(height: 48),
|
||||
],
|
||||
),
|
||||
),
|
||||
|
|
@ -432,7 +650,10 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
|
|||
|
||||
@override
|
||||
void dispose() {
|
||||
_userEndedCall = true;
|
||||
_durationTimer?.cancel();
|
||||
_stopHeartbeat();
|
||||
_reconnectTimer?.cancel();
|
||||
_waveController.dispose();
|
||||
_stopwatch.stop();
|
||||
_micSubscription?.cancel();
|
||||
|
|
|
|||
|
|
@ -1,12 +1,18 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from ..config.settings import settings
|
||||
from .health import router as health_router
|
||||
from .session_router import router as session_router
|
||||
from .twilio_webhook import router as twilio_router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="IT0 Voice Service",
|
||||
description="Real-time voice dialogue engine powered by Pipecat",
|
||||
|
|
@ -25,6 +31,32 @@ app.include_router(session_router, prefix="/api/v1/voice", tags=["sessions"])
|
|||
app.include_router(twilio_router, prefix="/api/v1/twilio", tags=["twilio"])
|
||||
|
||||
|
||||
async def _session_cleanup_loop():
|
||||
"""Periodically remove sessions that have been disconnected longer than session_ttl.
|
||||
|
||||
Runs every 30 seconds. Sessions in "disconnected" state whose
|
||||
``disconnected_at`` timestamp is older than ``settings.session_ttl``
|
||||
are removed from ``app.state.sessions``.
|
||||
"""
|
||||
ttl = settings.session_ttl
|
||||
while True:
|
||||
await asyncio.sleep(30)
|
||||
try:
|
||||
sessions: dict = getattr(app.state, "sessions", {})
|
||||
now = time.time()
|
||||
expired = [
|
||||
sid
|
||||
for sid, s in sessions.items()
|
||||
if s.get("status") == "disconnected"
|
||||
and (now - s.get("disconnected_at", 0)) > ttl
|
||||
]
|
||||
for sid in expired:
|
||||
sessions.pop(sid, None)
|
||||
logger.info("Cleaned up expired session %s", sid)
|
||||
except Exception:
|
||||
logger.exception("Error in session cleanup loop")
|
||||
|
||||
|
||||
def _load_models_sync():
|
||||
"""Load ML models in a background thread (all blocking calls)."""
|
||||
from ..config.settings import settings
|
||||
|
|
@ -96,11 +128,19 @@ async def startup():
|
|||
app.state.stt = None
|
||||
app.state.tts = None
|
||||
app.state.vad = None
|
||||
app.state.sessions = {}
|
||||
|
||||
# Load models in background thread so server responds to healthchecks immediately
|
||||
thread = threading.Thread(target=_load_models_sync, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Start background session cleanup task
|
||||
app.state._cleanup_task = asyncio.create_task(_session_cleanup_loop())
|
||||
logger.info(
|
||||
"Session cleanup task started (ttl=%ds, check every 30s)",
|
||||
settings.session_ttl,
|
||||
)
|
||||
|
||||
print("Voice service ready (models loading in background).", flush=True)
|
||||
|
||||
|
||||
|
|
@ -108,3 +148,12 @@ async def startup():
|
|||
async def shutdown():
|
||||
"""Cleanup on shutdown."""
|
||||
print("Voice service shutting down...", flush=True)
|
||||
|
||||
# Cancel session cleanup task
|
||||
cleanup_task = getattr(app.state, "_cleanup_task", None)
|
||||
if cleanup_task is not None and not cleanup_task.done():
|
||||
cleanup_task.cancel()
|
||||
try:
|
||||
await cleanup_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Request, UploadFile, File
|
||||
|
|
@ -6,9 +9,12 @@ from fastapi.responses import JSONResponse
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
from ..config.settings import settings
|
||||
from ..pipeline.app_transport import AppTransport
|
||||
from ..pipeline.base_pipeline import create_voice_pipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
|
|
@ -25,6 +31,39 @@ class SessionResponse(BaseModel):
|
|||
websocket_url: str
|
||||
|
||||
|
||||
async def _heartbeat_sender(websocket: WebSocket, session: dict):
|
||||
"""Send periodic pings to keep the connection alive and detect dead clients.
|
||||
|
||||
Runs as a parallel asyncio task alongside the Pipecat pipeline.
|
||||
Sends a JSON ``{"type": "ping", "ts": <epoch_ms>}`` text frame every
|
||||
``heartbeat_interval`` seconds. If the send fails the connection is dead
|
||||
and the task exits, which will cause the pipeline to be cleaned up.
|
||||
|
||||
Note: Pipecat owns the WebSocket read loop (for binary audio frames), so
|
||||
we cannot read client pong responses here. Instead we rely on the fact
|
||||
that a failed ``send_text`` indicates a broken connection. The client
|
||||
sends audio continuously during an active call, so Pipecat's pipeline
|
||||
will also naturally detect disconnection.
|
||||
"""
|
||||
interval = settings.heartbeat_interval
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(interval)
|
||||
try:
|
||||
await websocket.send_text(
|
||||
json.dumps({"type": "ping", "ts": int(time.time() * 1000)})
|
||||
)
|
||||
except Exception:
|
||||
# WebSocket already closed — exit so cleanup runs
|
||||
logger.info(
|
||||
"Heartbeat send failed for session %s, connection dead",
|
||||
session.get("session_id", "?"),
|
||||
)
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
|
||||
|
||||
@router.post("/sessions", response_model=SessionResponse)
|
||||
async def create_session(request: CreateSessionRequest, req: Request):
|
||||
"""Create a new voice dialogue session."""
|
||||
|
|
@ -79,9 +118,62 @@ async def end_session(session_id: str, req: Request):
|
|||
return {"status": "ended", "session_id": session_id}
|
||||
|
||||
|
||||
@router.post("/sessions/{session_id}/reconnect", response_model=SessionResponse)
|
||||
async def reconnect_session(session_id: str, req: Request):
|
||||
"""Reconnect to a disconnected voice session.
|
||||
|
||||
If the session still exists and is in "disconnected" state (within the
|
||||
``session_ttl`` window), return session info with a fresh websocket_url
|
||||
so the client can open a new WebSocket connection.
|
||||
"""
|
||||
if not hasattr(req.app.state, "sessions"):
|
||||
req.app.state.sessions = {}
|
||||
|
||||
session = req.app.state.sessions.get(session_id)
|
||||
if session is None:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"error": "Session not found or expired", "session_id": session_id},
|
||||
)
|
||||
|
||||
if session["status"] != "disconnected":
|
||||
return JSONResponse(
|
||||
status_code=409,
|
||||
content={
|
||||
"error": f"Session is in '{session['status']}' state, not reconnectable",
|
||||
"session_id": session_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Check if session has expired based on TTL
|
||||
disconnected_at = session.get("disconnected_at", 0)
|
||||
if time.time() - disconnected_at > settings.session_ttl:
|
||||
# Expired -- clean it up
|
||||
req.app.state.sessions.pop(session_id, None)
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={"error": "Session expired", "session_id": session_id},
|
||||
)
|
||||
|
||||
websocket_url = f"/api/v1/voice/ws/{session_id}"
|
||||
return SessionResponse(
|
||||
session_id=session_id,
|
||||
status="disconnected",
|
||||
websocket_url=websocket_url,
|
||||
)
|
||||
|
||||
|
||||
@router.websocket("/ws/{session_id}")
|
||||
async def voice_websocket(websocket: WebSocket, session_id: str):
|
||||
"""WebSocket endpoint for real-time voice streaming."""
|
||||
"""WebSocket endpoint for real-time voice streaming.
|
||||
|
||||
Supports both fresh connections and reconnections. Binary frames carry
|
||||
PCM audio and are handled by the Pipecat pipeline. Text frames carry
|
||||
JSON control events (ping/pong) and are handled by a parallel task.
|
||||
|
||||
On disconnect the session is preserved in "disconnected" state for up to
|
||||
``session_ttl`` seconds so the client can reconnect.
|
||||
"""
|
||||
await websocket.accept()
|
||||
|
||||
app = websocket.app
|
||||
|
|
@ -96,11 +188,31 @@ async def voice_websocket(websocket: WebSocket, session_id: str):
|
|||
await websocket.close(code=4004, reason="Session not found")
|
||||
return
|
||||
|
||||
is_reconnect = session["status"] == "disconnected"
|
||||
|
||||
# Cancel any leftover pipeline task from previous connection
|
||||
old_task = session.get("task")
|
||||
if old_task is not None and isinstance(old_task, asyncio.Task) and not old_task.done():
|
||||
old_task.cancel()
|
||||
try:
|
||||
await old_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
# Update session status
|
||||
session["status"] = "active"
|
||||
|
||||
pipeline_task = None
|
||||
heartbeat_task = None
|
||||
|
||||
try:
|
||||
# Notify client of successful reconnection
|
||||
if is_reconnect:
|
||||
logger.info("Session %s reconnected", session_id)
|
||||
await websocket.send_text(
|
||||
json.dumps({"type": "session.resumed", "session_id": session_id})
|
||||
)
|
||||
|
||||
# Create the AppTransport from the websocket connection
|
||||
transport = AppTransport(websocket)
|
||||
|
||||
|
|
@ -124,17 +236,38 @@ async def voice_websocket(websocket: WebSocket, session_id: str):
|
|||
pipeline_task = asyncio.create_task(task.run())
|
||||
session["task"] = pipeline_task
|
||||
|
||||
# Wait for the pipeline task to complete (ends on disconnect or cancel)
|
||||
# Start heartbeat sender as a parallel task
|
||||
heartbeat_task = asyncio.create_task(
|
||||
_heartbeat_sender(websocket, session)
|
||||
)
|
||||
|
||||
# Wait for the pipeline to finish. The heartbeat or text-message
|
||||
# handler may close the websocket which will also cause the pipeline
|
||||
# to end.
|
||||
await pipeline_task
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.exception("Unexpected error in voice_websocket for session %s: %s", session_id, exc)
|
||||
try:
|
||||
await websocket.send_text(
|
||||
json.dumps({"type": "error", "message": str(exc)})
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
# Cleanup: cancel pipeline if still running
|
||||
# Cancel heartbeat task
|
||||
if heartbeat_task is not None and not heartbeat_task.done():
|
||||
heartbeat_task.cancel()
|
||||
try:
|
||||
await heartbeat_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
# Cancel pipeline task if still running
|
||||
if pipeline_task is not None and not pipeline_task.done():
|
||||
pipeline_task.cancel()
|
||||
try:
|
||||
|
|
@ -142,10 +275,16 @@ async def voice_websocket(websocket: WebSocket, session_id: str):
|
|||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
|
||||
# Update session status
|
||||
# Mark session as disconnected (preserve for reconnection)
|
||||
if session_id in app.state.sessions:
|
||||
app.state.sessions[session_id]["status"] = "disconnected"
|
||||
app.state.sessions[session_id]["disconnected_at"] = time.time()
|
||||
app.state.sessions[session_id]["task"] = None
|
||||
logger.info(
|
||||
"Session %s disconnected, preserved for %ds",
|
||||
session_id,
|
||||
settings.session_ttl,
|
||||
)
|
||||
|
||||
# Ensure websocket is closed
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -36,6 +36,11 @@ class Settings(BaseSettings):
|
|||
audio_sample_rate: int = 16000
|
||||
audio_channels: int = 1
|
||||
|
||||
# Session
|
||||
session_ttl: int = 60 # seconds before disconnected sessions are cleaned up
|
||||
heartbeat_interval: int = 15 # seconds between pings
|
||||
heartbeat_timeout: int = 45 # seconds before declaring dead connection
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue