526 lines
19 KiB
Python
526 lines
19 KiB
Python
"""
|
||
Custom LLM plugin that proxies to IT0 agent-service.
|
||
|
||
Instead of calling Claude directly, this plugin:
|
||
1. POSTs to agent-service /api/v1/agent/tasks (engineType configurable: claude_agent_sdk or claude_api)
|
||
2. Subscribes to the agent-service WebSocket /ws/agent for streaming text events
|
||
3. Emits ChatChunk objects into the LiveKit pipeline
|
||
|
||
In Agent SDK mode, the prompt is wrapped with voice-conversation instructions
|
||
so the agent outputs concise spoken Chinese without tool-call details.
|
||
|
||
This preserves all agent-service capabilities: Tool Use, conversation history,
|
||
tenant isolation, and session management.
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from typing import Any
|
||
|
||
import httpx
|
||
import websockets
|
||
|
||
from livekit.agents import llm
|
||
from livekit.agents.types import (
|
||
DEFAULT_API_CONNECT_OPTIONS,
|
||
NOT_GIVEN,
|
||
APIConnectOptions,
|
||
NotGivenOr,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AgentServiceLLM(llm.LLM):
|
||
"""LLM that proxies to IT0 agent-service for Claude + Tool Use."""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
agent_service_url: str = "http://agent-service:3002",
|
||
auth_header: str = "",
|
||
engine_type: str = "claude_agent_sdk",
|
||
):
|
||
super().__init__()
|
||
self._agent_service_url = agent_service_url
|
||
self._auth_header = auth_header
|
||
self._engine_type = engine_type
|
||
self._agent_session_id: str | None = None
|
||
self._injecting = False # guard: don't clear session during inject
|
||
|
||
@property
|
||
def model(self) -> str:
|
||
return "agent-service-proxy"
|
||
|
||
@property
|
||
def provider(self) -> str:
|
||
return "it0-agent"
|
||
|
||
def chat(
|
||
self,
|
||
*,
|
||
chat_ctx: llm.ChatContext,
|
||
tools: list[llm.Tool] | None = None,
|
||
conn_options: APIConnectOptions = DEFAULT_API_CONNECT_OPTIONS,
|
||
parallel_tool_calls: NotGivenOr[bool] = NOT_GIVEN,
|
||
tool_choice: NotGivenOr[llm.ToolChoice] = NOT_GIVEN,
|
||
extra_kwargs: NotGivenOr[dict[str, Any]] = NOT_GIVEN,
|
||
) -> "AgentServiceLLMStream":
|
||
return AgentServiceLLMStream(
|
||
llm_instance=self,
|
||
chat_ctx=chat_ctx,
|
||
tools=tools or [],
|
||
conn_options=conn_options,
|
||
)
|
||
|
||
async def inject_text_message(
|
||
self,
|
||
*,
|
||
text: str = "",
|
||
attachments: list[dict] | None = None,
|
||
) -> str:
|
||
"""Inject a text message (with optional attachments) into the agent session.
|
||
|
||
Returns the complete response text for TTS playback via session.say().
|
||
Uses the same session ID so conversation context is preserved.
|
||
"""
|
||
if not text and not attachments:
|
||
return ""
|
||
|
||
self._injecting = True
|
||
try:
|
||
return await self._do_inject(text, attachments)
|
||
except Exception as exc:
|
||
logger.error("inject_text_message error: %s: %s", type(exc).__name__, exc)
|
||
return ""
|
||
finally:
|
||
self._injecting = False
|
||
|
||
async def _do_inject(
|
||
self,
|
||
text: str,
|
||
attachments: list[dict] | None,
|
||
) -> str:
|
||
"""Execute inject: WS+HTTP stream, collect full response text."""
|
||
import time
|
||
|
||
agent_url = self._agent_service_url
|
||
auth_header = self._auth_header
|
||
|
||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||
if auth_header:
|
||
headers["Authorization"] = auth_header
|
||
|
||
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
|
||
ws_url = f"{ws_url}/ws/agent"
|
||
|
||
timeout_secs = 120
|
||
engine_type = self._engine_type
|
||
voice_mode = engine_type == "claude_agent_sdk"
|
||
|
||
body: dict[str, Any] = {
|
||
"prompt": text if text else "(see attachments)",
|
||
"engineType": engine_type,
|
||
"voiceMode": voice_mode,
|
||
}
|
||
|
||
if voice_mode:
|
||
body["systemPrompt"] = (
|
||
"你正在通过语音与用户实时对话。请严格遵守以下规则:\n"
|
||
"1. 只输出用户关注的最终答案,不要输出工具调用过程、中间步骤或技术细节\n"
|
||
"2. 用简洁自然的口语中文回答,像面对面对话一样\n"
|
||
"3. 回复要简短精炼,适合语音播报,通常1-3句话即可\n"
|
||
"4. 不要使用markdown格式、代码块、列表符号等文本格式"
|
||
)
|
||
|
||
if self._agent_session_id:
|
||
body["sessionId"] = self._agent_session_id
|
||
|
||
if attachments:
|
||
body["attachments"] = attachments
|
||
|
||
logger.info(
|
||
"inject POST /tasks engine=%s text=%s attachments=%d",
|
||
engine_type,
|
||
text[:80] if text else "(empty)",
|
||
len(attachments) if attachments else 0,
|
||
)
|
||
|
||
collected_text = ""
|
||
|
||
async with websockets.connect(
|
||
ws_url,
|
||
open_timeout=10,
|
||
close_timeout=5,
|
||
ping_interval=20,
|
||
ping_timeout=10,
|
||
) as ws:
|
||
# Pre-subscribe
|
||
if self._agent_session_id:
|
||
await ws.send(json.dumps({
|
||
"event": "subscribe_session",
|
||
"data": {"sessionId": self._agent_session_id},
|
||
}))
|
||
|
||
# Create task
|
||
async with httpx.AsyncClient(
|
||
timeout=httpx.Timeout(connect=10, read=30, write=10, pool=10),
|
||
) as client:
|
||
resp = await client.post(
|
||
f"{agent_url}/api/v1/agent/tasks",
|
||
json=body,
|
||
headers=headers,
|
||
)
|
||
|
||
if resp.status_code not in (200, 201):
|
||
logger.error(
|
||
"inject task creation failed: %d %s",
|
||
resp.status_code, resp.text[:200],
|
||
)
|
||
return ""
|
||
|
||
data = resp.json()
|
||
session_id = data.get("sessionId", "")
|
||
task_id = data.get("taskId", "")
|
||
self._agent_session_id = session_id
|
||
logger.info(
|
||
"inject task created: session=%s, task=%s",
|
||
session_id, task_id,
|
||
)
|
||
|
||
# Subscribe with actual IDs
|
||
await ws.send(json.dumps({
|
||
"event": "subscribe_session",
|
||
"data": {"sessionId": session_id, "taskId": task_id},
|
||
}))
|
||
|
||
# Stream events → collect text
|
||
deadline = time.time() + timeout_secs
|
||
|
||
while time.time() < deadline:
|
||
remaining = deadline - time.time()
|
||
try:
|
||
raw = await asyncio.wait_for(
|
||
ws.recv(), timeout=min(30.0, remaining)
|
||
)
|
||
except asyncio.TimeoutError:
|
||
if time.time() >= deadline:
|
||
logger.warning("inject stream timeout after %ds", timeout_secs)
|
||
continue
|
||
except websockets.exceptions.ConnectionClosed:
|
||
logger.warning("inject WS connection closed")
|
||
break
|
||
|
||
try:
|
||
msg = json.loads(raw)
|
||
except (json.JSONDecodeError, TypeError):
|
||
continue
|
||
|
||
event_type = msg.get("event", "")
|
||
|
||
if event_type == "stream_event":
|
||
evt_data = msg.get("data", {})
|
||
evt_type = evt_data.get("type", "")
|
||
|
||
if evt_type == "text":
|
||
content = evt_data.get("content", "")
|
||
if content:
|
||
collected_text += content
|
||
|
||
elif evt_type == "completed":
|
||
logger.info(
|
||
"inject stream completed, text length=%d",
|
||
len(collected_text),
|
||
)
|
||
return collected_text
|
||
|
||
elif evt_type == "error":
|
||
err_msg = evt_data.get("message", "Unknown error")
|
||
logger.error("inject error: %s", err_msg)
|
||
if "aborted" in err_msg.lower() or "exited" in err_msg.lower():
|
||
self._agent_session_id = None
|
||
return collected_text if collected_text else ""
|
||
|
||
return collected_text
|
||
|
||
|
||
class AgentServiceLLMStream(llm.LLMStream):
|
||
"""Streams text from agent-service via WebSocket."""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
llm_instance: AgentServiceLLM,
|
||
chat_ctx: llm.ChatContext,
|
||
tools: list[llm.Tool],
|
||
conn_options: APIConnectOptions,
|
||
):
|
||
super().__init__(
|
||
llm_instance,
|
||
chat_ctx=chat_ctx,
|
||
tools=tools,
|
||
conn_options=conn_options,
|
||
)
|
||
self._llm_instance = llm_instance
|
||
|
||
# Retry configuration
|
||
_MAX_RETRIES = 2
|
||
_RETRY_DELAYS = [1.0, 3.0] # seconds between retries
|
||
|
||
async def _run(self) -> None:
|
||
# Extract the latest user message from ChatContext
|
||
# items can contain ChatMessage and AgentConfigUpdate; filter by type
|
||
user_text = ""
|
||
for item in reversed(self._chat_ctx.items):
|
||
if getattr(item, "type", None) != "message":
|
||
continue
|
||
if item.role == "user":
|
||
user_text = item.text_content
|
||
break
|
||
|
||
if not user_text:
|
||
# on_enter/generate_reply may call LLM without a user message;
|
||
# look for the developer/system instruction to use as prompt
|
||
for item in self._chat_ctx.items:
|
||
if getattr(item, "type", None) != "message":
|
||
continue
|
||
if item.role in ("developer", "system"):
|
||
user_text = item.text_content
|
||
break
|
||
|
||
if not user_text:
|
||
logger.warning("No user message found in chat context")
|
||
return
|
||
|
||
request_id = f"agent-{uuid.uuid4().hex[:12]}"
|
||
last_error: Exception | None = None
|
||
|
||
for attempt in range(self._MAX_RETRIES + 1):
|
||
try:
|
||
if attempt > 0:
|
||
delay = self._RETRY_DELAYS[min(attempt - 1, len(self._RETRY_DELAYS) - 1)]
|
||
logger.info("Retry %d/%d after %.1fs", attempt, self._MAX_RETRIES, delay)
|
||
await asyncio.sleep(delay)
|
||
|
||
await self._do_stream(user_text, request_id)
|
||
return # success
|
||
|
||
except (httpx.ConnectError, httpx.ConnectTimeout, OSError) as exc:
|
||
# Network-level errors — retryable
|
||
last_error = exc
|
||
logger.warning(
|
||
"Agent stream attempt %d failed (network): %s: %s",
|
||
attempt + 1, type(exc).__name__, exc,
|
||
)
|
||
except websockets.exceptions.InvalidStatusCode as exc:
|
||
last_error = exc
|
||
logger.warning(
|
||
"Agent WS connect attempt %d failed: status %s",
|
||
attempt + 1, getattr(exc, "status_code", "?"),
|
||
)
|
||
except Exception as exc:
|
||
# Non-retryable errors — fail immediately
|
||
logger.error("Agent stream error: %s: %s", type(exc).__name__, exc)
|
||
self._event_ch.send_nowait(
|
||
llm.ChatChunk(
|
||
id=request_id,
|
||
delta=llm.ChoiceDelta(
|
||
role="assistant",
|
||
content="抱歉,Agent服务暂时不可用。",
|
||
),
|
||
)
|
||
)
|
||
return
|
||
|
||
# All retries exhausted
|
||
logger.error(
|
||
"Agent stream failed after %d attempts: %s",
|
||
self._MAX_RETRIES + 1, last_error,
|
||
)
|
||
self._event_ch.send_nowait(
|
||
llm.ChatChunk(
|
||
id=request_id,
|
||
delta=llm.ChoiceDelta(
|
||
role="assistant",
|
||
content="抱歉,Agent服务暂时不可用,请稍后再试。",
|
||
),
|
||
)
|
||
)
|
||
|
||
async def _do_stream(self, user_text: str, request_id: str) -> None:
|
||
"""Execute a single WS+HTTP streaming attempt."""
|
||
import time
|
||
|
||
agent_url = self._llm_instance._agent_service_url
|
||
auth_header = self._llm_instance._auth_header
|
||
|
||
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||
if auth_header:
|
||
headers["Authorization"] = auth_header
|
||
|
||
ws_url = agent_url.replace("http://", "ws://").replace("https://", "wss://")
|
||
ws_url = f"{ws_url}/ws/agent"
|
||
|
||
timeout_secs = 120
|
||
|
||
logger.info("Connecting to agent-service WS: %s", ws_url)
|
||
async with websockets.connect(
|
||
ws_url,
|
||
open_timeout=10,
|
||
close_timeout=5,
|
||
ping_interval=20,
|
||
ping_timeout=10,
|
||
) as ws:
|
||
# 1. Pre-subscribe with existing session ID (for event buffering)
|
||
if self._llm_instance._agent_session_id:
|
||
await ws.send(json.dumps({
|
||
"event": "subscribe_session",
|
||
"data": {"sessionId": self._llm_instance._agent_session_id},
|
||
}))
|
||
|
||
# 2. Create agent task (with timeout)
|
||
engine_type = self._llm_instance._engine_type
|
||
|
||
# Voice mode flag: tell agent-service to filter intermediate events
|
||
# (tool_use, tool_result, thinking) — only stream text + completed + error
|
||
voice_mode = engine_type == "claude_agent_sdk"
|
||
|
||
body: dict[str, Any] = {
|
||
"prompt": user_text, # always send clean user text (no wrapping)
|
||
"engineType": engine_type,
|
||
"voiceMode": voice_mode,
|
||
}
|
||
|
||
# Agent SDK mode: set systemPrompt once (not per-message) so
|
||
# conversation history stays clean — identical to text chat pattern
|
||
if voice_mode:
|
||
body["systemPrompt"] = (
|
||
"你正在通过语音与用户实时对话。请严格遵守以下规则:\n"
|
||
"1. 只输出用户关注的最终答案,不要输出工具调用过程、中间步骤或技术细节\n"
|
||
"2. 用简洁自然的口语中文回答,像面对面对话一样\n"
|
||
"3. 回复要简短精炼,适合语音播报,通常1-3句话即可\n"
|
||
"4. 不要使用markdown格式、代码块、列表符号等文本格式"
|
||
)
|
||
|
||
if self._llm_instance._agent_session_id:
|
||
body["sessionId"] = self._llm_instance._agent_session_id
|
||
|
||
logger.info(
|
||
"POST /tasks engine=%s voiceMode=%s user_text=%s",
|
||
engine_type,
|
||
voice_mode,
|
||
user_text[:80],
|
||
)
|
||
async with httpx.AsyncClient(
|
||
timeout=httpx.Timeout(connect=10, read=30, write=10, pool=10),
|
||
) as client:
|
||
resp = await client.post(
|
||
f"{agent_url}/api/v1/agent/tasks",
|
||
json=body,
|
||
headers=headers,
|
||
)
|
||
|
||
if resp.status_code not in (200, 201):
|
||
logger.error(
|
||
"Task creation failed: %d %s",
|
||
resp.status_code, resp.text[:200],
|
||
)
|
||
self._event_ch.send_nowait(
|
||
llm.ChatChunk(
|
||
id=request_id,
|
||
delta=llm.ChoiceDelta(
|
||
role="assistant",
|
||
content="抱歉,Agent服务暂时不可用。",
|
||
),
|
||
)
|
||
)
|
||
return
|
||
|
||
data = resp.json()
|
||
session_id = data.get("sessionId", "")
|
||
task_id = data.get("taskId", "")
|
||
self._llm_instance._agent_session_id = session_id
|
||
logger.info(
|
||
"Task created: session=%s, task=%s", session_id, task_id
|
||
)
|
||
|
||
# 3. Subscribe with actual IDs
|
||
await ws.send(json.dumps({
|
||
"event": "subscribe_session",
|
||
"data": {"sessionId": session_id, "taskId": task_id},
|
||
}))
|
||
|
||
# 4. Send initial role delta
|
||
self._event_ch.send_nowait(
|
||
llm.ChatChunk(
|
||
id=request_id,
|
||
delta=llm.ChoiceDelta(role="assistant"),
|
||
)
|
||
)
|
||
|
||
# 5. Stream events → ChatChunk
|
||
deadline = time.time() + timeout_secs
|
||
|
||
while time.time() < deadline:
|
||
remaining = deadline - time.time()
|
||
try:
|
||
raw = await asyncio.wait_for(
|
||
ws.recv(), timeout=min(30.0, remaining)
|
||
)
|
||
except asyncio.TimeoutError:
|
||
if time.time() >= deadline:
|
||
logger.warning("Agent stream timeout after %ds", timeout_secs)
|
||
continue
|
||
except websockets.exceptions.ConnectionClosed:
|
||
logger.warning("Agent WS connection closed during streaming")
|
||
break
|
||
|
||
try:
|
||
msg = json.loads(raw)
|
||
except (json.JSONDecodeError, TypeError):
|
||
continue
|
||
|
||
event_type = msg.get("event", "")
|
||
|
||
if event_type == "stream_event":
|
||
evt_data = msg.get("data", {})
|
||
evt_type = evt_data.get("type", "")
|
||
|
||
if evt_type == "text":
|
||
content = evt_data.get("content", "")
|
||
if content:
|
||
self._event_ch.send_nowait(
|
||
llm.ChatChunk(
|
||
id=request_id,
|
||
delta=llm.ChoiceDelta(content=content),
|
||
)
|
||
)
|
||
|
||
elif evt_type == "completed":
|
||
logger.info("Agent stream completed")
|
||
return
|
||
|
||
elif evt_type == "error":
|
||
err_msg = evt_data.get("message", "Unknown error")
|
||
logger.error("Agent error: %s", err_msg)
|
||
# Clear session so next task starts fresh
|
||
# (don't try to resume a dead/aborted session)
|
||
# But skip if inject is in progress — it owns the session
|
||
if "aborted" in err_msg.lower() or "exited" in err_msg.lower():
|
||
if not self._llm_instance._injecting:
|
||
logger.info("Clearing agent session after abort/exit")
|
||
self._llm_instance._agent_session_id = None
|
||
else:
|
||
logger.info("Skipping session clear — inject in progress")
|
||
self._event_ch.send_nowait(
|
||
llm.ChatChunk(
|
||
id=request_id,
|
||
delta=llm.ChoiceDelta(
|
||
content=f"Agent 错误: {err_msg}"
|
||
),
|
||
)
|
||
)
|
||
return
|