feat: route voice pipeline through agent-service instead of direct LLM
Voice calls now use the same agent task + WS subscription flow as the chat UI, enabling tool use and command execution during voice sessions. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
7afbd54fce
commit
abf5e29419
|
|
@ -318,9 +318,11 @@ services:
|
||||||
- "13008:3008"
|
- "13008:3008"
|
||||||
environment:
|
environment:
|
||||||
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
|
||||||
|
- ANTHROPIC_BASE_URL=${ANTHROPIC_BASE_URL}
|
||||||
- AGENT_SERVICE_URL=http://agent-service:3002
|
- AGENT_SERVICE_URL=http://agent-service:3002
|
||||||
- WHISPER_MODEL=${WHISPER_MODEL:-base}
|
- WHISPER_MODEL=${WHISPER_MODEL:-base}
|
||||||
- KOKORO_MODEL=${KOKORO_MODEL:-kokoro-82m}
|
- KOKORO_MODEL=${KOKORO_MODEL:-kokoro-82m}
|
||||||
|
- KOKORO_VOICE=${KOKORO_VOICE:-zf_xiaoxiao}
|
||||||
- DEVICE=${VOICE_DEVICE:-cpu}
|
- DEVICE=${VOICE_DEVICE:-cpu}
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD-SHELL", "python3 -c \"import urllib.request; urllib.request.urlopen('http://localhost:3008/docs')\""]
|
test: ["CMD-SHELL", "python3 -c \"import urllib.request; urllib.request.urlopen('http://localhost:3008/docs')\""]
|
||||||
|
|
|
||||||
|
|
@ -73,12 +73,16 @@ async def create_session(request: CreateSessionRequest, req: Request):
|
||||||
if not hasattr(req.app.state, "sessions"):
|
if not hasattr(req.app.state, "sessions"):
|
||||||
req.app.state.sessions = {}
|
req.app.state.sessions = {}
|
||||||
|
|
||||||
|
# Capture JWT from the incoming request (Kong forwards it)
|
||||||
|
auth_header = req.headers.get("authorization", "")
|
||||||
|
|
||||||
# Store session metadata
|
# Store session metadata
|
||||||
req.app.state.sessions[session_id] = {
|
req.app.state.sessions[session_id] = {
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"status": "created",
|
"status": "created",
|
||||||
"execution_id": request.execution_id,
|
"execution_id": request.execution_id,
|
||||||
"agent_context": request.agent_context,
|
"agent_context": request.agent_context,
|
||||||
|
"auth_header": auth_header,
|
||||||
"task": None,
|
"task": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -217,6 +221,7 @@ async def voice_websocket(websocket: WebSocket, session_id: str):
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"execution_id": session.get("execution_id"),
|
"execution_id": session.get("execution_id"),
|
||||||
"agent_context": session.get("agent_context", {}),
|
"agent_context": session.get("agent_context", {}),
|
||||||
|
"auth_header": session.get("auth_header", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create the voice pipeline using the WebSocket directly
|
# Create the voice pipeline using the WebSocket directly
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,8 @@ class Settings(BaseSettings):
|
||||||
|
|
||||||
# Claude API
|
# Claude API
|
||||||
anthropic_api_key: str = ""
|
anthropic_api_key: str = ""
|
||||||
claude_model: str = "claude-sonnet-4-5-20250929"
|
anthropic_base_url: str = ""
|
||||||
|
claude_model: str = "claude-sonnet-4-5-20250514"
|
||||||
|
|
||||||
# Agent Service
|
# Agent Service
|
||||||
agent_service_url: str = "http://agent-service:3002"
|
agent_service_url: str = "http://agent-service:3002"
|
||||||
|
|
|
||||||
|
|
@ -1,20 +1,23 @@
|
||||||
"""
|
"""
|
||||||
Voice dialogue pipeline — direct WebSocket audio I/O.
|
Voice dialogue pipeline — direct WebSocket audio I/O.
|
||||||
|
|
||||||
Pipeline: Audio Input → VAD → STT → LLM → TTS → Audio Output
|
Pipeline: Audio Input → VAD → STT → Agent Service → TTS → Audio Output
|
||||||
|
|
||||||
Runs as an async task that reads binary PCM frames from a FastAPI WebSocket,
|
Runs as an async task that reads binary PCM frames from a FastAPI WebSocket,
|
||||||
detects speech with VAD, transcribes with STT, generates a response via
|
detects speech with VAD, transcribes with STT, sends text to the Agent
|
||||||
Claude LLM, synthesizes speech with TTS, and sends audio back.
|
service (which uses Claude SDK with tools), and synthesizes the response
|
||||||
|
back as speech via TTS.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import anthropic
|
import httpx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import websockets
|
||||||
from fastapi import WebSocket
|
from fastapi import WebSocket
|
||||||
|
|
||||||
from ..config.settings import settings
|
from ..config.settings import settings
|
||||||
|
|
@ -22,7 +25,9 @@ from ..stt.whisper_service import WhisperSTTService
|
||||||
from ..tts.kokoro_service import KokoroTTSService
|
from ..tts.kokoro_service import KokoroTTSService
|
||||||
from ..vad.silero_service import SileroVADService
|
from ..vad.silero_service import SileroVADService
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
# Minimum speech duration in seconds before we transcribe
|
# Minimum speech duration in seconds before we transcribe
|
||||||
_MIN_SPEECH_SECS = 0.5
|
_MIN_SPEECH_SECS = 0.5
|
||||||
|
|
@ -57,26 +62,20 @@ class VoicePipelineTask:
|
||||||
self.tts = tts
|
self.tts = tts
|
||||||
self.vad = vad
|
self.vad = vad
|
||||||
|
|
||||||
self._conversation: list[dict] = [
|
# Agent session ID (reused across turns for conversation continuity)
|
||||||
{
|
self._agent_session_id: Optional[str] = None
|
||||||
"role": "user",
|
self._auth_header: str = session_context.get("auth_header", "")
|
||||||
"content": (
|
|
||||||
"You are iAgent, an AI voice assistant for IT operations. "
|
|
||||||
"Respond concisely in Chinese. Keep answers under 2 sentences "
|
|
||||||
"when possible. You are in a real-time voice conversation."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "好的,我是 iAgent 智能运维语音助手。有什么可以帮您的?",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
self._cancelled = False
|
self._cancelled = False
|
||||||
self._speaking = False # True while sending TTS audio to client
|
self._speaking = False # True while sending TTS audio to client
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
"""Main loop: read audio → VAD → STT → LLM → TTS → send audio."""
|
"""Main loop: read audio → VAD → STT → LLM → TTS → send audio."""
|
||||||
logger.info("Voice pipeline started for session %s", self.session_context.get("session_id"))
|
sid = self.session_context.get("session_id", "?")
|
||||||
|
print(f"[pipeline] Started for session {sid}", flush=True)
|
||||||
|
print(f"[pipeline] STT={self.stt is not None and self.stt._model is not None}, "
|
||||||
|
f"TTS={self.tts is not None and self.tts._pipeline is not None}, "
|
||||||
|
f"VAD={self.vad is not None and self.vad._model is not None}", flush=True)
|
||||||
|
|
||||||
# Audio buffer for accumulating speech
|
# Audio buffer for accumulating speech
|
||||||
speech_buffer = bytearray()
|
speech_buffer = bytearray()
|
||||||
|
|
@ -84,6 +83,7 @@ class VoicePipelineTask:
|
||||||
is_speech_active = False
|
is_speech_active = False
|
||||||
silence_start: Optional[float] = None
|
silence_start: Optional[float] = None
|
||||||
speech_start: Optional[float] = None
|
speech_start: Optional[float] = None
|
||||||
|
chunk_count = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while not self._cancelled:
|
while not self._cancelled:
|
||||||
|
|
@ -92,12 +92,16 @@ class VoicePipelineTask:
|
||||||
self.websocket.receive_bytes(), timeout=30.0
|
self.websocket.receive_bytes(), timeout=30.0
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
# No data for 30s — connection might be dead
|
print(f"[pipeline] No data for 30s, session {sid}", flush=True)
|
||||||
continue
|
continue
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
# WebSocket closed
|
print(f"[pipeline] WebSocket read error: {type(exc).__name__}: {exc}", flush=True)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
chunk_count += 1
|
||||||
|
if chunk_count <= 3 or chunk_count % 100 == 0:
|
||||||
|
print(f"[pipeline] Received audio chunk #{chunk_count}, len={len(data)}", flush=True)
|
||||||
|
|
||||||
# Accumulate into VAD-sized chunks
|
# Accumulate into VAD-sized chunks
|
||||||
vad_buffer.extend(data)
|
vad_buffer.extend(data)
|
||||||
|
|
||||||
|
|
@ -113,10 +117,11 @@ class VoicePipelineTask:
|
||||||
is_speech_active = True
|
is_speech_active = True
|
||||||
speech_start = time.time()
|
speech_start = time.time()
|
||||||
silence_start = None
|
silence_start = None
|
||||||
|
print(f"[pipeline] Speech detected!", flush=True)
|
||||||
# Barge-in: if we were speaking TTS, stop
|
# Barge-in: if we were speaking TTS, stop
|
||||||
if self._speaking:
|
if self._speaking:
|
||||||
self._cancelled_tts = True
|
self._cancelled_tts = True
|
||||||
logger.debug("Barge-in detected")
|
print("[pipeline] Barge-in!", flush=True)
|
||||||
|
|
||||||
speech_buffer.extend(chunk)
|
speech_buffer.extend(chunk)
|
||||||
silence_start = None
|
silence_start = None
|
||||||
|
|
@ -130,6 +135,8 @@ class VoicePipelineTask:
|
||||||
elif time.time() - silence_start >= _SILENCE_AFTER_SPEECH_SECS:
|
elif time.time() - silence_start >= _SILENCE_AFTER_SPEECH_SECS:
|
||||||
# Silence detected after speech — process
|
# Silence detected after speech — process
|
||||||
speech_duration = time.time() - (speech_start or time.time())
|
speech_duration = time.time() - (speech_start or time.time())
|
||||||
|
buf_secs = len(speech_buffer) / (_SAMPLE_RATE * _BYTES_PER_SAMPLE)
|
||||||
|
print(f"[pipeline] Silence after speech: duration={speech_duration:.1f}s, buffer={buf_secs:.1f}s", flush=True)
|
||||||
if speech_duration >= _MIN_SPEECH_SECS and len(speech_buffer) > 0:
|
if speech_duration >= _MIN_SPEECH_SECS and len(speech_buffer) > 0:
|
||||||
await self._process_speech(bytes(speech_buffer))
|
await self._process_speech(bytes(speech_buffer))
|
||||||
|
|
||||||
|
|
@ -140,11 +147,13 @@ class VoicePipelineTask:
|
||||||
speech_start = None
|
speech_start = None
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Voice pipeline cancelled")
|
print(f"[pipeline] Cancelled, session {sid}", flush=True)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Voice pipeline error: %s", exc)
|
print(f"[pipeline] ERROR: {type(exc).__name__}: {exc}", flush=True)
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
logger.info("Voice pipeline ended for session %s", self.session_context.get("session_id"))
|
print(f"[pipeline] Ended for session {sid}", flush=True)
|
||||||
|
|
||||||
def cancel(self):
|
def cancel(self):
|
||||||
self._cancelled = True
|
self._cancelled = True
|
||||||
|
|
@ -163,14 +172,16 @@ class VoicePipelineTask:
|
||||||
async def _process_speech(self, audio_data: bytes):
|
async def _process_speech(self, audio_data: bytes):
|
||||||
"""Transcribe speech, generate LLM response, synthesize and send TTS."""
|
"""Transcribe speech, generate LLM response, synthesize and send TTS."""
|
||||||
session_id = self.session_context.get("session_id", "?")
|
session_id = self.session_context.get("session_id", "?")
|
||||||
|
audio_secs = len(audio_data) / (_SAMPLE_RATE * _BYTES_PER_SAMPLE)
|
||||||
|
print(f"[pipeline] Processing speech: {audio_secs:.1f}s of audio", flush=True)
|
||||||
|
|
||||||
# 1. STT
|
# 1. STT
|
||||||
text = await self._transcribe(audio_data)
|
text = await self._transcribe(audio_data)
|
||||||
if not text or not text.strip():
|
if not text or not text.strip():
|
||||||
logger.debug("[%s] STT returned empty text, skipping", session_id)
|
print(f"[pipeline] STT returned empty text, skipping", flush=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("[%s] User said: %s", session_id, text.strip())
|
print(f"[pipeline] User said: {text.strip()}", flush=True)
|
||||||
|
|
||||||
# Notify client that we heard them
|
# Notify client that we heard them
|
||||||
try:
|
try:
|
||||||
|
|
@ -181,19 +192,16 @@ class VoicePipelineTask:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 2. LLM
|
# 2. Agent service — create task + subscribe for response
|
||||||
self._conversation.append({"role": "user", "content": text.strip()})
|
response_text = await self._agent_generate(text.strip())
|
||||||
response_text = await self._llm_generate()
|
|
||||||
if not response_text:
|
if not response_text:
|
||||||
logger.warning("[%s] LLM returned empty response", session_id)
|
print(f"[pipeline] Agent returned empty response", flush=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("[%s] Agent says: %s", session_id, response_text)
|
print(f"[pipeline] Agent says: {response_text[:100]}", flush=True)
|
||||||
self._conversation.append({"role": "assistant", "content": response_text})
|
|
||||||
|
|
||||||
# Notify client of the response text
|
# Notify client of the response text
|
||||||
try:
|
try:
|
||||||
import json
|
|
||||||
await self.websocket.send_text(
|
await self.websocket.send_text(
|
||||||
json.dumps({"type": "transcript", "text": response_text, "role": "assistant"})
|
json.dumps({"type": "transcript", "text": response_text, "role": "assistant"})
|
||||||
)
|
)
|
||||||
|
|
@ -214,23 +222,105 @@ class VoicePipelineTask:
|
||||||
logger.error("STT error: %s", exc)
|
logger.error("STT error: %s", exc)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _llm_generate(self) -> str:
|
async def _agent_generate(self, user_text: str) -> str:
|
||||||
"""Generate a response using Anthropic Claude."""
|
"""Send user text to agent-service, subscribe via WS, collect response.
|
||||||
if not settings.anthropic_api_key:
|
|
||||||
logger.warning("Anthropic API key not set, returning default response")
|
Mirrors the Flutter chat flow:
|
||||||
return "抱歉,语音助手暂时无法连接到AI服务。"
|
1. POST /api/v1/agent/tasks → get sessionId + taskId
|
||||||
|
2. WS connect to /ws/agent → subscribe_session
|
||||||
|
3. Collect 'text' stream events until 'completed'
|
||||||
|
"""
|
||||||
|
agent_url = settings.agent_service_url # http://agent-service:3002
|
||||||
|
headers = {"Content-Type": "application/json"}
|
||||||
|
if self._auth_header:
|
||||||
|
headers["Authorization"] = self._auth_header
|
||||||
|
|
||||||
try:
|
try:
|
||||||
client = anthropic.AsyncAnthropic(api_key=settings.anthropic_api_key)
|
# 1. Create agent task
|
||||||
response = await client.messages.create(
|
body = {"prompt": user_text}
|
||||||
model=settings.claude_model,
|
if self._agent_session_id:
|
||||||
max_tokens=256,
|
body["sessionId"] = self._agent_session_id
|
||||||
messages=self._conversation,
|
|
||||||
)
|
print(f"[pipeline] Creating agent task: {user_text[:60]}", flush=True)
|
||||||
return response.content[0].text if response.content else ""
|
async with httpx.AsyncClient(timeout=30) as client:
|
||||||
|
resp = await client.post(
|
||||||
|
f"{agent_url}/api/v1/agent/tasks",
|
||||||
|
json=body,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
if resp.status_code != 200 and resp.status_code != 201:
|
||||||
|
print(f"[pipeline] Agent task creation failed: {resp.status_code} {resp.text}", flush=True)
|
||||||
|
return "抱歉,Agent服务暂时不可用。"
|
||||||
|
|
||||||
|
data = resp.json()
|
||||||
|
session_id = data.get("sessionId", "")
|
||||||
|
task_id = data.get("taskId", "")
|
||||||
|
self._agent_session_id = session_id
|
||||||
|
print(f"[pipeline] Agent task created: session={session_id}, task={task_id}", flush=True)
|
||||||
|
|
||||||
|
# 2. Subscribe via WebSocket and collect text events
|
||||||
|
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
|
||||||
|
ws_url = f"{ws_url}/ws/agent"
|
||||||
|
|
||||||
|
collected_text = []
|
||||||
|
timeout_secs = 60 # Max wait for agent response
|
||||||
|
|
||||||
|
async with websockets.connect(ws_url) as ws:
|
||||||
|
# Subscribe to the session
|
||||||
|
subscribe_msg = json.dumps({
|
||||||
|
"event": "subscribe_session",
|
||||||
|
"data": {"sessionId": session_id, "taskId": task_id},
|
||||||
|
})
|
||||||
|
await ws.send(subscribe_msg)
|
||||||
|
print(f"[pipeline] Subscribed to agent WS session={session_id}", flush=True)
|
||||||
|
|
||||||
|
# Collect events until completed
|
||||||
|
deadline = time.time() + timeout_secs
|
||||||
|
while time.time() < deadline:
|
||||||
|
try:
|
||||||
|
raw = await asyncio.wait_for(ws.recv(), timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
msg = json.loads(raw)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
event_type = msg.get("event", "")
|
||||||
|
|
||||||
|
if event_type == "stream_event":
|
||||||
|
evt_data = msg.get("data", {})
|
||||||
|
evt_type = evt_data.get("type", "")
|
||||||
|
payload = evt_data.get("data", {})
|
||||||
|
|
||||||
|
if evt_type == "text":
|
||||||
|
content = payload.get("content", "")
|
||||||
|
if content:
|
||||||
|
collected_text.append(content)
|
||||||
|
|
||||||
|
elif evt_type == "completed":
|
||||||
|
summary = payload.get("summary", "")
|
||||||
|
if summary and not collected_text:
|
||||||
|
collected_text.append(summary)
|
||||||
|
print(f"[pipeline] Agent completed", flush=True)
|
||||||
|
break
|
||||||
|
|
||||||
|
elif evt_type == "error":
|
||||||
|
err_msg = payload.get("message", "Unknown error")
|
||||||
|
print(f"[pipeline] Agent error: {err_msg}", flush=True)
|
||||||
|
return f"Agent 错误: {err_msg}"
|
||||||
|
|
||||||
|
result = "".join(collected_text).strip()
|
||||||
|
if not result:
|
||||||
|
return "Agent 未返回回复。"
|
||||||
|
return result
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("LLM error: %s", exc)
|
print(f"[pipeline] Agent generate error: {type(exc).__name__}: {exc}", flush=True)
|
||||||
return "抱歉,AI服务暂时不可用,请稍后再试。"
|
return "抱歉,Agent服务暂时不可用,请稍后再试。"
|
||||||
|
|
||||||
async def _synthesize_and_send(self, text: str):
|
async def _synthesize_and_send(self, text: str):
|
||||||
"""Synthesize text to speech and send audio chunks over WebSocket."""
|
"""Synthesize text to speech and send audio chunks over WebSocket."""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue