feat: add WebSocket robustness to voice call (heartbeat, reconnect, jitter buffer)

Addresses reliability gaps in the real-time voice WebSocket connection
between Flutter client and Python voice-service backend.

Backend (voice-service):
- Heartbeat: new _heartbeat_sender coroutine sends JSON ping text frames
  every 15s alongside the Pipecat pipeline; failed send = dead connection
- Session preservation: on WebSocket disconnect, sessions are now marked
  "disconnected" with a timestamp instead of being deleted, allowing
  reconnection within a configurable TTL (default 60s)
- Reconnect endpoint: POST /sessions/{id}/reconnect verifies the session
  is alive and in "disconnected" state, returns fresh websocket_url
- Reconnect-aware WS handler: detects "disconnected" sessions, cancels
  stale pipeline tasks, creates a new pipeline, sends "session.resumed"
- Background cleanup: asyncio loop every 30s removes sessions that have
  been disconnected longer than session_ttl
- Structured event protocol: text frames = JSON control messages
  (ping/pong/session.resumed/session.ended/error), binary = PCM audio
- New settings: session_ttl (60s), heartbeat_interval (15s),
  heartbeat_timeout (45s)

Flutter (agent_call_page.dart):
- Heartbeat monitoring: tracks last server ping timestamp, triggers
  reconnect if no ping received in 45s (3 missed intervals)
- Auto-reconnect: exponential backoff (1s→2s→4s→8s→16s), max 5 attempts;
  calls /reconnect endpoint to verify session, rebuilds WebSocket,
  resets audio buffer, restarts heartbeat
- Reconnecting UI: yellow warning banner "重新连接中... (N/5)" with
  spinner overlay during reconnection attempts
- WebSocket data routing: _onWsData distinguishes String (JSON control)
  from binary (audio) frames, handles ping/session.resumed/session.ended
- User-initiated disconnect guard: _userEndedCall flag prevents reconnect
  attempts when user intentionally hangs up
- session_id field compatibility: supports session_id/sessionId/id

Flutter (pcm_player.dart):
- Jitter buffer: queues incoming PCM chunks, starts playback only after
  accumulating 4800 bytes (150ms at 16kHz 16-bit mono) to smooth out
  network timing variance
- reset() method: clears buffer on reconnect to discard stale audio
- Buffer underrun handling: re-enters buffering phase if queue empties

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-02-23 07:32:19 -08:00
parent dfc541b571
commit a6cd3c20d9
5 changed files with 503 additions and 40 deletions

View File

@ -1,7 +1,8 @@
import 'dart:typed_data'; import 'dart:typed_data';
import 'package:flutter_sound/flutter_sound.dart'; import 'package:flutter_sound/flutter_sound.dart';
/// Wraps [FlutterSoundPlayer] for streaming raw PCM 16kHz mono playback. /// Wraps [FlutterSoundPlayer] for streaming raw PCM 16kHz mono playback
/// with a simple jitter buffer to smooth out network-induced timing variance.
/// ///
/// Usage: /// Usage:
/// final player = PcmPlayer(); /// final player = PcmPlayer();
@ -12,6 +13,15 @@ class PcmPlayer {
FlutterSoundPlayer? _player; FlutterSoundPlayer? _player;
bool _initialized = false; bool _initialized = false;
// Jitter buffer
final List<Uint8List> _buffer = [];
int _bufferedBytes = 0;
bool _playbackStarted = false;
/// Number of bytes to accumulate before starting playback.
/// 4800 bytes = 150 ms at 16 kHz, 16-bit mono (2 bytes per sample).
static const int _bufferThreshold = 4800;
/// Open the player and start a streaming session. /// Open the player and start a streaming session.
Future<void> init() async { Future<void> init() async {
if (_initialized) return; if (_initialized) return;
@ -31,12 +41,50 @@ class PcmPlayer {
} }
/// Feed raw PCM 16-bit signed LE mono 16 kHz data for playback. /// Feed raw PCM 16-bit signed LE mono 16 kHz data for playback.
///
/// Incoming chunks are held in a jitter buffer until [_bufferThreshold] bytes
/// have accumulated. Once playback has started, new chunks are forwarded to
/// the underlying player immediately. If a buffer underrun occurs (the queue
/// empties while playing) the next call to [feed] will re-enter the buffering
/// phase, pausing briefly until the threshold is reached again.
Future<void> feed(Uint8List pcmData) async { Future<void> feed(Uint8List pcmData) async {
if (!_initialized || _player == null) return; if (!_initialized || _player == null) return;
if (!_playbackStarted) {
// Still buffering queue and wait until we reach the threshold.
_buffer.add(pcmData);
_bufferedBytes += pcmData.length;
if (_bufferedBytes >= _bufferThreshold) {
_playbackStarted = true;
await _drainBuffer();
}
return;
}
// Playback already running feed directly.
// ignore: deprecated_member_use // ignore: deprecated_member_use
await _player!.feedUint8FromStream(pcmData); await _player!.feedUint8FromStream(pcmData);
} }
/// Drain all queued chunks into the player in order.
Future<void> _drainBuffer() async {
while (_buffer.isNotEmpty) {
final chunk = _buffer.removeAt(0);
_bufferedBytes -= chunk.length;
// ignore: deprecated_member_use
await _player!.feedUint8FromStream(chunk);
}
}
/// Clear the jitter buffer and reset playback state.
///
/// Call this on reconnect so stale audio data is not played back.
void reset() {
_buffer.clear();
_bufferedBytes = 0;
_playbackStarted = false;
}
/// Toggle speaker mode (earpiece vs loudspeaker). /// Toggle speaker mode (earpiece vs loudspeaker).
Future<void> setSpeakerOn(bool on) async { Future<void> setSpeakerOn(bool on) async {
// flutter_sound doesn't expose a direct speaker toggle; // flutter_sound doesn't expose a direct speaker toggle;
@ -46,6 +94,7 @@ class PcmPlayer {
/// Stop playback and release resources. /// Stop playback and release resources.
Future<void> dispose() async { Future<void> dispose() async {
reset();
if (_player != null) { if (_player != null) {
try { try {
await _player!.stopPlayer(); await _player!.stopPlayer();

View File

@ -1,4 +1,5 @@
import 'dart:async'; import 'dart:async';
import 'dart:convert';
import 'dart:math'; import 'dart:math';
import 'dart:typed_data'; import 'dart:typed_data';
import 'package:flutter/material.dart'; import 'package:flutter/material.dart';
@ -50,6 +51,15 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
bool _isMuted = false; bool _isMuted = false;
bool _isSpeakerOn = false; bool _isSpeakerOn = false;
// Reconnection
int _reconnectAttempts = 0;
static const int _maxReconnectAttempts = 5;
Timer? _reconnectTimer;
Timer? _heartbeatCheckTimer;
DateTime? _lastServerPing;
bool _isReconnecting = false;
bool _userEndedCall = false;
@override @override
void initState() { void initState() {
super.initState(); super.initState();
@ -75,7 +85,7 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
final response = await dio.post('${ApiEndpoints.voice}/sessions'); final response = await dio.post('${ApiEndpoints.voice}/sessions');
final data = response.data as Map<String, dynamic>; final data = response.data as Map<String, dynamic>;
_sessionId = data['id'] as String? ?? data['sessionId'] as String?; _sessionId = data['session_id'] as String? ?? data['sessionId'] as String? ?? data['id'] as String?;
// Build WebSocket URL prefer backend-returned path // Build WebSocket URL prefer backend-returned path
String wsUrl; String wsUrl;
@ -94,13 +104,16 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
// Connect WebSocket // Connect WebSocket
_audioChannel = WebSocketChannel.connect(Uri.parse(wsUrl)); _audioChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
// Listen for incoming audio from the agent // Listen for incoming data (text control messages + binary audio)
_audioSubscription = _audioChannel!.stream.listen( _audioSubscription = _audioChannel!.stream.listen(
(data) => _onAudioReceived(data), _onWsData,
onDone: () => _onCallEnded(), onDone: () => _onWsDisconnected(),
onError: (_) => _onCallEnded(), onError: (_) => _onWsDisconnected(),
); );
// Start bidirectional heartbeat
_startHeartbeat();
// Initialize audio playback // Initialize audio playback
await _pcmPlayer.init(); await _pcmPlayer.init();
@ -158,6 +171,166 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
}); });
} }
// ---------------------------------------------------------------------------
// WebSocket data routing & heartbeat
// ---------------------------------------------------------------------------
/// Top-level handler that distinguishes text (JSON control) from binary
/// (audio) WebSocket frames.
void _onWsData(dynamic data) {
if (data is String) {
_handleControlMessage(data);
} else {
_onAudioReceived(data);
}
}
/// Decode a JSON control frame from the server and respond accordingly.
void _handleControlMessage(String jsonStr) {
try {
final msg = jsonDecode(jsonStr) as Map<String, dynamic>;
final type = msg['type'] as String?;
switch (type) {
case 'ping':
// Server heartbeat update liveness timestamp.
// Note: We don't send pong back because Pipecat owns the server-side
// WebSocket read loop and would intercept the text frame.
_lastServerPing = DateTime.now();
break;
case 'session.resumed':
// Reconnection confirmed by server
_lastServerPing = DateTime.now();
break;
case 'session.ended':
_onCallEnded();
break;
case 'error':
if (mounted) {
final detail = msg['message'] as String? ?? 'Unknown error';
ScaffoldMessenger.of(context).showSnackBar(
SnackBar(
content: Text('服务端错误: $detail'),
backgroundColor: AppColors.error,
),
);
}
break;
}
} catch (_) {
// Malformed JSON ignore
}
}
/// Start monitoring server liveness based on incoming pings.
///
/// The server sends `{"type": "ping"}` every 15 s. If we haven't received
/// one in 45 s (3 missed pings), we assume the connection is dead and
/// trigger a reconnect.
void _startHeartbeat() {
_heartbeatCheckTimer?.cancel();
// Mark initial timestamp so the first 45-second window starts now
_lastServerPing = DateTime.now();
// Periodically verify server is alive
_heartbeatCheckTimer = Timer.periodic(const Duration(seconds: 10), (_) {
if (_lastServerPing != null &&
DateTime.now().difference(_lastServerPing!).inSeconds > 45) {
_triggerReconnect();
}
});
}
/// Cancel heartbeat check timer.
void _stopHeartbeat() {
_heartbeatCheckTimer?.cancel();
_heartbeatCheckTimer = null;
}
// ---------------------------------------------------------------------------
// Auto-reconnect with exponential backoff
// ---------------------------------------------------------------------------
/// Called when the WebSocket drops unexpectedly (not user-initiated).
void _onWsDisconnected() {
if (_phase == _CallPhase.ended || _isReconnecting || _userEndedCall) return;
_triggerReconnect();
}
/// Attempt to reconnect using exponential backoff (1 s, 2 s, 4 s, 8 s,
/// 16 s) with a maximum of [_maxReconnectAttempts] tries.
Future<void> _triggerReconnect() async {
if (_isReconnecting || _phase == _CallPhase.ended || _userEndedCall) return;
_isReconnecting = true;
if (mounted) setState(() {});
// Tear down current connection & heartbeat
_stopHeartbeat();
await _audioSubscription?.cancel();
_audioSubscription = null;
try {
await _audioChannel?.sink.close();
} catch (_) {}
_audioChannel = null;
for (int attempt = 0; attempt < _maxReconnectAttempts; attempt++) {
_reconnectAttempts = attempt + 1;
if (!mounted || _phase == _CallPhase.ended || _userEndedCall) break;
final delaySecs = min(pow(2, attempt).toInt(), 16);
await Future.delayed(Duration(seconds: delaySecs));
if (!mounted || _phase == _CallPhase.ended || _userEndedCall) break;
try {
// Ask backend if session is still alive
final dio = ref.read(dioClientProvider);
final response = await dio.post(
'${ApiEndpoints.voice}/sessions/$_sessionId/reconnect',
);
final data = response.data as Map<String, dynamic>;
// Build new WebSocket URL
final config = ref.read(appConfigProvider);
String wsUrl;
final backendWsUrl = data['websocket_url'] as String?;
if (backendWsUrl != null && backendWsUrl.startsWith('/')) {
final uri = Uri.parse(config.wsBaseUrl);
wsUrl = '${uri.scheme}://${uri.host}:${uri.port}$backendWsUrl';
} else if (backendWsUrl != null) {
wsUrl = backendWsUrl;
} else {
wsUrl = '${config.wsBaseUrl}/api/v1/voice/ws/$_sessionId';
}
// Connect new WebSocket
_audioChannel = WebSocketChannel.connect(Uri.parse(wsUrl));
_audioSubscription = _audioChannel!.stream.listen(
_onWsData,
onDone: () => _onWsDisconnected(),
onError: (_) => _onWsDisconnected(),
);
// Reset audio jitter buffer for a fresh start
_pcmPlayer.reset();
// Restart heartbeat
_startHeartbeat();
_reconnectAttempts = 0;
_isReconnecting = false;
if (mounted) setState(() {});
return; // success
} catch (_) {
// Will retry on next iteration
}
}
// All reconnection attempts exhausted
_isReconnecting = false;
_onCallEnded();
}
/// Handle incoming audio from the agent side. /// Handle incoming audio from the agent side.
void _onAudioReceived(dynamic data) { void _onAudioReceived(dynamic data) {
if (!mounted || _phase != _CallPhase.active) return; if (!mounted || _phase != _CallPhase.active) return;
@ -203,11 +376,18 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
/// End the call and clean up all resources. /// End the call and clean up all resources.
Future<void> _endCall() async { Future<void> _endCall() async {
_userEndedCall = true;
_isReconnecting = false;
setState(() => _phase = _CallPhase.ended); setState(() => _phase = _CallPhase.ended);
_stopwatch.stop(); _stopwatch.stop();
_durationTimer?.cancel(); _durationTimer?.cancel();
_waveController.stop(); _waveController.stop();
// Cancel heartbeat & reconnect timers
_stopHeartbeat();
_reconnectTimer?.cancel();
_reconnectTimer = null;
// Stop mic // Stop mic
await _micSubscription?.cancel(); await _micSubscription?.cancel();
_micSubscription = null; _micSubscription = null;
@ -252,15 +432,17 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
return Scaffold( return Scaffold(
backgroundColor: AppColors.background, backgroundColor: AppColors.background,
body: SafeArea( body: SafeArea(
child: Column( child: Stack(
children: [
Column(
children: [ children: [
const Spacer(flex: 2), const Spacer(flex: 2),
_buildAvatar(), _buildAvatar(),
const SizedBox(height: 24), const SizedBox(height: 24),
Text( Text(
_statusText, _statusText,
style: style: const TextStyle(
const TextStyle(fontSize: 24, fontWeight: FontWeight.bold), fontSize: 24, fontWeight: FontWeight.bold),
), ),
const SizedBox(height: 8), const SizedBox(height: 8),
Text( Text(
@ -286,6 +468,42 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
const SizedBox(height: 48), const SizedBox(height: 48),
], ],
), ),
// Reconnecting overlay
if (_isReconnecting)
Positioned(
top: 0,
left: 0,
right: 0,
child: Container(
padding:
const EdgeInsets.symmetric(vertical: 10, horizontal: 16),
color: AppColors.warning.withOpacity(0.9),
child: Row(
mainAxisAlignment: MainAxisAlignment.center,
children: [
const SizedBox(
width: 16,
height: 16,
child: CircularProgressIndicator(
strokeWidth: 2,
color: Colors.white,
),
),
const SizedBox(width: 10),
Text(
'重新连接中... ($_reconnectAttempts/$_maxReconnectAttempts)',
style: const TextStyle(
color: Colors.white,
fontSize: 14,
fontWeight: FontWeight.w500,
),
),
],
),
),
),
],
),
), ),
); );
} }
@ -432,7 +650,10 @@ class _AgentCallPageState extends ConsumerState<AgentCallPage>
@override @override
void dispose() { void dispose() {
_userEndedCall = true;
_durationTimer?.cancel(); _durationTimer?.cancel();
_stopHeartbeat();
_reconnectTimer?.cancel();
_waveController.dispose(); _waveController.dispose();
_stopwatch.stop(); _stopwatch.stop();
_micSubscription?.cancel(); _micSubscription?.cancel();

View File

@ -1,12 +1,18 @@
import asyncio
import logging
import threading import threading
import time
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from ..config.settings import settings
from .health import router as health_router from .health import router as health_router
from .session_router import router as session_router from .session_router import router as session_router
from .twilio_webhook import router as twilio_router from .twilio_webhook import router as twilio_router
logger = logging.getLogger(__name__)
app = FastAPI( app = FastAPI(
title="IT0 Voice Service", title="IT0 Voice Service",
description="Real-time voice dialogue engine powered by Pipecat", description="Real-time voice dialogue engine powered by Pipecat",
@ -25,6 +31,32 @@ app.include_router(session_router, prefix="/api/v1/voice", tags=["sessions"])
app.include_router(twilio_router, prefix="/api/v1/twilio", tags=["twilio"]) app.include_router(twilio_router, prefix="/api/v1/twilio", tags=["twilio"])
async def _session_cleanup_loop():
"""Periodically remove sessions that have been disconnected longer than session_ttl.
Runs every 30 seconds. Sessions in "disconnected" state whose
``disconnected_at`` timestamp is older than ``settings.session_ttl``
are removed from ``app.state.sessions``.
"""
ttl = settings.session_ttl
while True:
await asyncio.sleep(30)
try:
sessions: dict = getattr(app.state, "sessions", {})
now = time.time()
expired = [
sid
for sid, s in sessions.items()
if s.get("status") == "disconnected"
and (now - s.get("disconnected_at", 0)) > ttl
]
for sid in expired:
sessions.pop(sid, None)
logger.info("Cleaned up expired session %s", sid)
except Exception:
logger.exception("Error in session cleanup loop")
def _load_models_sync(): def _load_models_sync():
"""Load ML models in a background thread (all blocking calls).""" """Load ML models in a background thread (all blocking calls)."""
from ..config.settings import settings from ..config.settings import settings
@ -96,11 +128,19 @@ async def startup():
app.state.stt = None app.state.stt = None
app.state.tts = None app.state.tts = None
app.state.vad = None app.state.vad = None
app.state.sessions = {}
# Load models in background thread so server responds to healthchecks immediately # Load models in background thread so server responds to healthchecks immediately
thread = threading.Thread(target=_load_models_sync, daemon=True) thread = threading.Thread(target=_load_models_sync, daemon=True)
thread.start() thread.start()
# Start background session cleanup task
app.state._cleanup_task = asyncio.create_task(_session_cleanup_loop())
logger.info(
"Session cleanup task started (ttl=%ds, check every 30s)",
settings.session_ttl,
)
print("Voice service ready (models loading in background).", flush=True) print("Voice service ready (models loading in background).", flush=True)
@ -108,3 +148,12 @@ async def startup():
async def shutdown(): async def shutdown():
"""Cleanup on shutdown.""" """Cleanup on shutdown."""
print("Voice service shutting down...", flush=True) print("Voice service shutting down...", flush=True)
# Cancel session cleanup task
cleanup_task = getattr(app.state, "_cleanup_task", None)
if cleanup_task is not None and not cleanup_task.done():
cleanup_task.cancel()
try:
await cleanup_task
except (asyncio.CancelledError, Exception):
pass

View File

@ -1,4 +1,7 @@
import asyncio import asyncio
import json
import logging
import time
import uuid import uuid
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Request, UploadFile, File from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Request, UploadFile, File
@ -6,9 +9,12 @@ from fastapi.responses import JSONResponse
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
from ..config.settings import settings
from ..pipeline.app_transport import AppTransport from ..pipeline.app_transport import AppTransport
from ..pipeline.base_pipeline import create_voice_pipeline from ..pipeline.base_pipeline import create_voice_pipeline
logger = logging.getLogger(__name__)
router = APIRouter() router = APIRouter()
@ -25,6 +31,39 @@ class SessionResponse(BaseModel):
websocket_url: str websocket_url: str
async def _heartbeat_sender(websocket: WebSocket, session: dict):
"""Send periodic pings to keep the connection alive and detect dead clients.
Runs as a parallel asyncio task alongside the Pipecat pipeline.
Sends a JSON ``{"type": "ping", "ts": <epoch_ms>}`` text frame every
``heartbeat_interval`` seconds. If the send fails the connection is dead
and the task exits, which will cause the pipeline to be cleaned up.
Note: Pipecat owns the WebSocket read loop (for binary audio frames), so
we cannot read client pong responses here. Instead we rely on the fact
that a failed ``send_text`` indicates a broken connection. The client
sends audio continuously during an active call, so Pipecat's pipeline
will also naturally detect disconnection.
"""
interval = settings.heartbeat_interval
try:
while True:
await asyncio.sleep(interval)
try:
await websocket.send_text(
json.dumps({"type": "ping", "ts": int(time.time() * 1000)})
)
except Exception:
# WebSocket already closed — exit so cleanup runs
logger.info(
"Heartbeat send failed for session %s, connection dead",
session.get("session_id", "?"),
)
return
except asyncio.CancelledError:
return
@router.post("/sessions", response_model=SessionResponse) @router.post("/sessions", response_model=SessionResponse)
async def create_session(request: CreateSessionRequest, req: Request): async def create_session(request: CreateSessionRequest, req: Request):
"""Create a new voice dialogue session.""" """Create a new voice dialogue session."""
@ -79,9 +118,62 @@ async def end_session(session_id: str, req: Request):
return {"status": "ended", "session_id": session_id} return {"status": "ended", "session_id": session_id}
@router.post("/sessions/{session_id}/reconnect", response_model=SessionResponse)
async def reconnect_session(session_id: str, req: Request):
"""Reconnect to a disconnected voice session.
If the session still exists and is in "disconnected" state (within the
``session_ttl`` window), return session info with a fresh websocket_url
so the client can open a new WebSocket connection.
"""
if not hasattr(req.app.state, "sessions"):
req.app.state.sessions = {}
session = req.app.state.sessions.get(session_id)
if session is None:
return JSONResponse(
status_code=404,
content={"error": "Session not found or expired", "session_id": session_id},
)
if session["status"] != "disconnected":
return JSONResponse(
status_code=409,
content={
"error": f"Session is in '{session['status']}' state, not reconnectable",
"session_id": session_id,
},
)
# Check if session has expired based on TTL
disconnected_at = session.get("disconnected_at", 0)
if time.time() - disconnected_at > settings.session_ttl:
# Expired -- clean it up
req.app.state.sessions.pop(session_id, None)
return JSONResponse(
status_code=404,
content={"error": "Session expired", "session_id": session_id},
)
websocket_url = f"/api/v1/voice/ws/{session_id}"
return SessionResponse(
session_id=session_id,
status="disconnected",
websocket_url=websocket_url,
)
@router.websocket("/ws/{session_id}") @router.websocket("/ws/{session_id}")
async def voice_websocket(websocket: WebSocket, session_id: str): async def voice_websocket(websocket: WebSocket, session_id: str):
"""WebSocket endpoint for real-time voice streaming.""" """WebSocket endpoint for real-time voice streaming.
Supports both fresh connections and reconnections. Binary frames carry
PCM audio and are handled by the Pipecat pipeline. Text frames carry
JSON control events (ping/pong) and are handled by a parallel task.
On disconnect the session is preserved in "disconnected" state for up to
``session_ttl`` seconds so the client can reconnect.
"""
await websocket.accept() await websocket.accept()
app = websocket.app app = websocket.app
@ -96,11 +188,31 @@ async def voice_websocket(websocket: WebSocket, session_id: str):
await websocket.close(code=4004, reason="Session not found") await websocket.close(code=4004, reason="Session not found")
return return
is_reconnect = session["status"] == "disconnected"
# Cancel any leftover pipeline task from previous connection
old_task = session.get("task")
if old_task is not None and isinstance(old_task, asyncio.Task) and not old_task.done():
old_task.cancel()
try:
await old_task
except (asyncio.CancelledError, Exception):
pass
# Update session status # Update session status
session["status"] = "active" session["status"] = "active"
pipeline_task = None pipeline_task = None
heartbeat_task = None
try: try:
# Notify client of successful reconnection
if is_reconnect:
logger.info("Session %s reconnected", session_id)
await websocket.send_text(
json.dumps({"type": "session.resumed", "session_id": session_id})
)
# Create the AppTransport from the websocket connection # Create the AppTransport from the websocket connection
transport = AppTransport(websocket) transport = AppTransport(websocket)
@ -124,17 +236,38 @@ async def voice_websocket(websocket: WebSocket, session_id: str):
pipeline_task = asyncio.create_task(task.run()) pipeline_task = asyncio.create_task(task.run())
session["task"] = pipeline_task session["task"] = pipeline_task
# Wait for the pipeline task to complete (ends on disconnect or cancel) # Start heartbeat sender as a parallel task
heartbeat_task = asyncio.create_task(
_heartbeat_sender(websocket, session)
)
# Wait for the pipeline to finish. The heartbeat or text-message
# handler may close the websocket which will also cause the pipeline
# to end.
await pipeline_task await pipeline_task
except WebSocketDisconnect: except WebSocketDisconnect:
pass pass
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as exc:
logger.exception("Unexpected error in voice_websocket for session %s: %s", session_id, exc)
try:
await websocket.send_text(
json.dumps({"type": "error", "message": str(exc)})
)
except Exception: except Exception:
pass pass
finally: finally:
# Cleanup: cancel pipeline if still running # Cancel heartbeat task
if heartbeat_task is not None and not heartbeat_task.done():
heartbeat_task.cancel()
try:
await heartbeat_task
except (asyncio.CancelledError, Exception):
pass
# Cancel pipeline task if still running
if pipeline_task is not None and not pipeline_task.done(): if pipeline_task is not None and not pipeline_task.done():
pipeline_task.cancel() pipeline_task.cancel()
try: try:
@ -142,10 +275,16 @@ async def voice_websocket(websocket: WebSocket, session_id: str):
except (asyncio.CancelledError, Exception): except (asyncio.CancelledError, Exception):
pass pass
# Update session status # Mark session as disconnected (preserve for reconnection)
if session_id in app.state.sessions: if session_id in app.state.sessions:
app.state.sessions[session_id]["status"] = "disconnected" app.state.sessions[session_id]["status"] = "disconnected"
app.state.sessions[session_id]["disconnected_at"] = time.time()
app.state.sessions[session_id]["task"] = None app.state.sessions[session_id]["task"] = None
logger.info(
"Session %s disconnected, preserved for %ds",
session_id,
settings.session_ttl,
)
# Ensure websocket is closed # Ensure websocket is closed
try: try:

View File

@ -36,6 +36,11 @@ class Settings(BaseSettings):
audio_sample_rate: int = 16000 audio_sample_rate: int = 16000
audio_channels: int = 1 audio_channels: int = 1
# Session
session_ttl: int = 60 # seconds before disconnected sessions are cleaned up
heartbeat_interval: int = 15 # seconds between pings
heartbeat_timeout: int = 45 # seconds before declaring dead connection
class Config: class Config:
env_file = ".env" env_file = ".env"