309 lines
10 KiB
Python
309 lines
10 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Request, UploadFile, File
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel
|
|
from typing import Optional
|
|
|
|
from ..config.settings import settings
|
|
from ..pipeline.app_transport import AppTransport
|
|
from ..pipeline.base_pipeline import create_voice_pipeline
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class CreateSessionRequest(BaseModel):
|
|
"""Request to create a voice dialogue session."""
|
|
execution_id: Optional[str] = None
|
|
agent_context: dict = {}
|
|
|
|
|
|
class SessionResponse(BaseModel):
|
|
"""Voice session info."""
|
|
session_id: str
|
|
status: str
|
|
websocket_url: str
|
|
|
|
|
|
async def _heartbeat_sender(websocket: WebSocket, session: dict):
|
|
"""Send periodic pings to keep the connection alive and detect dead clients.
|
|
|
|
Runs as a parallel asyncio task alongside the Pipecat pipeline.
|
|
Sends a JSON ``{"type": "ping", "ts": <epoch_ms>}`` text frame every
|
|
``heartbeat_interval`` seconds. If the send fails the connection is dead
|
|
and the task exits, which will cause the pipeline to be cleaned up.
|
|
|
|
Note: Pipecat owns the WebSocket read loop (for binary audio frames), so
|
|
we cannot read client pong responses here. Instead we rely on the fact
|
|
that a failed ``send_text`` indicates a broken connection. The client
|
|
sends audio continuously during an active call, so Pipecat's pipeline
|
|
will also naturally detect disconnection.
|
|
"""
|
|
interval = settings.heartbeat_interval
|
|
try:
|
|
while True:
|
|
await asyncio.sleep(interval)
|
|
try:
|
|
await websocket.send_text(
|
|
json.dumps({"type": "ping", "ts": int(time.time() * 1000)})
|
|
)
|
|
except Exception:
|
|
# WebSocket already closed — exit so cleanup runs
|
|
logger.info(
|
|
"Heartbeat send failed for session %s, connection dead",
|
|
session.get("session_id", "?"),
|
|
)
|
|
return
|
|
except asyncio.CancelledError:
|
|
return
|
|
|
|
|
|
@router.post("/sessions", response_model=SessionResponse)
|
|
async def create_session(request: CreateSessionRequest, req: Request):
|
|
"""Create a new voice dialogue session."""
|
|
# Generate a unique session ID
|
|
session_id = f"vs_{uuid.uuid4().hex[:12]}"
|
|
|
|
# Initialize sessions dict on app.state if not present
|
|
if not hasattr(req.app.state, "sessions"):
|
|
req.app.state.sessions = {}
|
|
|
|
# Store session metadata
|
|
req.app.state.sessions[session_id] = {
|
|
"session_id": session_id,
|
|
"status": "created",
|
|
"execution_id": request.execution_id,
|
|
"agent_context": request.agent_context,
|
|
"task": None,
|
|
}
|
|
|
|
websocket_url = f"/ws/voice/{session_id}"
|
|
|
|
return SessionResponse(
|
|
session_id=session_id,
|
|
status="created",
|
|
websocket_url=websocket_url,
|
|
)
|
|
|
|
|
|
@router.delete("/sessions/{session_id}")
|
|
async def end_session(session_id: str, req: Request):
|
|
"""End a voice dialogue session."""
|
|
# Initialize sessions dict on app.state if not present
|
|
if not hasattr(req.app.state, "sessions"):
|
|
req.app.state.sessions = {}
|
|
|
|
session = req.app.state.sessions.get(session_id)
|
|
if session is None:
|
|
return {"status": "not_found", "session_id": session_id}
|
|
|
|
# Cancel the pipeline task if it is still running
|
|
task = session.get("task")
|
|
if task is not None and isinstance(task, asyncio.Task) and not task.done():
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
|
|
# Remove from sessions
|
|
req.app.state.sessions.pop(session_id, None)
|
|
|
|
return {"status": "ended", "session_id": session_id}
|
|
|
|
|
|
@router.post("/sessions/{session_id}/reconnect", response_model=SessionResponse)
|
|
async def reconnect_session(session_id: str, req: Request):
|
|
"""Reconnect to a disconnected voice session.
|
|
|
|
If the session still exists and is in "disconnected" state (within the
|
|
``session_ttl`` window), return session info with a fresh websocket_url
|
|
so the client can open a new WebSocket connection.
|
|
"""
|
|
if not hasattr(req.app.state, "sessions"):
|
|
req.app.state.sessions = {}
|
|
|
|
session = req.app.state.sessions.get(session_id)
|
|
if session is None:
|
|
return JSONResponse(
|
|
status_code=404,
|
|
content={"error": "Session not found or expired", "session_id": session_id},
|
|
)
|
|
|
|
if session["status"] != "disconnected":
|
|
return JSONResponse(
|
|
status_code=409,
|
|
content={
|
|
"error": f"Session is in '{session['status']}' state, not reconnectable",
|
|
"session_id": session_id,
|
|
},
|
|
)
|
|
|
|
# Check if session has expired based on TTL
|
|
disconnected_at = session.get("disconnected_at", 0)
|
|
if time.time() - disconnected_at > settings.session_ttl:
|
|
# Expired -- clean it up
|
|
req.app.state.sessions.pop(session_id, None)
|
|
return JSONResponse(
|
|
status_code=404,
|
|
content={"error": "Session expired", "session_id": session_id},
|
|
)
|
|
|
|
websocket_url = f"/ws/voice/{session_id}"
|
|
return SessionResponse(
|
|
session_id=session_id,
|
|
status="disconnected",
|
|
websocket_url=websocket_url,
|
|
)
|
|
|
|
|
|
@router.websocket("/ws/{session_id}")
|
|
async def voice_websocket(websocket: WebSocket, session_id: str):
|
|
"""WebSocket endpoint for real-time voice streaming.
|
|
|
|
Supports both fresh connections and reconnections. Binary frames carry
|
|
PCM audio and are handled by the Pipecat pipeline. Text frames carry
|
|
JSON control events (ping/pong) and are handled by a parallel task.
|
|
|
|
On disconnect the session is preserved in "disconnected" state for up to
|
|
``session_ttl`` seconds so the client can reconnect.
|
|
"""
|
|
await websocket.accept()
|
|
|
|
app = websocket.app
|
|
|
|
# Initialize sessions dict on app.state if not present
|
|
if not hasattr(app.state, "sessions"):
|
|
app.state.sessions = {}
|
|
|
|
# Verify session exists
|
|
session = app.state.sessions.get(session_id)
|
|
if session is None:
|
|
await websocket.close(code=4004, reason="Session not found")
|
|
return
|
|
|
|
is_reconnect = session["status"] == "disconnected"
|
|
|
|
# Cancel any leftover pipeline task from previous connection
|
|
old_task = session.get("task")
|
|
if old_task is not None and isinstance(old_task, asyncio.Task) and not old_task.done():
|
|
old_task.cancel()
|
|
try:
|
|
await old_task
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
|
|
# Update session status
|
|
session["status"] = "active"
|
|
|
|
pipeline_task = None
|
|
heartbeat_task = None
|
|
|
|
try:
|
|
# Notify client of successful reconnection
|
|
if is_reconnect:
|
|
logger.info("Session %s reconnected", session_id)
|
|
await websocket.send_text(
|
|
json.dumps({"type": "session.resumed", "session_id": session_id})
|
|
)
|
|
|
|
# Create the AppTransport from the websocket connection
|
|
transport = AppTransport(websocket)
|
|
|
|
# Build the session context from stored session data
|
|
session_context = {
|
|
"session_id": session_id,
|
|
"execution_id": session.get("execution_id"),
|
|
"agent_context": session.get("agent_context", {}),
|
|
}
|
|
|
|
# Create the Pipecat voice pipeline using shared services from app.state
|
|
task = await create_voice_pipeline(
|
|
transport,
|
|
session_context,
|
|
stt=getattr(app.state, "stt", None),
|
|
tts=getattr(app.state, "tts", None),
|
|
vad=getattr(app.state, "vad", None),
|
|
)
|
|
|
|
# Run the pipeline task in the background
|
|
pipeline_task = asyncio.create_task(task.run())
|
|
session["task"] = pipeline_task
|
|
|
|
# Start heartbeat sender as a parallel task
|
|
heartbeat_task = asyncio.create_task(
|
|
_heartbeat_sender(websocket, session)
|
|
)
|
|
|
|
# Wait for the pipeline to finish. The heartbeat or text-message
|
|
# handler may close the websocket which will also cause the pipeline
|
|
# to end.
|
|
await pipeline_task
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception as exc:
|
|
logger.exception("Unexpected error in voice_websocket for session %s: %s", session_id, exc)
|
|
try:
|
|
await websocket.send_text(
|
|
json.dumps({"type": "error", "message": str(exc)})
|
|
)
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
# Cancel heartbeat task
|
|
if heartbeat_task is not None and not heartbeat_task.done():
|
|
heartbeat_task.cancel()
|
|
try:
|
|
await heartbeat_task
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
|
|
# Cancel pipeline task if still running
|
|
if pipeline_task is not None and not pipeline_task.done():
|
|
pipeline_task.cancel()
|
|
try:
|
|
await pipeline_task
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
|
|
# Mark session as disconnected (preserve for reconnection)
|
|
if session_id in app.state.sessions:
|
|
app.state.sessions[session_id]["status"] = "disconnected"
|
|
app.state.sessions[session_id]["disconnected_at"] = time.time()
|
|
app.state.sessions[session_id]["task"] = None
|
|
logger.info(
|
|
"Session %s disconnected, preserved for %ds",
|
|
session_id,
|
|
settings.session_ttl,
|
|
)
|
|
|
|
# Ensure websocket is closed
|
|
try:
|
|
await websocket.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
@router.post("/transcribe")
|
|
async def transcribe_audio(req: Request, audio: UploadFile = File(...)):
|
|
"""Transcribe uploaded audio (PCM 16kHz 16-bit mono) to text using Whisper."""
|
|
stt = getattr(req.app.state, "stt", None)
|
|
if stt is None or stt._model is None:
|
|
return JSONResponse(status_code=503, content={"error": "STT model not loaded"})
|
|
|
|
audio_data = await audio.read()
|
|
if len(audio_data) == 0:
|
|
return {"text": ""}
|
|
|
|
text = await stt.transcribe(audio_data)
|
|
return {"text": text.strip()}
|