sglang_v0.5.2/sglang/sgl-router/py_test/unit/test_startup_sequence.py

1054 lines
40 KiB
Python

"""
Unit tests for startup sequence logic in sglang_router.
These tests focus on testing the startup sequence logic in isolation,
including router initialization, configuration validation, and startup flow.
"""
import logging
from types import SimpleNamespace
from unittest.mock import MagicMock, call, patch
import pytest
from sglang_router.launch_router import RouterArgs, launch_router
from sglang_router.router import policy_from_str
# Local helper mirroring the router logger setup used in production
def setup_logger():
logger = logging.getLogger("router")
logger.setLevel(logging.INFO)
if not logger.handlers:
formatter = logging.Formatter(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
from sglang_router_rs import PolicyType
class TestSetupLogger:
"""Test logger setup functionality."""
def test_setup_logger_returns_logger(self):
"""Test that setup_logger returns a logger instance."""
logger = setup_logger()
assert isinstance(logger, logging.Logger)
assert logger.name == "router"
assert logger.level == logging.INFO
def test_setup_logger_has_handler(self):
"""Test that setup_logger configures a handler."""
logger = setup_logger()
assert len(logger.handlers) > 0
handler = logger.handlers[0]
assert isinstance(handler, logging.StreamHandler)
def test_setup_logger_has_formatter(self):
"""Test that setup_logger configures a formatter."""
logger = setup_logger()
handler = logger.handlers[0]
formatter = handler.formatter
assert formatter is not None
assert "[Router (Python)]" in formatter._fmt
def test_setup_logger_multiple_calls(self):
"""Test that multiple calls to setup_logger work correctly."""
logger1 = setup_logger()
logger2 = setup_logger()
# Should return the same logger instance
assert logger1 is logger2
class TestPolicyFromStr:
"""Test policy string to enum conversion in startup context."""
def test_policy_conversion_in_startup(self):
"""Test policy conversion during startup sequence."""
# Test all valid policies
policies = ["random", "round_robin", "cache_aware", "power_of_two"]
expected_enums = [
PolicyType.Random,
PolicyType.RoundRobin,
PolicyType.CacheAware,
PolicyType.PowerOfTwo,
]
for policy_str, expected_enum in zip(policies, expected_enums):
result = policy_from_str(policy_str)
assert result == expected_enum
def test_invalid_policy_in_startup(self):
"""Test handling of invalid policy during startup."""
with pytest.raises(KeyError):
policy_from_str("invalid_policy")
class TestRouterInitialization:
"""Test router initialization logic."""
def test_router_initialization_basic(self):
"""Test basic router initialization."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="cache_aware",
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
# capture needed fields from RouterArgs
captured_args.update(
dict(
host=router_args.host,
port=router_args.port,
worker_urls=router_args.worker_urls,
policy=policy_from_str(router_args.policy),
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify Router.from_args was called and captured fields match
router_mod.from_args.assert_called_once()
assert captured_args["host"] == "127.0.0.1"
assert captured_args["port"] == 30000
assert captured_args["worker_urls"] == ["http://worker1:8000"]
assert captured_args["policy"] == PolicyType.CacheAware
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_pd_mode(self):
"""Test router initialization in PD mode."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", 9000)],
decode_urls=["http://decode1:8001"],
policy="power_of_two",
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
pd_disaggregation=router_args.pd_disaggregation,
prefill_urls=router_args.prefill_urls,
decode_urls=router_args.decode_urls,
policy=policy_from_str(router_args.policy),
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify Router.from_args was called with PD parameters
router_mod.from_args.assert_called_once()
assert captured_args["pd_disaggregation"] is True
assert captured_args["prefill_urls"] == [("http://prefill1:8000", 9000)]
assert captured_args["decode_urls"] == ["http://decode1:8001"]
assert captured_args["policy"] == PolicyType.PowerOfTwo
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_service_discovery(self):
"""Test router initialization with service discovery."""
args = RouterArgs(
service_discovery=True,
selector={"app": "worker", "env": "prod"},
service_discovery_port=8080,
service_discovery_namespace="default",
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
service_discovery=router_args.service_discovery,
selector=router_args.selector,
service_discovery_port=router_args.service_discovery_port,
service_discovery_namespace=router_args.service_discovery_namespace,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify Router.from_args was called with service discovery parameters
router_mod.from_args.assert_called_once()
assert captured_args["service_discovery"] is True
assert captured_args["selector"] == {"app": "worker", "env": "prod"}
assert captured_args["service_discovery_port"] == 8080
assert captured_args["service_discovery_namespace"] == "default"
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_retry_config(self):
"""Test router initialization with retry configuration."""
args = RouterArgs(
retry_max_retries=3,
retry_initial_backoff_ms=100,
retry_max_backoff_ms=10000,
retry_backoff_multiplier=2.0,
retry_jitter_factor=0.1,
disable_retries=False,
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
retry_max_retries=router_args.retry_max_retries,
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
retry_max_backoff_ms=router_args.retry_max_backoff_ms,
retry_backoff_multiplier=router_args.retry_backoff_multiplier,
retry_jitter_factor=router_args.retry_jitter_factor,
disable_retries=router_args.disable_retries,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify router was created with retry parameters
router_mod.from_args.assert_called_once()
assert captured_args["retry_max_retries"] == 3
assert captured_args["retry_initial_backoff_ms"] == 100
assert captured_args["retry_max_backoff_ms"] == 10000
assert captured_args["retry_backoff_multiplier"] == 2.0
assert captured_args["retry_jitter_factor"] == 0.1
assert captured_args["disable_retries"] is False
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_circuit_breaker_config(self):
"""Test router initialization with circuit breaker configuration."""
args = RouterArgs(
cb_failure_threshold=5,
cb_success_threshold=2,
cb_timeout_duration_secs=30,
cb_window_duration_secs=60,
disable_circuit_breaker=False,
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
cb_failure_threshold=router_args.cb_failure_threshold,
cb_success_threshold=router_args.cb_success_threshold,
cb_timeout_duration_secs=router_args.cb_timeout_duration_secs,
cb_window_duration_secs=router_args.cb_window_duration_secs,
disable_circuit_breaker=router_args.disable_circuit_breaker,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify router was created with circuit breaker parameters
router_mod.from_args.assert_called_once()
assert captured_args["cb_failure_threshold"] == 5
assert captured_args["cb_success_threshold"] == 2
assert captured_args["cb_timeout_duration_secs"] == 30
assert captured_args["cb_window_duration_secs"] == 60
assert captured_args["disable_circuit_breaker"] is False
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_rate_limiting_config(self):
"""Test router initialization with rate limiting configuration."""
args = RouterArgs(
max_concurrent_requests=512,
queue_size=200,
queue_timeout_secs=120,
rate_limit_tokens_per_second=100,
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
max_concurrent_requests=router_args.max_concurrent_requests,
queue_size=router_args.queue_size,
queue_timeout_secs=router_args.queue_timeout_secs,
rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify router was created with rate limiting parameters
router_mod.from_args.assert_called_once()
assert captured_args["max_concurrent_requests"] == 512
assert captured_args["queue_size"] == 200
assert captured_args["queue_timeout_secs"] == 120
assert captured_args["rate_limit_tokens_per_second"] == 100
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_health_check_config(self):
"""Test router initialization with health check configuration."""
args = RouterArgs(
health_failure_threshold=2,
health_success_threshold=1,
health_check_timeout_secs=3,
health_check_interval_secs=30,
health_check_endpoint="/healthz",
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
health_failure_threshold=router_args.health_failure_threshold,
health_success_threshold=router_args.health_success_threshold,
health_check_timeout_secs=router_args.health_check_timeout_secs,
health_check_interval_secs=router_args.health_check_interval_secs,
health_check_endpoint=router_args.health_check_endpoint,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify router was created with health check parameters
router_mod.from_args.assert_called_once()
assert captured_args["health_failure_threshold"] == 2
assert captured_args["health_success_threshold"] == 1
assert captured_args["health_check_timeout_secs"] == 3
assert captured_args["health_check_interval_secs"] == 30
assert captured_args["health_check_endpoint"] == "/healthz"
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_prometheus_config(self):
"""Test router initialization with Prometheus configuration."""
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify router was created with Prometheus parameters
router_mod.from_args.assert_called_once()
assert captured_args["prometheus_port"] == 29000
assert captured_args["prometheus_host"] == "127.0.0.1"
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_cors_config(self):
"""Test router initialization with CORS configuration."""
args = RouterArgs(
cors_allowed_origins=["http://localhost:3000", "https://example.com"]
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(cors_allowed_origins=router_args.cors_allowed_origins)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify router was created with CORS parameters
router_mod.from_args.assert_called_once()
assert captured_args["cors_allowed_origins"] == [
"http://localhost:3000",
"https://example.com",
]
# Verify router.start() was called
mock_router_instance.start.assert_called_once()
# Function returns None; ensure start was invoked
def test_router_initialization_with_tokenizer_config(self):
"""Test router initialization with tokenizer configuration."""
# Note: model_path and tokenizer_path are not available in current RouterArgs
pytest.skip("Tokenizer configuration not available in current implementation")
class TestStartupValidation:
"""Test startup validation logic."""
def test_pd_mode_validation_during_startup(self):
"""Test PD mode validation during startup."""
# PD mode without URLs should fail
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=False,
)
with pytest.raises(
ValueError, match="PD disaggregation mode requires --prefill"
):
launch_router(args)
def test_pd_mode_with_service_discovery_validation(self):
"""Test PD mode with service discovery validation during startup."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=True,
)
# Should not raise validation error
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
result = launch_router(args)
# Should create router instance
router_mod.from_args.assert_called_once()
def test_policy_warning_during_startup(self):
"""Test policy warning during startup in PD mode."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
prefill_policy="power_of_two",
decode_policy="round_robin",
)
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
# The policy messages are emitted by router_args logger
with patch("sglang_router.router_args.logger") as mock_logger:
result = launch_router(args)
# Should log warning about policy usage
mock_logger.warning.assert_called_once()
warning_call = mock_logger.warning.call_args[0][0]
assert (
"Both --prefill-policy and --decode-policy are specified"
in warning_call
)
# Should create router instance
router_mod.from_args.assert_called_once()
def test_policy_info_during_startup(self):
"""Test policy info logging during startup in PD mode."""
# Test with only prefill policy specified
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
prefill_policy="power_of_two",
decode_policy=None,
)
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
# The policy messages are emitted by router_args logger
with patch("sglang_router.router_args.logger") as mock_logger:
result = launch_router(args)
# Should log info about policy usage
mock_logger.info.assert_called_once()
info_call = mock_logger.info.call_args[0][0]
assert "Using --prefill-policy 'power_of_two'" in info_call
assert "and --policy 'cache_aware'" in info_call
# Should create router instance
router_mod.from_args.assert_called_once()
def test_policy_info_decode_only_during_startup(self):
"""Test policy info logging during startup with only decode policy specified."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[("http://prefill1:8000", None)],
decode_urls=["http://decode1:8001"],
policy="cache_aware",
prefill_policy=None,
decode_policy="round_robin",
)
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
# The policy messages are emitted by router_args logger
with patch("sglang_router.router_args.logger") as mock_logger:
result = launch_router(args)
# Should log info about policy usage
mock_logger.info.assert_called_once()
info_call = mock_logger.info.call_args[0][0]
assert "Using --policy 'cache_aware'" in info_call
assert "and --decode-policy 'round_robin'" in info_call
# Should create router instance
router_mod.from_args.assert_called_once()
class TestStartupErrorHandling:
"""Test startup error handling logic."""
def test_router_creation_error_handling(self):
"""Test error handling when router creation fails."""
args = RouterArgs(
host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"]
)
with patch("sglang_router.launch_router.Router") as router_mod:
# Simulate router creation failure in from_args
router_mod.from_args = MagicMock(
side_effect=Exception("Router creation failed")
)
with patch("sglang_router.launch_router.logger") as mock_logger:
with pytest.raises(Exception, match="Router creation failed"):
launch_router(args)
# Should log error
mock_logger.error.assert_called_once()
error_call = mock_logger.error.call_args[0][0]
assert "Error starting router: Router creation failed" in error_call
def test_router_start_error_handling(self):
"""Test error handling when router start fails."""
args = RouterArgs(
host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"]
)
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
# Simulate router start failure
mock_router_instance.start.side_effect = Exception("Router start failed")
with patch("sglang_router.launch_router.logger") as mock_logger:
with pytest.raises(Exception, match="Router start failed"):
launch_router(args)
# Should log error
mock_logger.error.assert_called_once()
error_call = mock_logger.error.call_args[0][0]
assert "Error starting router: Router start failed" in error_call
# --- Added unit tests for Router wrapper and launch_server helpers ---
def _install_sglang_stubs(monkeypatch):
"""Install lightweight stubs for sglang.srt to avoid heavy deps during unit tests."""
import sys
import types
sglang_mod = types.ModuleType("sglang")
srt_mod = types.ModuleType("sglang.srt")
entry_mod = types.ModuleType("sglang.srt.entrypoints")
http_server_mod = types.ModuleType("sglang.srt.entrypoints.http_server")
server_args_mod = types.ModuleType("sglang.srt.server_args")
utils_mod = types.ModuleType("sglang.srt.utils")
def launch_server(_args):
return None
class ServerArgs:
# Minimal fields used by launch_server_process
def __init__(self):
self.port = 0
self.base_gpu_id = 0
self.dp_size = 1
self.tp_size = 1
@staticmethod
def add_cli_args(_parser):
return None
@staticmethod
def from_cli_args(_args):
sa = ServerArgs()
if hasattr(_args, "dp_size"):
sa.dp_size = _args.dp_size
if hasattr(_args, "tp_size"):
sa.tp_size = _args.tp_size
if hasattr(_args, "host"):
sa.host = _args.host
else:
sa.host = "127.0.0.1"
return sa
def is_port_available(_port: int) -> bool:
return True
http_server_mod.launch_server = launch_server
server_args_mod.ServerArgs = ServerArgs
utils_mod.is_port_available = is_port_available
# Also stub external deps imported at module top-level
def _dummy_get(*_a, **_k):
raise NotImplementedError
requests_stub = types.SimpleNamespace(
exceptions=types.SimpleNamespace(RequestException=Exception), get=_dummy_get
)
setproctitle_stub = types.SimpleNamespace(setproctitle=lambda *_a, **_k: None)
monkeypatch.setitem(sys.modules, "requests", requests_stub)
monkeypatch.setitem(sys.modules, "setproctitle", setproctitle_stub)
monkeypatch.setitem(sys.modules, "sglang", sglang_mod)
monkeypatch.setitem(sys.modules, "sglang.srt", srt_mod)
monkeypatch.setitem(sys.modules, "sglang.srt.entrypoints", entry_mod)
monkeypatch.setitem(
sys.modules, "sglang.srt.entrypoints.http_server", http_server_mod
)
monkeypatch.setitem(sys.modules, "sglang.srt.server_args", server_args_mod)
monkeypatch.setitem(sys.modules, "sglang.srt.utils", utils_mod)
def test_router_defaults_and_start(monkeypatch):
"""Router wrapper: defaults normalization and start() call.
Mocks the Rust-backed _Router to avoid native deps.
"""
from sglang_router import router as router_mod
captured = {}
class FakeRouter:
def __init__(self, **kwargs):
captured.update(kwargs)
def start(self):
captured["started"] = True
monkeypatch.setattr(router_mod, "_Router", FakeRouter, raising=True)
from sglang_router.router_args import RouterArgs as _RouterArgs
Router = router_mod.Router
args = _RouterArgs(
worker_urls=["http://w1:8000"],
policy="round_robin",
selector=None,
prefill_selector=None,
decode_selector=None,
cors_allowed_origins=None,
)
r = Router.from_args(args)
# Defaults preserved/normalized by Router.from_args
assert captured["selector"] is None
assert captured["prefill_selector"] is None
assert captured["decode_selector"] is None
assert captured["cors_allowed_origins"] is None
assert captured["worker_urls"] == ["http://w1:8000"]
from sglang_router_rs import PolicyType
assert captured["policy"] == PolicyType.RoundRobin
r.start()
assert captured.get("started") is True
def test_find_available_ports_and_wait_health(monkeypatch):
"""launch_server helpers: port finding and health waiting with transient error."""
_install_sglang_stubs(monkeypatch)
import importlib
ls = importlib.import_module("sglang_router.launch_server")
# Deterministic increments
monkeypatch.setattr(ls.random, "randint", lambda a, b: 100)
ports = ls.find_available_ports(30000, 3)
assert ports == [30000, 30100, 30200]
calls = {"n": 0}
class Ok:
status_code = 200
def fake_get(_url, timeout=5):
calls["n"] += 1
if calls["n"] == 1:
raise ls.requests.exceptions.RequestException("boom")
return Ok()
monkeypatch.setattr(ls.requests, "get", fake_get)
monkeypatch.setattr(ls.time, "sleep", lambda _s: None)
base = {"t": 0.0}
monkeypatch.setattr(
ls.time,
"perf_counter",
lambda: (base.__setitem__("t", base["t"] + 0.1) or base["t"]),
)
assert ls.wait_for_server_health("127.0.0.1", 12345, timeout=1)
def test_launch_server_process_and_cleanup(monkeypatch):
"""launch_server: process creation args and cleanup SIGTERM/SIGKILL logic."""
_install_sglang_stubs(monkeypatch)
import importlib
ls = importlib.import_module("sglang_router.launch_server")
created = {}
class FakeProcess:
def __init__(self, target, args):
created["target"] = target
created["args"] = args
self.pid = 4242
self._alive = True
def start(self):
created["started"] = True
def join(self, timeout=None):
return None
def is_alive(self):
return self._alive
monkeypatch.setattr(ls.mp, "Process", FakeProcess)
import sys as _sys
SA = _sys.modules["sglang.srt.server_args"].ServerArgs
sa = SA()
sa.tp_size = 2
proc = ls.launch_server_process(sa, worker_port=31001, dp_id=3)
assert created.get("started") is True
targ, targ_args = created["target"], created["args"]
assert targ is ls.run_server
passed_sa = targ_args[0]
assert passed_sa.port == 31001
assert passed_sa.base_gpu_id == 3 * 2
assert passed_sa.dp_size == 1
# cleanup_processes
p1 = FakeProcess(target=None, args=())
p1._alive = False
p2 = FakeProcess(target=None, args=())
p2._alive = True
calls = []
def fake_killpg(pid, sig):
calls.append((pid, sig))
monkeypatch.setattr(ls.os, "killpg", fake_killpg)
ls.cleanup_processes([p1, p2])
import signal as _sig
assert (p1.pid, _sig.SIGTERM) in calls and (p2.pid, _sig.SIGTERM) in calls
assert (p2.pid, _sig.SIGKILL) in calls
def test_validation_error_handling(self):
"""Test error handling when validation fails."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[],
decode_urls=[],
service_discovery=False,
)
with patch("sglang_router.launch_router.logger") as mock_logger:
with pytest.raises(
ValueError, match="PD disaggregation mode requires --prefill"
):
launch_router(args)
# Should log error for validation failures
mock_logger.error.assert_called_once()
class TestStartupFlow:
"""Test complete startup flow."""
def test_complete_startup_flow_basic(self):
"""Test complete startup flow for basic configuration."""
args = RouterArgs(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000", "http://worker2:8000"],
policy="cache_aware",
cache_threshold=0.5,
balance_abs_threshold=32,
balance_rel_threshold=1.5,
)
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
result = launch_router(args)
# Verify complete flow
router_mod.from_args.assert_called_once()
mock_router_instance.start.assert_called_once()
def test_complete_startup_flow_pd_mode(self):
"""Test complete startup flow for PD mode configuration."""
args = RouterArgs(
pd_disaggregation=True,
prefill_urls=[
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
],
decode_urls=["http://decode1:8001", "http://decode2:8001"],
policy="power_of_two",
prefill_policy="cache_aware",
decode_policy="round_robin",
)
with patch("sglang_router.launch_router.Router") as router_mod:
mock_router_instance = MagicMock()
router_mod.from_args = MagicMock(return_value=mock_router_instance)
with patch("sglang_router.router_args.logger") as mock_logger:
result = launch_router(args)
# Verify complete flow
router_mod.from_args.assert_called_once()
mock_router_instance.start.assert_called_once()
# Verify policy warning was logged
mock_logger.warning.assert_called_once()
def test_complete_startup_flow_with_all_features(self):
"""Test complete startup flow with all features enabled."""
args = RouterArgs(
host="0.0.0.0",
port=30001,
worker_urls=["http://worker1:8000"],
policy="round_robin",
service_discovery=True,
selector={"app": "worker"},
service_discovery_port=8080,
service_discovery_namespace="default",
dp_aware=True,
api_key="test-key",
log_dir="/tmp/logs",
log_level="debug",
prometheus_port=29000,
prometheus_host="0.0.0.0",
request_id_headers=["x-request-id", "x-trace-id"],
request_timeout_secs=1200,
max_concurrent_requests=512,
queue_size=200,
queue_timeout_secs=120,
rate_limit_tokens_per_second=100,
cors_allowed_origins=["http://localhost:3000"],
retry_max_retries=3,
retry_initial_backoff_ms=100,
retry_max_backoff_ms=10000,
retry_backoff_multiplier=2.0,
retry_jitter_factor=0.1,
cb_failure_threshold=5,
cb_success_threshold=2,
cb_timeout_duration_secs=30,
cb_window_duration_secs=60,
health_failure_threshold=2,
health_success_threshold=1,
health_check_timeout_secs=3,
health_check_interval_secs=30,
health_check_endpoint="/healthz",
)
with patch("sglang_router.launch_router.Router") as router_mod:
captured_args = {}
mock_router_instance = MagicMock()
def fake_from_args(router_args):
captured_args.update(
dict(
host=router_args.host,
port=router_args.port,
worker_urls=router_args.worker_urls,
policy=policy_from_str(router_args.policy),
service_discovery=router_args.service_discovery,
selector=router_args.selector,
service_discovery_port=router_args.service_discovery_port,
service_discovery_namespace=router_args.service_discovery_namespace,
dp_aware=router_args.dp_aware,
api_key=router_args.api_key,
log_dir=router_args.log_dir,
log_level=router_args.log_level,
prometheus_port=router_args.prometheus_port,
prometheus_host=router_args.prometheus_host,
request_id_headers=router_args.request_id_headers,
request_timeout_secs=router_args.request_timeout_secs,
max_concurrent_requests=router_args.max_concurrent_requests,
queue_size=router_args.queue_size,
queue_timeout_secs=router_args.queue_timeout_secs,
rate_limit_tokens_per_second=router_args.rate_limit_tokens_per_second,
cors_allowed_origins=router_args.cors_allowed_origins,
retry_max_retries=router_args.retry_max_retries,
retry_initial_backoff_ms=router_args.retry_initial_backoff_ms,
retry_max_backoff_ms=router_args.retry_max_backoff_ms,
retry_backoff_multiplier=router_args.retry_backoff_multiplier,
retry_jitter_factor=router_args.retry_jitter_factor,
cb_failure_threshold=router_args.cb_failure_threshold,
cb_success_threshold=router_args.cb_success_threshold,
cb_timeout_duration_secs=router_args.cb_timeout_duration_secs,
cb_window_duration_secs=router_args.cb_window_duration_secs,
health_failure_threshold=router_args.health_failure_threshold,
health_success_threshold=router_args.health_success_threshold,
health_check_timeout_secs=router_args.health_check_timeout_secs,
health_check_interval_secs=router_args.health_check_interval_secs,
health_check_endpoint=router_args.health_check_endpoint,
)
)
return mock_router_instance
router_mod.from_args = MagicMock(side_effect=fake_from_args)
result = launch_router(args)
# Verify complete flow
router_mod.from_args.assert_called_once()
mock_router_instance.start.assert_called_once()
# Verify key parameters were propagated into RouterArgs
assert captured_args["host"] == "0.0.0.0"
assert captured_args["port"] == 30001
assert captured_args["worker_urls"] == ["http://worker1:8000"]
assert captured_args["policy"] == PolicyType.RoundRobin
assert captured_args["service_discovery"] is True
assert captured_args["selector"] == {"app": "worker"}
assert captured_args["service_discovery_port"] == 8080
assert captured_args["service_discovery_namespace"] == "default"
assert captured_args["dp_aware"] is True
assert captured_args["api_key"] == "test-key"
assert captured_args["log_dir"] == "/tmp/logs"
assert captured_args["log_level"] == "debug"
assert captured_args["prometheus_port"] == 29000
assert captured_args["prometheus_host"] == "0.0.0.0"
assert captured_args["request_id_headers"] == ["x-request-id", "x-trace-id"]
assert captured_args["request_timeout_secs"] == 1200
assert captured_args["max_concurrent_requests"] == 512
assert captured_args["queue_size"] == 200
assert captured_args["queue_timeout_secs"] == 120
assert captured_args["rate_limit_tokens_per_second"] == 100
assert captured_args["cors_allowed_origins"] == ["http://localhost:3000"]
assert captured_args["retry_max_retries"] == 3
assert captured_args["retry_initial_backoff_ms"] == 100
assert captured_args["retry_max_backoff_ms"] == 10000
assert captured_args["retry_backoff_multiplier"] == 2.0
assert captured_args["retry_jitter_factor"] == 0.1
assert captured_args["cb_failure_threshold"] == 5
assert captured_args["cb_success_threshold"] == 2
assert captured_args["cb_timeout_duration_secs"] == 30
assert captured_args["cb_window_duration_secs"] == 60
assert captured_args["health_failure_threshold"] == 2
assert captured_args["health_success_threshold"] == 1
assert captured_args["health_check_timeout_secs"] == 3
assert captured_args["health_check_interval_secs"] == 30
assert captured_args["health_check_endpoint"] == "/healthz"