204 lines
6.2 KiB
Python
204 lines
6.2 KiB
Python
"""
|
||
IT0 Voice Agent — LiveKit Agents v1.x entry point.
|
||
|
||
Uses the official AgentServer + @server.rtc_session() pattern.
|
||
Pipeline: VAD → STT → LLM (via agent-service) → TTS.
|
||
|
||
Usage:
|
||
python -m src.agent start
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
|
||
from livekit.agents import (
|
||
Agent,
|
||
AgentServer,
|
||
AgentSession,
|
||
JobContext,
|
||
JobProcess,
|
||
cli,
|
||
room_io,
|
||
)
|
||
from livekit.plugins import silero
|
||
|
||
from .config import settings
|
||
from .plugins.agent_llm import AgentServiceLLM
|
||
from .plugins.whisper_stt import LocalWhisperSTT
|
||
from .plugins.kokoro_tts import LocalKokoroTTS, patch_misaki_compat
|
||
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class IT0VoiceAgent(Agent):
|
||
"""Voice agent for IT0 server operations platform."""
|
||
|
||
def __init__(self):
|
||
super().__init__(
|
||
instructions=(
|
||
"你是 IT0 服务器运维助手。用户通过语音与你对话,"
|
||
"你帮助管理和监控服务器集群。回答简洁,适合语音对话场景。"
|
||
),
|
||
)
|
||
|
||
async def on_enter(self):
|
||
"""Called when the agent becomes active — greet the user."""
|
||
self.session.generate_reply(
|
||
instructions="用一句简短的话打招呼,告诉用户你是IT0运维助手,可以帮助什么。"
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Server setup
|
||
# ---------------------------------------------------------------------------
|
||
|
||
server = AgentServer()
|
||
|
||
|
||
def prewarm(proc: JobProcess) -> None:
|
||
"""Pre-load ML models into shared process memory.
|
||
|
||
Called once per worker process. Models are shared across all sessions
|
||
handled by this process, avoiding redundant loading.
|
||
"""
|
||
logger.info(
|
||
"Prewarming models (stt=%s, tts=%s, device=%s)",
|
||
settings.stt_provider,
|
||
settings.tts_provider,
|
||
settings.device,
|
||
)
|
||
|
||
# VAD — always needed
|
||
proc.userdata["vad"] = silero.VAD.load()
|
||
logger.info("VAD loaded: Silero VAD")
|
||
|
||
# STT — local faster-whisper
|
||
if settings.stt_provider == "local":
|
||
from faster_whisper import WhisperModel
|
||
|
||
compute_type = "float16" if settings.device == "cuda" else "int8"
|
||
try:
|
||
model = WhisperModel(
|
||
settings.whisper_model,
|
||
device=settings.device,
|
||
compute_type=compute_type,
|
||
)
|
||
except Exception as e:
|
||
logger.warning("Whisper GPU failed, falling back to CPU: %s", e)
|
||
model = WhisperModel(
|
||
settings.whisper_model, device="cpu", compute_type="int8"
|
||
)
|
||
proc.userdata["whisper_model"] = model
|
||
logger.info("STT loaded: faster-whisper %s", settings.whisper_model)
|
||
else:
|
||
proc.userdata["whisper_model"] = None
|
||
logger.info("STT: using OpenAI %s", settings.openai_stt_model)
|
||
|
||
# TTS — local Kokoro
|
||
if settings.tts_provider == "local":
|
||
patch_misaki_compat()
|
||
from kokoro import KPipeline
|
||
|
||
proc.userdata["kokoro_pipeline"] = KPipeline(lang_code="z")
|
||
logger.info("TTS loaded: Kokoro-82M voice=%s", settings.kokoro_voice)
|
||
else:
|
||
proc.userdata["kokoro_pipeline"] = None
|
||
logger.info("TTS: using OpenAI %s", settings.openai_tts_model)
|
||
|
||
logger.info("Prewarm complete.")
|
||
|
||
|
||
server.setup_fnc = prewarm
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Session entrypoint — called for each voice session (room join)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@server.rtc_session(agent_name="voice-agent")
|
||
async def entrypoint(ctx: JobContext) -> None:
|
||
"""Main entrypoint — called for each voice session."""
|
||
logger.info("New voice session: room=%s", ctx.room.name)
|
||
|
||
# Extract auth header from job metadata
|
||
# The token endpoint embeds {"auth_header": "Bearer ..."} via RoomAgentDispatch metadata,
|
||
# which LiveKit passes through as job.metadata to the agent worker.
|
||
auth_header = ""
|
||
try:
|
||
meta_str = ctx.job.metadata or "{}"
|
||
meta = json.loads(meta_str)
|
||
auth_header = meta.get("auth_header", "")
|
||
except Exception as e:
|
||
logger.warning("Failed to parse job metadata: %s", e)
|
||
|
||
logger.info("Auth header present: %s", bool(auth_header))
|
||
|
||
# Build STT
|
||
if settings.stt_provider == "openai":
|
||
from livekit.plugins import openai as openai_plugin
|
||
import httpx as _httpx
|
||
import openai as _openai
|
||
|
||
# OPENAI_BASE_URL may use a self-signed certificate (e.g. proxy)
|
||
_http_client = _httpx.AsyncClient(verify=False)
|
||
_oai_client = _openai.AsyncOpenAI(http_client=_http_client)
|
||
|
||
stt = openai_plugin.STT(
|
||
model=settings.openai_stt_model,
|
||
language=settings.whisper_language,
|
||
client=_oai_client,
|
||
)
|
||
else:
|
||
stt = LocalWhisperSTT(
|
||
model=ctx.proc.userdata.get("whisper_model"),
|
||
language=settings.whisper_language,
|
||
)
|
||
|
||
# Build TTS
|
||
if settings.tts_provider == "openai":
|
||
from livekit.plugins import openai as openai_plugin
|
||
import httpx as _httpx
|
||
import openai as _openai
|
||
|
||
_http_client_tts = _httpx.AsyncClient(verify=False)
|
||
_oai_client_tts = _openai.AsyncOpenAI(http_client=_http_client_tts)
|
||
|
||
tts = openai_plugin.TTS(
|
||
model=settings.openai_tts_model,
|
||
voice=settings.openai_tts_voice,
|
||
client=_oai_client_tts,
|
||
)
|
||
else:
|
||
tts = LocalKokoroTTS(
|
||
pipeline=ctx.proc.userdata.get("kokoro_pipeline"),
|
||
voice=settings.kokoro_voice,
|
||
)
|
||
|
||
# Build custom LLM (proxies to agent-service)
|
||
llm = AgentServiceLLM(
|
||
agent_service_url=settings.agent_service_url,
|
||
auth_header=auth_header,
|
||
)
|
||
|
||
# Create and start AgentSession with the full pipeline
|
||
session = AgentSession(
|
||
vad=ctx.proc.userdata["vad"],
|
||
stt=stt,
|
||
llm=llm,
|
||
tts=tts,
|
||
)
|
||
|
||
await session.start(
|
||
agent=IT0VoiceAgent(),
|
||
room=ctx.room,
|
||
room_input_options=room_io.RoomInputOptions(),
|
||
room_output_options=room_io.RoomOutputOptions(),
|
||
)
|
||
|
||
logger.info("Voice session started for room %s", ctx.room.name)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
cli.run_app(server)
|