it0/packages/services/voice-agent/src/plugins/agent_llm.py

526 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Custom LLM plugin that proxies to IT0 agent-service.
Instead of calling Claude directly, this plugin:
1. POSTs to agent-service /api/v1/agent/tasks (engineType configurable: claude_agent_sdk or claude_api)
2. Subscribes to the agent-service WebSocket /ws/agent for streaming text events
3. Emits ChatChunk objects into the LiveKit pipeline
In Agent SDK mode, the prompt is wrapped with voice-conversation instructions
so the agent outputs concise spoken Chinese without tool-call details.
This preserves all agent-service capabilities: Tool Use, conversation history,
tenant isolation, and session management.
"""
import asyncio
import json
import logging
import uuid
from typing import Any
import httpx
import websockets
from livekit.agents import llm
from livekit.agents.types import (
DEFAULT_API_CONNECT_OPTIONS,
NOT_GIVEN,
APIConnectOptions,
NotGivenOr,
)
logger = logging.getLogger(__name__)
class AgentServiceLLM(llm.LLM):
"""LLM that proxies to IT0 agent-service for Claude + Tool Use."""
def __init__(
self,
*,
agent_service_url: str = "http://agent-service:3002",
auth_header: str = "",
engine_type: str = "claude_agent_sdk",
):
super().__init__()
self._agent_service_url = agent_service_url
self._auth_header = auth_header
self._engine_type = engine_type
self._agent_session_id: str | None = None
self._injecting = False # guard: don't clear session during inject
@property
def model(self) -> str:
return "agent-service-proxy"
@property
def provider(self) -> str:
return "it0-agent"
def chat(
self,
*,
chat_ctx: llm.ChatContext,
tools: list[llm.Tool] | None = None,
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
) -> "AgentServiceLLMStream":
return AgentServiceLLMStream(
llm_instance=self,
chat_ctx=chat_ctx,
tools=tools or [],
conn_options=conn_options,
)
async def inject_text_message(
self,
*,
text: str = "",
attachments: list[dict] | None = None,
) -> str:
"""Inject a text message (with optional attachments) into the agent session.
Returns the complete response text for TTS playback via session.say().
Uses the same session ID so conversation context is preserved.
"""
if not text and not attachments:
return ""
self._injecting = True
try:
return await self._do_inject(text, attachments)
except Exception as exc:
logger.error("inject_text_message error: %s: %s", type(exc).__name__, exc)
return ""
finally:
self._injecting = False
async def _do_inject(
self,
text: str,
attachments: list[dict] | None,
) -> str:
"""Execute inject: WS+HTTP stream, collect full response text."""
import time
agent_url = self._agent_service_url
auth_header = self._auth_header
headers: dict[str, str] = {"Content-Type": "application/json"}
if auth_header:
headers["Authorization"] = auth_header
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/ws/agent"
timeout_secs = 120
engine_type = self._engine_type
voice_mode = engine_type == "claude_agent_sdk"
body: dict[str, Any] = {
"prompt": text if text else "(see attachments)",
"engineType": engine_type,
"voiceMode": voice_mode,
}
if voice_mode:
body["systemPrompt"] = (
"你正在通过语音与用户实时对话。请严格遵守以下规则:\n"
"1. 只输出用户关注的最终答案,不要输出工具调用过程、中间步骤或技术细节\n"
"2. 用简洁自然的口语中文回答,像面对面对话一样\n"
"3. 回复要简短精炼适合语音播报通常1-3句话即可\n"
"4. 不要使用markdown格式、代码块、列表符号等文本格式"
)
if self._agent_session_id:
body["sessionId"] = self._agent_session_id
if attachments:
body["attachments"] = attachments
logger.info(
"inject POST /tasks engine=%s text=%s attachments=%d",
engine_type,
text[:80] if text else "(empty)",
len(attachments) if attachments else 0,
)
collected_text = ""
async with websockets.connect(
ws_url,
open_timeout=10,
close_timeout=5,
ping_interval=20,
ping_timeout=10,
) as ws:
# Pre-subscribe
if self._agent_session_id:
await ws.send(json.dumps({
"event": "subscribe_session",
"data": {"sessionId": self._agent_session_id},
}))
# Create task
async with httpx.AsyncClient(
timeout=httpx.Timeout(connect=10, read=30, write=10, pool=10),
) as client:
resp = await client.post(
f"{agent_url}/api/v1/agent/tasks",
json=body,
headers=headers,
)
if resp.status_code not in (200, 201):
logger.error(
"inject task creation failed: %d %s",
resp.status_code, resp.text[:200],
)
return ""
data = resp.json()
session_id = data.get("sessionId", "")
task_id = data.get("taskId", "")
self._agent_session_id = session_id
logger.info(
"inject task created: session=%s, task=%s",
session_id, task_id,
)
# Subscribe with actual IDs
await ws.send(json.dumps({
"event": "subscribe_session",
"data": {"sessionId": session_id, "taskId": task_id},
}))
# Stream events → collect text
deadline = time.time() + timeout_secs
while time.time() < deadline:
remaining = deadline - time.time()
try:
raw = await asyncio.wait_for(
ws.recv(), timeout=min(30.0, remaining)
)
except asyncio.TimeoutError:
if time.time() >= deadline:
logger.warning("inject stream timeout after %ds", timeout_secs)
continue
except websockets.exceptions.ConnectionClosed:
logger.warning("inject WS connection closed")
break
try:
msg = json.loads(raw)
except (json.JSONDecodeError, TypeError):
continue
event_type = msg.get("event", "")
if event_type == "stream_event":
evt_data = msg.get("data", {})
evt_type = evt_data.get("type", "")
if evt_type == "text":
content = evt_data.get("content", "")
if content:
collected_text += content
elif evt_type == "completed":
logger.info(
"inject stream completed, text length=%d",
len(collected_text),
)
return collected_text
elif evt_type == "error":
err_msg = evt_data.get("message", "Unknown error")
logger.error("inject error: %s", err_msg)
if "aborted" in err_msg.lower() or "exited" in err_msg.lower():
self._agent_session_id = None
return collected_text if collected_text else ""
return collected_text
class AgentServiceLLMStream(llm.LLMStream):
"""Streams text from agent-service via WebSocket."""
def __init__(
self,
*,
llm_instance: AgentServiceLLM,
chat_ctx: llm.ChatContext,
tools: list[llm.Tool],
conn_options: APIConnectOptions,
):
super().__init__(
llm_instance,
chat_ctx=chat_ctx,
tools=tools,
conn_options=conn_options,
)
self._llm_instance = llm_instance
# Retry configuration
_MAX_RETRIES = 2
_RETRY_DELAYS = [1.0, 3.0] # seconds between retries
async def _run(self) -> None:
# Extract the latest user message from ChatContext
# items can contain ChatMessage and AgentConfigUpdate; filter by type
user_text = ""
for item in reversed(self._chat_ctx.items):
if getattr(item, "type", None) != "message":
continue
if item.role == "user":
user_text = item.text_content
break
if not user_text:
# on_enter/generate_reply may call LLM without a user message;
# look for the developer/system instruction to use as prompt
for item in self._chat_ctx.items:
if getattr(item, "type", None) != "message":
continue
if item.role in ("developer", "system"):
user_text = item.text_content
break
if not user_text:
logger.warning("No user message found in chat context")
return
request_id = f"agent-{uuid.uuid4().hex[:12]}"
last_error: Exception | None = None
for attempt in range(self._MAX_RETRIES + 1):
try:
if attempt > 0:
delay = self._RETRY_DELAYS[min(attempt - 1, len(self._RETRY_DELAYS) - 1)]
logger.info("Retry %d/%d after %.1fs", attempt, self._MAX_RETRIES, delay)
await asyncio.sleep(delay)
await self._do_stream(user_text, request_id)
return # success
except (httpx.ConnectError, httpx.ConnectTimeout, OSError) as exc:
# Network-level errors — retryable
last_error = exc
logger.warning(
"Agent stream attempt %d failed (network): %s: %s",
attempt + 1, type(exc).__name__, exc,
)
except websockets.exceptions.InvalidStatusCode as exc:
last_error = exc
logger.warning(
"Agent WS connect attempt %d failed: status %s",
attempt + 1, getattr(exc, "status_code", "?"),
)
except Exception as exc:
# Non-retryable errors — fail immediately
logger.error("Agent stream error: %s: %s", type(exc).__name__, exc)
self._event_ch.send_nowait(
llm.ChatChunk(
id=request_id,
delta=llm.ChoiceDelta(
role="assistant",
content="抱歉Agent服务暂时不可用。",
),
)
)
return
# All retries exhausted
logger.error(
"Agent stream failed after %d attempts: %s",
self._MAX_RETRIES + 1, last_error,
)
self._event_ch.send_nowait(
llm.ChatChunk(
id=request_id,
delta=llm.ChoiceDelta(
role="assistant",
content="抱歉Agent服务暂时不可用请稍后再试。",
),
)
)
async def _do_stream(self, user_text: str, request_id: str) -> None:
"""Execute a single WS+HTTP streaming attempt."""
import time
agent_url = self._llm_instance._agent_service_url
auth_header = self._llm_instance._auth_header
headers: dict[str, str] = {"Content-Type": "application/json"}
if auth_header:
headers["Authorization"] = auth_header
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
ws_url = f"{ws_url}/ws/agent"
timeout_secs = 120
logger.info("Connecting to agent-service WS: %s", ws_url)
async with websockets.connect(
ws_url,
open_timeout=10,
close_timeout=5,
ping_interval=20,
ping_timeout=10,
) as ws:
# 1. Pre-subscribe with existing session ID (for event buffering)
if self._llm_instance._agent_session_id:
await ws.send(json.dumps({
"event": "subscribe_session",
"data": {"sessionId": self._llm_instance._agent_session_id},
}))
# 2. Create agent task (with timeout)
engine_type = self._llm_instance._engine_type
# Voice mode flag: tell agent-service to filter intermediate events
# (tool_use, tool_result, thinking) — only stream text + completed + error
voice_mode = engine_type == "claude_agent_sdk"
body: dict[str, Any] = {
"prompt": user_text, # always send clean user text (no wrapping)
"engineType": engine_type,
"voiceMode": voice_mode,
}
# Agent SDK mode: set systemPrompt once (not per-message) so
# conversation history stays clean — identical to text chat pattern
if voice_mode:
body["systemPrompt"] = (
"你正在通过语音与用户实时对话。请严格遵守以下规则:\n"
"1. 只输出用户关注的最终答案,不要输出工具调用过程、中间步骤或技术细节\n"
"2. 用简洁自然的口语中文回答,像面对面对话一样\n"
"3. 回复要简短精炼适合语音播报通常1-3句话即可\n"
"4. 不要使用markdown格式、代码块、列表符号等文本格式"
)
if self._llm_instance._agent_session_id:
body["sessionId"] = self._llm_instance._agent_session_id
logger.info(
"POST /tasks engine=%s voiceMode=%s user_text=%s",
engine_type,
voice_mode,
user_text[:80],
)
async with httpx.AsyncClient(
timeout=httpx.Timeout(connect=10, read=30, write=10, pool=10),
) as client:
resp = await client.post(
f"{agent_url}/api/v1/agent/tasks",
json=body,
headers=headers,
)
if resp.status_code not in (200, 201):
logger.error(
"Task creation failed: %d %s",
resp.status_code, resp.text[:200],
)
self._event_ch.send_nowait(
llm.ChatChunk(
id=request_id,
delta=llm.ChoiceDelta(
role="assistant",
content="抱歉Agent服务暂时不可用。",
),
)
)
return
data = resp.json()
session_id = data.get("sessionId", "")
task_id = data.get("taskId", "")
self._llm_instance._agent_session_id = session_id
logger.info(
"Task created: session=%s, task=%s", session_id, task_id
)
# 3. Subscribe with actual IDs
await ws.send(json.dumps({
"event": "subscribe_session",
"data": {"sessionId": session_id, "taskId": task_id},
}))
# 4. Send initial role delta
self._event_ch.send_nowait(
llm.ChatChunk(
id=request_id,
delta=llm.ChoiceDelta(role="assistant"),
)
)
# 5. Stream events → ChatChunk
deadline = time.time() + timeout_secs
while time.time() < deadline:
remaining = deadline - time.time()
try:
raw = await asyncio.wait_for(
ws.recv(), timeout=min(30.0, remaining)
)
except asyncio.TimeoutError:
if time.time() >= deadline:
logger.warning("Agent stream timeout after %ds", timeout_secs)
continue
except websockets.exceptions.ConnectionClosed:
logger.warning("Agent WS connection closed during streaming")
break
try:
msg = json.loads(raw)
except (json.JSONDecodeError, TypeError):
continue
event_type = msg.get("event", "")
if event_type == "stream_event":
evt_data = msg.get("data", {})
evt_type = evt_data.get("type", "")
if evt_type == "text":
content = evt_data.get("content", "")
if content:
self._event_ch.send_nowait(
llm.ChatChunk(
id=request_id,
delta=llm.ChoiceDelta(content=content),
)
)
elif evt_type == "completed":
logger.info("Agent stream completed")
return
elif evt_type == "error":
err_msg = evt_data.get("message", "Unknown error")
logger.error("Agent error: %s", err_msg)
# Clear session so next task starts fresh
# (don't try to resume a dead/aborted session)
# But skip if inject is in progress — it owns the session
if "aborted" in err_msg.lower() or "exited" in err_msg.lower():
if not self._llm_instance._injecting:
logger.info("Clearing agent session after abort/exit")
self._llm_instance._agent_session_id = None
else:
logger.info("Skipping session clear — inject in progress")
self._event_ch.send_nowait(
llm.ChatChunk(
id=request_id,
delta=llm.ChoiceDelta(
content=f"Agent 错误: {err_msg}"
),
)
)
return