it0/packages/services/voice-agent/src/agent.py

380 lines
14 KiB
Python

"""
IT0 Voice Agent — LiveKit Agents v1.x entry point.
Uses the official AgentServer + @server.rtc_session() pattern.
Pipeline: VAD → STT → LLM (via agent-service) → TTS.
Usage:
python -m src.agent start
"""
import asyncio
import json
import logging
import ssl
import aiohttp
from livekit.agents import (
Agent,
AgentServer,
AgentSession,
JobContext,
JobProcess,
cli,
room_io,
)
from livekit.agents.utils import http_context
from livekit.plugins import silero
from .config import settings
# ---------------------------------------------------------------------------
# Monkey-patch: disable SSL verification for aiohttp sessions.
#
# The OpenAI Realtime STT uses aiohttp WebSocket (not httpx), so passing
# verify=False to the httpx/OpenAI client does NOT help. LiveKit's
# http_context._new_session_ctx creates an aiohttp.TCPConnector without
# ssl=False, causing SSL errors when OPENAI_BASE_URL points to a proxy
# with a self-signed certificate.
#
# We replace _new_session_ctx to inject ssl=False into the connector.
# ---------------------------------------------------------------------------
_original_new_session_ctx = http_context._new_session_ctx
_no_verify_ssl = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
_no_verify_ssl.check_hostname = False
_no_verify_ssl.verify_mode = ssl.CERT_NONE
def _patched_new_session_ctx():
"""Same as the original but with ssl verification disabled."""
_g_session = None
def _new_session():
nonlocal _g_session
if _g_session is None or _g_session.closed:
from livekit.agents.job import get_job_context
try:
http_proxy = get_job_context().proc.http_proxy
except RuntimeError:
http_proxy = None
connector = aiohttp.TCPConnector(
limit_per_host=50,
keepalive_timeout=120,
ssl=_no_verify_ssl,
)
_g_session = aiohttp.ClientSession(proxy=http_proxy, connector=connector)
return _g_session
http_context._ContextVar.set(_new_session)
return _new_session
http_context._new_session_ctx = _patched_new_session_ctx
from .plugins.agent_llm import AgentServiceLLM
from .plugins.whisper_stt import LocalWhisperSTT
from .plugins.kokoro_tts import LocalKokoroTTS, patch_misaki_compat
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class IT0VoiceAgent(Agent):
"""Voice agent for IT0 server operations platform."""
def __init__(self):
super().__init__(
instructions=(
"你是 IT0 服务器运维助手。用户通过语音与你对话,"
"你帮助管理和监控服务器集群。回答简洁,适合语音对话场景。"
),
)
async def on_enter(self):
"""No greeting — wait for the user to speak first."""
pass
# ---------------------------------------------------------------------------
# Server setup
# ---------------------------------------------------------------------------
server = AgentServer()
def prewarm(proc: JobProcess) -> None:
"""Pre-load ML models into shared process memory.
Called once per worker process. Models are shared across all sessions
handled by this process, avoiding redundant loading.
"""
logger.info(
"Prewarming models (stt=%s, tts=%s, device=%s)",
settings.stt_provider,
settings.tts_provider,
settings.device,
)
# VAD — always needed
proc.userdata["vad"] = silero.VAD.load()
logger.info("VAD loaded: Silero VAD")
# STT — local faster-whisper
if settings.stt_provider == "local":
from faster_whisper import WhisperModel
compute_type = "float16" if settings.device == "cuda" else "int8"
try:
model = WhisperModel(
settings.whisper_model,
device=settings.device,
compute_type=compute_type,
)
except Exception as e:
logger.warning("Whisper GPU failed, falling back to CPU: %s", e)
model = WhisperModel(
settings.whisper_model, device="cpu", compute_type="int8"
)
proc.userdata["whisper_model"] = model
logger.info("STT loaded: faster-whisper %s", settings.whisper_model)
else:
proc.userdata["whisper_model"] = None
logger.info("STT: using OpenAI %s", settings.openai_stt_model)
# TTS — local Kokoro
if settings.tts_provider == "local":
patch_misaki_compat()
from kokoro import KPipeline
proc.userdata["kokoro_pipeline"] = KPipeline(lang_code="z")
logger.info("TTS loaded: Kokoro-82M voice=%s", settings.kokoro_voice)
else:
proc.userdata["kokoro_pipeline"] = None
logger.info("TTS: using OpenAI %s", settings.openai_tts_model)
logger.info("Prewarm complete.")
server.setup_fnc = prewarm
# ---------------------------------------------------------------------------
# Session entrypoint — called for each voice session (room join)
# ---------------------------------------------------------------------------
@server.rtc_session(agent_name="voice-agent")
async def entrypoint(ctx: JobContext) -> None:
"""Main entrypoint — called for each voice session.
NOTE: session.start() returns immediately while the session continues
running in the background. Resources (httpx clients) must stay alive
for the session's lifetime and are cleaned up via the room disconnect
listener, NOT in a finally block.
"""
logger.info("New voice session: room=%s", ctx.room.name)
# httpx clients to close when the room disconnects
_http_clients: list = []
async def _on_room_disconnect() -> None:
"""Clean up httpx clients when the room disconnects."""
for client in _http_clients:
try:
await client.aclose()
except Exception:
pass
logger.info("Cleaned up %d httpx client(s) for room %s",
len(_http_clients), ctx.room.name)
# Register cleanup before anything else so it fires even on errors
ctx.room.on("disconnected", lambda *_: asyncio.ensure_future(_on_room_disconnect()))
try:
# Extract auth header from job metadata
# The token endpoint embeds {"auth_header": "Bearer ..."} via RoomAgentDispatch metadata,
# which LiveKit passes through as job.metadata to the agent worker.
auth_header = ""
tts_voice = settings.openai_tts_voice
tts_style = ""
engine_type = "claude_agent_sdk"
meta = {}
try:
meta_str = ctx.job.metadata or "{}"
meta = json.loads(meta_str)
auth_header = meta.get("auth_header", "")
tts_voice = meta.get("tts_voice", settings.openai_tts_voice)
tts_style = meta.get("tts_style", "")
engine_type = meta.get("engine_type", "claude_agent_sdk")
except Exception as e:
logger.warning("Failed to parse job metadata: %s", e)
logger.info("Auth header present: %s, TTS: voice=%s, style=%s, engine=%s",
bool(auth_header), tts_voice, tts_style[:50] if tts_style else "(default)", engine_type)
# ── Resolve STT provider (metadata > agent-service config > env default) ──
stt_provider = meta.get("stt_provider", "")
if not stt_provider and auth_header:
try:
import httpx as _httpx_cfg
async with _httpx_cfg.AsyncClient(timeout=_httpx_cfg.Timeout(5)) as _cfg_client:
_cfg_resp = await _cfg_client.get(
f"{settings.agent_service_url}/api/v1/agent/voice-config",
headers={"Authorization": auth_header},
)
if _cfg_resp.status_code == 200:
_voice_cfg = _cfg_resp.json()
stt_provider = _voice_cfg.get("stt_provider", "")
logger.info("Voice config from agent-service: stt_provider=%s", stt_provider)
except Exception as e:
logger.warning("Failed to fetch voice config from agent-service: %s", e)
if not stt_provider:
stt_provider = settings.stt_provider # env var fallback
# ── Build STT ──
if stt_provider == "openai":
from livekit.plugins import openai as openai_plugin
import httpx as _httpx
import openai as _openai
# OPENAI_BASE_URL may use a self-signed certificate (e.g. proxy)
_http_client = _httpx.AsyncClient(verify=False)
_http_clients.append(_http_client)
_oai_client = _openai.AsyncOpenAI(http_client=_http_client)
stt = openai_plugin.STT(
model=settings.openai_stt_model,
language=settings.whisper_language,
client=_oai_client,
use_realtime=True,
# Increase silence_duration_ms so Chinese speech isn't chopped
# into tiny fragments (default 350ms is too aggressive).
turn_detection={
"type": "server_vad",
"threshold": 0.6,
"prefix_padding_ms": 600,
"silence_duration_ms": 800,
},
)
else:
stt = LocalWhisperSTT(
model=ctx.proc.userdata.get("whisper_model"),
language=settings.whisper_language,
)
logger.info("STT provider selected: %s", stt_provider)
# Build TTS
if settings.tts_provider == "openai":
from livekit.plugins import openai as openai_plugin
import httpx as _httpx
import openai as _openai
_http_client_tts = _httpx.AsyncClient(verify=False)
_http_clients.append(_http_client_tts)
_oai_client_tts = _openai.AsyncOpenAI(http_client=_http_client_tts)
default_instructions = "用自然、友好的中文语气说话,语速稍快,简洁干练,像专业助手一样。"
tts = openai_plugin.TTS(
model=settings.openai_tts_model,
voice=tts_voice,
instructions=tts_style if tts_style else default_instructions,
client=_oai_client_tts,
)
else:
tts = LocalKokoroTTS(
pipeline=ctx.proc.userdata.get("kokoro_pipeline"),
voice=settings.kokoro_voice,
)
# Build custom LLM (proxies to agent-service)
llm = AgentServiceLLM(
agent_service_url=settings.agent_service_url,
auth_header=auth_header,
engine_type=engine_type,
)
# Create and start AgentSession with the full pipeline
session = AgentSession(
vad=ctx.proc.userdata["vad"],
stt=stt,
llm=llm,
tts=tts,
)
await session.start(
agent=IT0VoiceAgent(),
room=ctx.room,
room_input_options=room_io.RoomInputOptions(),
room_output_options=room_io.RoomOutputOptions(),
)
logger.info("Voice session started for room %s", ctx.room.name)
# ---------------------------------------------------------------------
# Data channel listener: receive text + attachments from Flutter client
# ---------------------------------------------------------------------
async def _on_data_received(data_packet) -> None:
try:
if data_packet.topic != "text_inject":
return
payload = json.loads(data_packet.data.decode("utf-8"))
text = payload.get("text", "")
attachments = payload.get("attachments")
logger.info(
"text_inject received: text=%s attachments=%d",
text[:80] if text else "(empty)",
len(attachments) if attachments else 0,
)
if not text and not attachments:
return
# Call agent-service with the same session (context preserved)
response = await llm.inject_text_message(
text=text,
attachments=attachments,
)
if response:
logger.info("inject response: %s", response[:100])
session.say(response)
# Send response text back to Flutter for display
try:
reply_payload = json.dumps({
"type": "text_reply",
"text": response,
}).encode("utf-8")
await ctx.room.local_participant.publish_data(
reply_payload,
reliable=True,
topic="text_reply",
)
except Exception as pub_err:
logger.warning("Failed to publish text_reply: %s", pub_err)
else:
logger.warning("inject_text_message returned empty response")
except Exception as exc:
logger.error(
"text_inject handler error: %s: %s",
type(exc).__name__, exc, exc_info=True,
)
# Use ensure_future because ctx.room.on() uses a sync event emitter
# (same pattern as the "disconnected" handler above)
ctx.room.on("data_received", lambda dp: asyncio.ensure_future(_on_data_received(dp)))
except Exception as exc:
logger.error(
"Voice session failed for room %s: %s: %s",
ctx.room.name, type(exc).__name__, exc, exc_info=True,
)
if __name__ == "__main__":
cli.run_app(server)