154 lines
4.5 KiB
Python
154 lines
4.5 KiB
Python
import asyncio
|
|
import uuid
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Request
|
|
from pydantic import BaseModel
|
|
from typing import Optional
|
|
|
|
from ..pipeline.app_transport import AppTransport
|
|
from ..pipeline.base_pipeline import create_voice_pipeline
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class CreateSessionRequest(BaseModel):
|
|
"""Request to create a voice dialogue session."""
|
|
execution_id: Optional[str] = None
|
|
agent_context: dict = {}
|
|
|
|
|
|
class SessionResponse(BaseModel):
|
|
"""Voice session info."""
|
|
session_id: str
|
|
status: str
|
|
websocket_url: str
|
|
|
|
|
|
@router.post("/sessions", response_model=SessionResponse)
|
|
async def create_session(request: CreateSessionRequest, req: Request):
|
|
"""Create a new voice dialogue session."""
|
|
# Generate a unique session ID
|
|
session_id = f"vs_{uuid.uuid4().hex[:12]}"
|
|
|
|
# Initialize sessions dict on app.state if not present
|
|
if not hasattr(req.app.state, "sessions"):
|
|
req.app.state.sessions = {}
|
|
|
|
# Store session metadata
|
|
req.app.state.sessions[session_id] = {
|
|
"session_id": session_id,
|
|
"status": "created",
|
|
"execution_id": request.execution_id,
|
|
"agent_context": request.agent_context,
|
|
"task": None,
|
|
}
|
|
|
|
websocket_url = f"/api/v1/voice/ws/{session_id}"
|
|
|
|
return SessionResponse(
|
|
session_id=session_id,
|
|
status="created",
|
|
websocket_url=websocket_url,
|
|
)
|
|
|
|
|
|
@router.delete("/sessions/{session_id}")
|
|
async def end_session(session_id: str, req: Request):
|
|
"""End a voice dialogue session."""
|
|
# Initialize sessions dict on app.state if not present
|
|
if not hasattr(req.app.state, "sessions"):
|
|
req.app.state.sessions = {}
|
|
|
|
session = req.app.state.sessions.get(session_id)
|
|
if session is None:
|
|
return {"status": "not_found", "session_id": session_id}
|
|
|
|
# Cancel the pipeline task if it is still running
|
|
task = session.get("task")
|
|
if task is not None and isinstance(task, asyncio.Task) and not task.done():
|
|
task.cancel()
|
|
try:
|
|
await task
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
|
|
# Remove from sessions
|
|
req.app.state.sessions.pop(session_id, None)
|
|
|
|
return {"status": "ended", "session_id": session_id}
|
|
|
|
|
|
@router.websocket("/ws/{session_id}")
|
|
async def voice_websocket(websocket: WebSocket, session_id: str):
|
|
"""WebSocket endpoint for real-time voice streaming."""
|
|
await websocket.accept()
|
|
|
|
app = websocket.app
|
|
|
|
# Initialize sessions dict on app.state if not present
|
|
if not hasattr(app.state, "sessions"):
|
|
app.state.sessions = {}
|
|
|
|
# Verify session exists
|
|
session = app.state.sessions.get(session_id)
|
|
if session is None:
|
|
await websocket.close(code=4004, reason="Session not found")
|
|
return
|
|
|
|
# Update session status
|
|
session["status"] = "active"
|
|
|
|
pipeline_task = None
|
|
try:
|
|
# Create the AppTransport from the websocket connection
|
|
transport = AppTransport(websocket)
|
|
|
|
# Build the session context from stored session data
|
|
session_context = {
|
|
"session_id": session_id,
|
|
"execution_id": session.get("execution_id"),
|
|
"agent_context": session.get("agent_context", {}),
|
|
}
|
|
|
|
# Create the Pipecat voice pipeline using shared services from app.state
|
|
task = await create_voice_pipeline(
|
|
transport,
|
|
session_context,
|
|
stt=getattr(app.state, "stt", None),
|
|
tts=getattr(app.state, "tts", None),
|
|
vad=getattr(app.state, "vad", None),
|
|
)
|
|
|
|
# Run the pipeline task in the background
|
|
pipeline_task = asyncio.create_task(task.run())
|
|
session["task"] = pipeline_task
|
|
|
|
# Wait for the pipeline task to complete (ends on disconnect or cancel)
|
|
await pipeline_task
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|
|
except asyncio.CancelledError:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
# Cleanup: cancel pipeline if still running
|
|
if pipeline_task is not None and not pipeline_task.done():
|
|
pipeline_task.cancel()
|
|
try:
|
|
await pipeline_task
|
|
except (asyncio.CancelledError, Exception):
|
|
pass
|
|
|
|
# Update session status
|
|
if session_id in app.state.sessions:
|
|
app.state.sessions[session_id]["status"] = "disconnected"
|
|
app.state.sessions[session_id]["task"] = None
|
|
|
|
# Ensure websocket is closed
|
|
try:
|
|
await websocket.close()
|
|
except Exception:
|
|
pass
|