it0/packages/services/voice-service/src/api/session_router.py

309 lines
10 KiB
Python

import asyncio
import json
import logging
import time
import uuid
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Request, UploadFile, File
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from typing import Optional
from ..config.settings import settings
from ..pipeline.app_transport import AppTransport
from ..pipeline.base_pipeline import create_voice_pipeline
logger = logging.getLogger(__name__)
router = APIRouter()
class CreateSessionRequest(BaseModel):
"""Request to create a voice dialogue session."""
execution_id: Optional[str] = None
agent_context: dict = {}
class SessionResponse(BaseModel):
"""Voice session info."""
session_id: str
status: str
websocket_url: str
async def _heartbeat_sender(websocket: WebSocket, session: dict):
"""Send periodic pings to keep the connection alive and detect dead clients.
Runs as a parallel asyncio task alongside the Pipecat pipeline.
Sends a JSON ``{"type": "ping", "ts": <epoch_ms>}`` text frame every
``heartbeat_interval`` seconds. If the send fails the connection is dead
and the task exits, which will cause the pipeline to be cleaned up.
Note: Pipecat owns the WebSocket read loop (for binary audio frames), so
we cannot read client pong responses here. Instead we rely on the fact
that a failed ``send_text`` indicates a broken connection. The client
sends audio continuously during an active call, so Pipecat's pipeline
will also naturally detect disconnection.
"""
interval = settings.heartbeat_interval
try:
while True:
await asyncio.sleep(interval)
try:
await websocket.send_text(
json.dumps({"type": "ping", "ts": int(time.time() * 1000)})
)
except Exception:
# WebSocket already closed — exit so cleanup runs
logger.info(
"Heartbeat send failed for session %s, connection dead",
session.get("session_id", "?"),
)
return
except asyncio.CancelledError:
return
@router.post("/sessions", response_model=SessionResponse)
async def create_session(request: CreateSessionRequest, req: Request):
"""Create a new voice dialogue session."""
# Generate a unique session ID
session_id = f"vs_{uuid.uuid4().hex[:12]}"
# Initialize sessions dict on app.state if not present
if not hasattr(req.app.state, "sessions"):
req.app.state.sessions = {}
# Store session metadata
req.app.state.sessions[session_id] = {
"session_id": session_id,
"status": "created",
"execution_id": request.execution_id,
"agent_context": request.agent_context,
"task": None,
}
websocket_url = f"/ws/voice/{session_id}"
return SessionResponse(
session_id=session_id,
status="created",
websocket_url=websocket_url,
)
@router.delete("/sessions/{session_id}")
async def end_session(session_id: str, req: Request):
"""End a voice dialogue session."""
# Initialize sessions dict on app.state if not present
if not hasattr(req.app.state, "sessions"):
req.app.state.sessions = {}
session = req.app.state.sessions.get(session_id)
if session is None:
return {"status": "not_found", "session_id": session_id}
# Cancel the pipeline task if it is still running
task = session.get("task")
if task is not None and isinstance(task, asyncio.Task) and not task.done():
task.cancel()
try:
await task
except (asyncio.CancelledError, Exception):
pass
# Remove from sessions
req.app.state.sessions.pop(session_id, None)
return {"status": "ended", "session_id": session_id}
@router.post("/sessions/{session_id}/reconnect", response_model=SessionResponse)
async def reconnect_session(session_id: str, req: Request):
"""Reconnect to a disconnected voice session.
If the session still exists and is in "disconnected" state (within the
``session_ttl`` window), return session info with a fresh websocket_url
so the client can open a new WebSocket connection.
"""
if not hasattr(req.app.state, "sessions"):
req.app.state.sessions = {}
session = req.app.state.sessions.get(session_id)
if session is None:
return JSONResponse(
status_code=404,
content={"error": "Session not found or expired", "session_id": session_id},
)
if session["status"] != "disconnected":
return JSONResponse(
status_code=409,
content={
"error": f"Session is in '{session['status']}' state, not reconnectable",
"session_id": session_id,
},
)
# Check if session has expired based on TTL
disconnected_at = session.get("disconnected_at", 0)
if time.time() - disconnected_at > settings.session_ttl:
# Expired -- clean it up
req.app.state.sessions.pop(session_id, None)
return JSONResponse(
status_code=404,
content={"error": "Session expired", "session_id": session_id},
)
websocket_url = f"/ws/voice/{session_id}"
return SessionResponse(
session_id=session_id,
status="disconnected",
websocket_url=websocket_url,
)
@router.websocket("/ws/{session_id}")
async def voice_websocket(websocket: WebSocket, session_id: str):
"""WebSocket endpoint for real-time voice streaming.
Supports both fresh connections and reconnections. Binary frames carry
PCM audio and are handled by the Pipecat pipeline. Text frames carry
JSON control events (ping/pong) and are handled by a parallel task.
On disconnect the session is preserved in "disconnected" state for up to
``session_ttl`` seconds so the client can reconnect.
"""
await websocket.accept()
app = websocket.app
# Initialize sessions dict on app.state if not present
if not hasattr(app.state, "sessions"):
app.state.sessions = {}
# Verify session exists
session = app.state.sessions.get(session_id)
if session is None:
await websocket.close(code=4004, reason="Session not found")
return
is_reconnect = session["status"] == "disconnected"
# Cancel any leftover pipeline task from previous connection
old_task = session.get("task")
if old_task is not None and isinstance(old_task, asyncio.Task) and not old_task.done():
old_task.cancel()
try:
await old_task
except (asyncio.CancelledError, Exception):
pass
# Update session status
session["status"] = "active"
pipeline_task = None
heartbeat_task = None
try:
# Notify client of successful reconnection
if is_reconnect:
logger.info("Session %s reconnected", session_id)
await websocket.send_text(
json.dumps({"type": "session.resumed", "session_id": session_id})
)
# Create the AppTransport from the websocket connection
transport = AppTransport(websocket)
# Build the session context from stored session data
session_context = {
"session_id": session_id,
"execution_id": session.get("execution_id"),
"agent_context": session.get("agent_context", {}),
}
# Create the Pipecat voice pipeline using shared services from app.state
task = await create_voice_pipeline(
transport,
session_context,
stt=getattr(app.state, "stt", None),
tts=getattr(app.state, "tts", None),
vad=getattr(app.state, "vad", None),
)
# Run the pipeline task in the background
pipeline_task = asyncio.create_task(task.run())
session["task"] = pipeline_task
# Start heartbeat sender as a parallel task
heartbeat_task = asyncio.create_task(
_heartbeat_sender(websocket, session)
)
# Wait for the pipeline to finish. The heartbeat or text-message
# handler may close the websocket which will also cause the pipeline
# to end.
await pipeline_task
except WebSocketDisconnect:
pass
except asyncio.CancelledError:
pass
except Exception as exc:
logger.exception("Unexpected error in voice_websocket for session %s: %s", session_id, exc)
try:
await websocket.send_text(
json.dumps({"type": "error", "message": str(exc)})
)
except Exception:
pass
finally:
# Cancel heartbeat task
if heartbeat_task is not None and not heartbeat_task.done():
heartbeat_task.cancel()
try:
await heartbeat_task
except (asyncio.CancelledError, Exception):
pass
# Cancel pipeline task if still running
if pipeline_task is not None and not pipeline_task.done():
pipeline_task.cancel()
try:
await pipeline_task
except (asyncio.CancelledError, Exception):
pass
# Mark session as disconnected (preserve for reconnection)
if session_id in app.state.sessions:
app.state.sessions[session_id]["status"] = "disconnected"
app.state.sessions[session_id]["disconnected_at"] = time.time()
app.state.sessions[session_id]["task"] = None
logger.info(
"Session %s disconnected, preserved for %ds",
session_id,
settings.session_ttl,
)
# Ensure websocket is closed
try:
await websocket.close()
except Exception:
pass
@router.post("/transcribe")
async def transcribe_audio(req: Request, audio: UploadFile = File(...)):
"""Transcribe uploaded audio (PCM 16kHz 16-bit mono) to text using Whisper."""
stt = getattr(req.app.state, "stt", None)
if stt is None or stt._model is None:
return JSONResponse(status_code=503, content={"error": "STT model not loaded"})
audio_data = await audio.read()
if len(audio_data) == 0:
return {"text": ""}
text = await stt.transcribe(audio_data)
return {"text": text.strip()}