import asyncio import json import logging import time import uuid from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Request, UploadFile, File from fastapi.responses import JSONResponse from pydantic import BaseModel from typing import Optional from ..config.settings import settings from ..pipeline.app_transport import AppTransport from ..pipeline.base_pipeline import create_voice_pipeline logger = logging.getLogger(__name__) router = APIRouter() class CreateSessionRequest(BaseModel): """Request to create a voice dialogue session.""" execution_id: Optional[str] = None agent_context: dict = {} class SessionResponse(BaseModel): """Voice session info.""" session_id: str status: str websocket_url: str async def _heartbeat_sender(websocket: WebSocket, session: dict): """Send periodic pings to keep the connection alive and detect dead clients. Runs as a parallel asyncio task alongside the Pipecat pipeline. Sends a JSON ``{"type": "ping", "ts": }`` text frame every ``heartbeat_interval`` seconds. If the send fails the connection is dead and the task exits, which will cause the pipeline to be cleaned up. Note: Pipecat owns the WebSocket read loop (for binary audio frames), so we cannot read client pong responses here. Instead we rely on the fact that a failed ``send_text`` indicates a broken connection. The client sends audio continuously during an active call, so Pipecat's pipeline will also naturally detect disconnection. """ interval = settings.heartbeat_interval try: while True: await asyncio.sleep(interval) try: await websocket.send_text( json.dumps({"type": "ping", "ts": int(time.time() * 1000)}) ) except Exception: # WebSocket already closed — exit so cleanup runs logger.info( "Heartbeat send failed for session %s, connection dead", session.get("session_id", "?"), ) return except asyncio.CancelledError: return @router.post("/sessions", response_model=SessionResponse) async def create_session(request: CreateSessionRequest, req: Request): """Create a new voice dialogue session.""" # Generate a unique session ID session_id = f"vs_{uuid.uuid4().hex[:12]}" # Initialize sessions dict on app.state if not present if not hasattr(req.app.state, "sessions"): req.app.state.sessions = {} # Store session metadata req.app.state.sessions[session_id] = { "session_id": session_id, "status": "created", "execution_id": request.execution_id, "agent_context": request.agent_context, "task": None, } websocket_url = f"/ws/voice/{session_id}" return SessionResponse( session_id=session_id, status="created", websocket_url=websocket_url, ) @router.delete("/sessions/{session_id}") async def end_session(session_id: str, req: Request): """End a voice dialogue session.""" # Initialize sessions dict on app.state if not present if not hasattr(req.app.state, "sessions"): req.app.state.sessions = {} session = req.app.state.sessions.get(session_id) if session is None: return {"status": "not_found", "session_id": session_id} # Cancel the pipeline task if it is still running task = session.get("task") if task is not None and isinstance(task, asyncio.Task) and not task.done(): task.cancel() try: await task except (asyncio.CancelledError, Exception): pass # Remove from sessions req.app.state.sessions.pop(session_id, None) return {"status": "ended", "session_id": session_id} @router.post("/sessions/{session_id}/reconnect", response_model=SessionResponse) async def reconnect_session(session_id: str, req: Request): """Reconnect to a disconnected voice session. If the session still exists and is in "disconnected" state (within the ``session_ttl`` window), return session info with a fresh websocket_url so the client can open a new WebSocket connection. """ if not hasattr(req.app.state, "sessions"): req.app.state.sessions = {} session = req.app.state.sessions.get(session_id) if session is None: return JSONResponse( status_code=404, content={"error": "Session not found or expired", "session_id": session_id}, ) if session["status"] != "disconnected": return JSONResponse( status_code=409, content={ "error": f"Session is in '{session['status']}' state, not reconnectable", "session_id": session_id, }, ) # Check if session has expired based on TTL disconnected_at = session.get("disconnected_at", 0) if time.time() - disconnected_at > settings.session_ttl: # Expired -- clean it up req.app.state.sessions.pop(session_id, None) return JSONResponse( status_code=404, content={"error": "Session expired", "session_id": session_id}, ) websocket_url = f"/ws/voice/{session_id}" return SessionResponse( session_id=session_id, status="disconnected", websocket_url=websocket_url, ) @router.websocket("/ws/{session_id}") async def voice_websocket(websocket: WebSocket, session_id: str): """WebSocket endpoint for real-time voice streaming. Supports both fresh connections and reconnections. Binary frames carry PCM audio and are handled by the Pipecat pipeline. Text frames carry JSON control events (ping/pong) and are handled by a parallel task. On disconnect the session is preserved in "disconnected" state for up to ``session_ttl`` seconds so the client can reconnect. """ await websocket.accept() app = websocket.app # Initialize sessions dict on app.state if not present if not hasattr(app.state, "sessions"): app.state.sessions = {} # Verify session exists session = app.state.sessions.get(session_id) if session is None: await websocket.close(code=4004, reason="Session not found") return is_reconnect = session["status"] == "disconnected" # Cancel any leftover pipeline task from previous connection old_task = session.get("task") if old_task is not None and isinstance(old_task, asyncio.Task) and not old_task.done(): old_task.cancel() try: await old_task except (asyncio.CancelledError, Exception): pass # Update session status session["status"] = "active" pipeline_task = None heartbeat_task = None try: # Notify client of successful reconnection if is_reconnect: logger.info("Session %s reconnected", session_id) await websocket.send_text( json.dumps({"type": "session.resumed", "session_id": session_id}) ) # Create the AppTransport from the websocket connection transport = AppTransport(websocket) # Build the session context from stored session data session_context = { "session_id": session_id, "execution_id": session.get("execution_id"), "agent_context": session.get("agent_context", {}), } # Create the Pipecat voice pipeline using shared services from app.state task = await create_voice_pipeline( transport, session_context, stt=getattr(app.state, "stt", None), tts=getattr(app.state, "tts", None), vad=getattr(app.state, "vad", None), ) # Run the pipeline task in the background pipeline_task = asyncio.create_task(task.run()) session["task"] = pipeline_task # Start heartbeat sender as a parallel task heartbeat_task = asyncio.create_task( _heartbeat_sender(websocket, session) ) # Wait for the pipeline to finish. The heartbeat or text-message # handler may close the websocket which will also cause the pipeline # to end. await pipeline_task except WebSocketDisconnect: pass except asyncio.CancelledError: pass except Exception as exc: logger.exception("Unexpected error in voice_websocket for session %s: %s", session_id, exc) try: await websocket.send_text( json.dumps({"type": "error", "message": str(exc)}) ) except Exception: pass finally: # Cancel heartbeat task if heartbeat_task is not None and not heartbeat_task.done(): heartbeat_task.cancel() try: await heartbeat_task except (asyncio.CancelledError, Exception): pass # Cancel pipeline task if still running if pipeline_task is not None and not pipeline_task.done(): pipeline_task.cancel() try: await pipeline_task except (asyncio.CancelledError, Exception): pass # Mark session as disconnected (preserve for reconnection) if session_id in app.state.sessions: app.state.sessions[session_id]["status"] = "disconnected" app.state.sessions[session_id]["disconnected_at"] = time.time() app.state.sessions[session_id]["task"] = None logger.info( "Session %s disconnected, preserved for %ds", session_id, settings.session_ttl, ) # Ensure websocket is closed try: await websocket.close() except Exception: pass @router.post("/transcribe") async def transcribe_audio(req: Request, audio: UploadFile = File(...)): """Transcribe uploaded audio (PCM 16kHz 16-bit mono) to text using Whisper.""" stt = getattr(req.app.state, "stt", None) if stt is None or stt._model is None: return JSONResponse(status_code=503, content={"error": "STT model not loaded"}) audio_data = await audio.read() if len(audio_data) == 0: return {"text": ""} text = await stt.transcribe(audio_data) return {"text": text.strip()}