fix: subscribe to agent WS before creating task to avoid race condition

The engine stream could emit text events before the voice pipeline
subscribed, causing all text to be lost.  Now we connect and subscribe
first, then POST the task.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
hailin 2026-02-24 02:35:57 -08:00
parent 82d12a5ff5
commit 370e32599f
1 changed files with 44 additions and 33 deletions

View File

@ -225,9 +225,9 @@ class VoicePipelineTask:
async def _agent_generate(self, user_text: str) -> str:
"""Send user text to agent-service, subscribe via WS, collect response.
Mirrors the Flutter chat flow:
1. POST /api/v1/agent/tasks get sessionId + taskId
2. WS connect to /ws/agent subscribe_session
Flow (subscribe-first to avoid race condition):
1. WS connect to /ws/agent subscribe_session (with existing or new sessionId)
2. POST /api/v1/agent/tasks triggers engine stream
3. Collect 'text' stream events until 'completed'
"""
agent_url = settings.agent_service_url # http://agent-service:3002
@ -235,46 +235,57 @@ class VoicePipelineTask:
if self._auth_header:
headers["Authorization"] = self._auth_header
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/ws/agent"
try:
# 1. Create agent task
body = {"prompt": user_text}
if self._agent_session_id:
body["sessionId"] = self._agent_session_id
print(f"[pipeline] Creating agent task: {user_text[:60]}", flush=True)
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{agent_url}/api/v1/agent/tasks",
json=body,
headers=headers,
)
if resp.status_code != 200 and resp.status_code != 201:
print(f"[pipeline] Agent task creation failed: {resp.status_code} {resp.text}", flush=True)
return "抱歉Agent服务暂时不可用。"
data = resp.json()
session_id = data.get("sessionId", "")
task_id = data.get("taskId", "")
self._agent_session_id = session_id
print(f"[pipeline] Agent task created: session={session_id}, task={task_id}", flush=True)
# 2. Subscribe via WebSocket and collect text events
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/ws/agent"
collected_text = []
timeout_secs = 60 # Max wait for agent response
timeout_secs = 120 # Max wait for agent response
async with websockets.connect(ws_url) as ws:
# Subscribe to the session
# 1. Subscribe FIRST (before creating task to avoid missing events)
# Use existing session ID for subscription; if none, we'll re-subscribe after task creation
pre_session_id = self._agent_session_id or ""
if pre_session_id:
subscribe_msg = json.dumps({
"event": "subscribe_session",
"data": {"sessionId": pre_session_id},
})
await ws.send(subscribe_msg)
print(f"[pipeline] Pre-subscribed to agent WS session={pre_session_id}", flush=True)
# 2. Create agent task
body = {"prompt": user_text}
if self._agent_session_id:
body["sessionId"] = self._agent_session_id
print(f"[pipeline] Creating agent task: {user_text[:60]}", flush=True)
async with httpx.AsyncClient(timeout=30) as client:
resp = await client.post(
f"{agent_url}/api/v1/agent/tasks",
json=body,
headers=headers,
)
if resp.status_code != 200 and resp.status_code != 201:
print(f"[pipeline] Agent task creation failed: {resp.status_code} {resp.text}", flush=True)
return "抱歉Agent服务暂时不可用。"
data = resp.json()
session_id = data.get("sessionId", "")
task_id = data.get("taskId", "")
self._agent_session_id = session_id
print(f"[pipeline] Agent task created: session={session_id}, task={task_id}", flush=True)
# 3. Subscribe with actual session/task IDs (covers first-time case)
subscribe_msg = json.dumps({
"event": "subscribe_session",
"data": {"sessionId": session_id, "taskId": task_id},
})
await ws.send(subscribe_msg)
print(f"[pipeline] Subscribed to agent WS session={session_id}", flush=True)
print(f"[pipeline] Subscribed to agent WS session={session_id}, task={task_id}", flush=True)
# Collect events until completed
# 4. Collect events until completed
deadline = time.time() + timeout_secs
while time.time() < deadline:
try: