fix: subscribe to agent WS before creating task to avoid race condition

The engine stream could emit text events before the voice pipeline
subscribed, causing all text to be lost.  Now we connect and subscribe
first, then POST the task.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-02-24 02:35:57 -08:00
parent 82d12a5ff5
commit 370e32599f
1 changed files with 44 additions and 33 deletions

View File

@ -225,9 +225,9 @@ class VoicePipelineTask:
async def _agent_generate(self, user_text: str) -> str: async def _agent_generate(self, user_text: str) -> str:
"""Send user text to agent-service, subscribe via WS, collect response. """Send user text to agent-service, subscribe via WS, collect response.
Mirrors the Flutter chat flow: Flow (subscribe-first to avoid race condition):
1. POST /api/v1/agent/tasks get sessionId + taskId 1. WS connect to /ws/agent subscribe_session (with existing or new sessionId)
2. WS connect to /ws/agent subscribe_session 2. POST /api/v1/agent/tasks triggers engine stream
3. Collect 'text' stream events until 'completed' 3. Collect 'text' stream events until 'completed'
""" """
agent_url = settings.agent_service_url # http://agent-service:3002 agent_url = settings.agent_service_url # http://agent-service:3002
@ -235,8 +235,27 @@ class VoicePipelineTask:
if self._auth_header: if self._auth_header:
headers["Authorization"] = self._auth_header headers["Authorization"] = self._auth_header
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/ws/agent"
try: try:
# 1. Create agent task collected_text = []
timeout_secs = 120 # Max wait for agent response
async with websockets.connect(ws_url) as ws:
# 1. Subscribe FIRST (before creating task to avoid missing events)
# Use existing session ID for subscription; if none, we'll re-subscribe after task creation
pre_session_id = self._agent_session_id or ""
if pre_session_id:
subscribe_msg = json.dumps({
"event": "subscribe_session",
"data": {"sessionId": pre_session_id},
})
await ws.send(subscribe_msg)
print(f"[pipeline] Pre-subscribed to agent WS session={pre_session_id}", flush=True)
# 2. Create agent task
body = {"prompt": user_text} body = {"prompt": user_text}
if self._agent_session_id: if self._agent_session_id:
body["sessionId"] = self._agent_session_id body["sessionId"] = self._agent_session_id
@ -258,23 +277,15 @@ class VoicePipelineTask:
self._agent_session_id = session_id self._agent_session_id = session_id
print(f"[pipeline] Agent task created: session={session_id}, task={task_id}", flush=True) print(f"[pipeline] Agent task created: session={session_id}, task={task_id}", flush=True)
# 2. Subscribe via WebSocket and collect text events # 3. Subscribe with actual session/task IDs (covers first-time case)
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/ws/agent"
collected_text = []
timeout_secs = 60 # Max wait for agent response
async with websockets.connect(ws_url) as ws:
# Subscribe to the session
subscribe_msg = json.dumps({ subscribe_msg = json.dumps({
"event": "subscribe_session", "event": "subscribe_session",
"data": {"sessionId": session_id, "taskId": task_id}, "data": {"sessionId": session_id, "taskId": task_id},
}) })
await ws.send(subscribe_msg) await ws.send(subscribe_msg)
print(f"[pipeline] Subscribed to agent WS session={session_id}", flush=True) print(f"[pipeline] Subscribed to agent WS session={session_id}, task={task_id}", flush=True)
# Collect events until completed # 4. Collect events until completed
deadline = time.time() + timeout_secs deadline = time.time() + timeout_secs
while time.time() < deadline: while time.time() < deadline:
try: try: