it0/packages/services/voice-agent/src/agent.py

204 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
IT0 Voice Agent — LiveKit Agents v1.x entry point.
Uses the official AgentServer + @server.rtc_session() pattern.
Pipeline: VAD → STT → LLM (via agent-service) → TTS.
Usage:
python -m src.agent start
"""
import json
import logging
from livekit.agents import (
Agent,
AgentServer,
AgentSession,
JobContext,
JobProcess,
cli,
room_io,
)
from livekit.plugins import silero
from .config import settings
from .plugins.agent_llm import AgentServiceLLM
from .plugins.whisper_stt import LocalWhisperSTT
from .plugins.kokoro_tts import LocalKokoroTTS, patch_misaki_compat
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class IT0VoiceAgent(Agent):
"""Voice agent for IT0 server operations platform."""
def __init__(self):
super().__init__(
instructions=(
"你是 IT0 服务器运维助手。用户通过语音与你对话,"
"你帮助管理和监控服务器集群。回答简洁,适合语音对话场景。"
),
)
async def on_enter(self):
"""Called when the agent becomes active — greet the user."""
self.session.generate_reply(
instructions="用一句简短的话打招呼告诉用户你是IT0运维助手可以帮助什么。"
)
# ---------------------------------------------------------------------------
# Server setup
# ---------------------------------------------------------------------------
server = AgentServer()
def prewarm(proc: JobProcess) -> None:
"""Pre-load ML models into shared process memory.
Called once per worker process. Models are shared across all sessions
handled by this process, avoiding redundant loading.
"""
logger.info(
"Prewarming models (stt=%s, tts=%s, device=%s)",
settings.stt_provider,
settings.tts_provider,
settings.device,
)
# VAD — always needed
proc.userdata["vad"] = silero.VAD.load()
logger.info("VAD loaded: Silero VAD")
# STT — local faster-whisper
if settings.stt_provider == "local":
from faster_whisper import WhisperModel
compute_type = "float16" if settings.device == "cuda" else "int8"
try:
model = WhisperModel(
settings.whisper_model,
device=settings.device,
compute_type=compute_type,
)
except Exception as e:
logger.warning("Whisper GPU failed, falling back to CPU: %s", e)
model = WhisperModel(
settings.whisper_model, device="cpu", compute_type="int8"
)
proc.userdata["whisper_model"] = model
logger.info("STT loaded: faster-whisper %s", settings.whisper_model)
else:
proc.userdata["whisper_model"] = None
logger.info("STT: using OpenAI %s", settings.openai_stt_model)
# TTS — local Kokoro
if settings.tts_provider == "local":
patch_misaki_compat()
from kokoro import KPipeline
proc.userdata["kokoro_pipeline"] = KPipeline(lang_code="z")
logger.info("TTS loaded: Kokoro-82M voice=%s", settings.kokoro_voice)
else:
proc.userdata["kokoro_pipeline"] = None
logger.info("TTS: using OpenAI %s", settings.openai_tts_model)
logger.info("Prewarm complete.")
server.setup_fnc = prewarm
# ---------------------------------------------------------------------------
# Session entrypoint — called for each voice session (room join)
# ---------------------------------------------------------------------------
@server.rtc_session(agent_name="voice-agent")
async def entrypoint(ctx: JobContext) -> None:
"""Main entrypoint — called for each voice session."""
logger.info("New voice session: room=%s", ctx.room.name)
# Extract auth header from job metadata
# The token endpoint embeds {"auth_header": "Bearer ..."} via RoomAgentDispatch metadata,
# which LiveKit passes through as job.metadata to the agent worker.
auth_header = ""
try:
meta_str = ctx.job.metadata or "{}"
meta = json.loads(meta_str)
auth_header = meta.get("auth_header", "")
except Exception as e:
logger.warning("Failed to parse job metadata: %s", e)
logger.info("Auth header present: %s", bool(auth_header))
# Build STT
if settings.stt_provider == "openai":
from livekit.plugins import openai as openai_plugin
import httpx as _httpx
import openai as _openai
# OPENAI_BASE_URL may use a self-signed certificate (e.g. proxy)
_http_client = _httpx.AsyncClient(verify=False)
_oai_client = _openai.AsyncOpenAI(http_client=_http_client)
stt = openai_plugin.STT(
model=settings.openai_stt_model,
language=settings.whisper_language,
client=_oai_client,
)
else:
stt = LocalWhisperSTT(
model=ctx.proc.userdata.get("whisper_model"),
language=settings.whisper_language,
)
# Build TTS
if settings.tts_provider == "openai":
from livekit.plugins import openai as openai_plugin
import httpx as _httpx
import openai as _openai
_http_client_tts = _httpx.AsyncClient(verify=False)
_oai_client_tts = _openai.AsyncOpenAI(http_client=_http_client_tts)
tts = openai_plugin.TTS(
model=settings.openai_tts_model,
voice=settings.openai_tts_voice,
client=_oai_client_tts,
)
else:
tts = LocalKokoroTTS(
pipeline=ctx.proc.userdata.get("kokoro_pipeline"),
voice=settings.kokoro_voice,
)
# Build custom LLM (proxies to agent-service)
llm = AgentServiceLLM(
agent_service_url=settings.agent_service_url,
auth_header=auth_header,
)
# Create and start AgentSession with the full pipeline
session = AgentSession(
vad=ctx.proc.userdata["vad"],
stt=stt,
llm=llm,
tts=tts,
)
await session.start(
agent=IT0VoiceAgent(),
room=ctx.room,
room_input_options=room_io.RoomInputOptions(),
room_output_options=room_io.RoomOutputOptions(),
)
logger.info("Voice session started for room %s", ctx.room.name)
if __name__ == "__main__":
cli.run_app(server)