sglang_v0.5.2/sglang/sgl-router/py_test/unit/test_arg_parser.py

629 lines
23 KiB
Python

"""
Unit tests for argument parsing functionality in sglang_router.
These tests focus on testing the argument parsing logic in isolation,
without starting actual router instances.
"""
import argparse
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from sglang_router.launch_router import RouterArgs, parse_router_args
from sglang_router.router import policy_from_str
class TestRouterArgs:
"""Test RouterArgs dataclass and its methods."""
def test_default_values(self):
"""Test that RouterArgs has correct default values."""
args = RouterArgs()
# Test basic defaults
assert args.host == "127.0.0.1"
assert args.port == 30000
assert args.policy == "cache_aware"
assert args.worker_urls == []
assert args.pd_disaggregation is False
assert args.prefill_urls == []
assert args.decode_urls == []
# Test PD-specific defaults
assert args.prefill_policy is None
assert args.decode_policy is None
# Test service discovery defaults
assert args.service_discovery is False
assert args.selector == {}
assert args.service_discovery_port == 80
assert args.service_discovery_namespace is None
# Test retry and circuit breaker defaults
assert args.retry_max_retries == 5
assert args.cb_failure_threshold == 10
assert args.disable_retries is False
assert args.disable_circuit_breaker is False
def test_parse_selector_valid(self):
"""Test parsing valid selector arguments."""
# Test single key-value pair
result = RouterArgs._parse_selector(["app=worker"])
assert result == {"app": "worker"}
# Test multiple key-value pairs
result = RouterArgs._parse_selector(["app=worker", "env=prod", "version=v1"])
assert result == {"app": "worker", "env": "prod", "version": "v1"}
# Test empty list
result = RouterArgs._parse_selector([])
assert result == {}
# Test None
result = RouterArgs._parse_selector(None)
assert result == {}
def test_parse_selector_invalid(self):
"""Test parsing invalid selector arguments."""
# Test malformed selector (no equals sign)
result = RouterArgs._parse_selector(["app"])
assert result == {}
# Test multiple equals signs (should use first one)
result = RouterArgs._parse_selector(["app=worker=extra"])
assert result == {"app": "worker=extra"}
def test_parse_prefill_urls_valid(self):
"""Test parsing valid prefill URL arguments."""
# Test with bootstrap port
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "9000"]])
assert result == [("http://prefill1:8000", 9000)]
# Test with 'none' bootstrap port
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "none"]])
assert result == [("http://prefill1:8000", None)]
# Test without bootstrap port
result = RouterArgs._parse_prefill_urls([["http://prefill1:8000"]])
assert result == [("http://prefill1:8000", None)]
# Test multiple prefill URLs
result = RouterArgs._parse_prefill_urls(
[
["http://prefill1:8000", "9000"],
["http://prefill2:8000", "none"],
["http://prefill3:8000"],
]
)
expected = [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
("http://prefill3:8000", None),
]
assert result == expected
# Test empty list
result = RouterArgs._parse_prefill_urls([])
assert result == []
# Test None
result = RouterArgs._parse_prefill_urls(None)
assert result == []
def test_parse_prefill_urls_invalid(self):
"""Test parsing invalid prefill URL arguments."""
# Test invalid bootstrap port
with pytest.raises(ValueError, match="Invalid bootstrap port"):
RouterArgs._parse_prefill_urls([["http://prefill1:8000", "invalid"]])
def test_parse_decode_urls_valid(self):
"""Test parsing valid decode URL arguments."""
# Test single decode URL
result = RouterArgs._parse_decode_urls([["http://decode1:8001"]])
assert result == ["http://decode1:8001"]
# Test multiple decode URLs
result = RouterArgs._parse_decode_urls(
[["http://decode1:8001"], ["http://decode2:8001"]]
)
assert result == ["http://decode1:8001", "http://decode2:8001"]
# Test empty list
result = RouterArgs._parse_decode_urls([])
assert result == []
# Test None
result = RouterArgs._parse_decode_urls(None)
assert result == []
def test_from_cli_args_basic(self):
"""Test creating RouterArgs from basic CLI arguments."""
args = SimpleNamespace(
host="0.0.0.0",
port=30001,
worker_urls=["http://worker1:8000", "http://worker2:8000"],
policy="round_robin",
prefill=None,
decode=None,
router_policy="round_robin",
router_pd_disaggregation=False,
router_prefill_policy=None,
router_decode_policy=None,
router_worker_startup_timeout_secs=300,
router_worker_startup_check_interval=15,
router_cache_threshold=0.7,
router_balance_abs_threshold=128,
router_balance_rel_threshold=2.0,
router_eviction_interval=180,
router_max_tree_size=2**28,
router_max_payload_size=1024 * 1024 * 1024, # 1GB
router_dp_aware=True,
router_api_key="test-key",
router_log_dir="/tmp/logs",
router_log_level="debug",
router_service_discovery=True,
router_selector=["app=worker", "env=test"],
router_service_discovery_port=8080,
router_service_discovery_namespace="default",
router_prefill_selector=["app=prefill"],
router_decode_selector=["app=decode"],
router_prometheus_port=29000,
router_prometheus_host="0.0.0.0",
router_request_id_headers=["x-request-id", "x-trace-id"],
router_request_timeout_secs=1200,
router_max_concurrent_requests=512,
router_queue_size=200,
router_queue_timeout_secs=120,
router_rate_limit_tokens_per_second=100,
router_cors_allowed_origins=["http://localhost:3000"],
router_retry_max_retries=3,
router_retry_initial_backoff_ms=100,
router_retry_max_backoff_ms=10000,
router_retry_backoff_multiplier=2.0,
router_retry_jitter_factor=0.1,
router_cb_failure_threshold=5,
router_cb_success_threshold=2,
router_cb_timeout_duration_secs=30,
router_cb_window_duration_secs=60,
router_disable_retries=False,
router_disable_circuit_breaker=False,
router_health_failure_threshold=2,
router_health_success_threshold=1,
router_health_check_timeout_secs=3,
router_health_check_interval_secs=30,
router_health_check_endpoint="/healthz",
)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Test basic configuration
assert router_args.host == "0.0.0.0"
assert router_args.port == 30001
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
assert router_args.policy == "round_robin"
# Test PD configuration
assert router_args.pd_disaggregation is False
assert router_args.prefill_urls == []
assert router_args.decode_urls == []
# Test service discovery
assert router_args.service_discovery is True
assert router_args.selector == {"app": "worker", "env": "test"}
assert router_args.service_discovery_port == 8080
assert router_args.service_discovery_namespace == "default"
assert router_args.prefill_selector == {"app": "prefill"}
assert router_args.decode_selector == {"app": "decode"}
# Test other configurations
assert router_args.dp_aware is True
assert router_args.api_key == "test-key"
assert router_args.log_dir == "/tmp/logs"
assert router_args.log_level == "debug"
assert router_args.prometheus_port == 29000
assert router_args.prometheus_host == "0.0.0.0"
assert router_args.request_id_headers == ["x-request-id", "x-trace-id"]
assert router_args.request_timeout_secs == 1200
assert router_args.max_concurrent_requests == 512
assert router_args.queue_size == 200
assert router_args.queue_timeout_secs == 120
assert router_args.rate_limit_tokens_per_second == 100
assert router_args.cors_allowed_origins == ["http://localhost:3000"]
# Test retry configuration
assert router_args.retry_max_retries == 3
assert router_args.retry_initial_backoff_ms == 100
assert router_args.retry_max_backoff_ms == 10000
assert router_args.retry_backoff_multiplier == 2.0
assert router_args.retry_jitter_factor == 0.1
# Test circuit breaker configuration
assert router_args.cb_failure_threshold == 5
assert router_args.cb_success_threshold == 2
assert router_args.cb_timeout_duration_secs == 30
assert router_args.cb_window_duration_secs == 60
assert router_args.disable_retries is False
assert router_args.disable_circuit_breaker is False
# Test health check configuration
assert router_args.health_failure_threshold == 2
assert router_args.health_success_threshold == 1
assert router_args.health_check_timeout_secs == 3
assert router_args.health_check_interval_secs == 30
assert router_args.health_check_endpoint == "/healthz"
# Note: model_path and tokenizer_path are not available in current RouterArgs
def test_from_cli_args_pd_mode(self):
"""Test creating RouterArgs from CLI arguments in PD mode."""
args = SimpleNamespace(
host="127.0.0.1",
port=30000,
worker_urls=[],
policy="cache_aware",
prefill=[
["http://prefill1:8000", "9000"],
["http://prefill2:8000", "none"],
],
decode=[["http://decode1:8001"], ["http://decode2:8001"]],
router_prefill=[
["http://prefill1:8000", "9000"],
["http://prefill2:8000", "none"],
],
router_decode=[["http://decode1:8001"], ["http://decode2:8001"]],
router_policy="cache_aware",
router_pd_disaggregation=True,
router_prefill_policy="power_of_two",
router_decode_policy="round_robin",
# Include all required fields with defaults
router_worker_startup_timeout_secs=600,
router_worker_startup_check_interval=30,
router_cache_threshold=0.3,
router_balance_abs_threshold=64,
router_balance_rel_threshold=1.5,
router_eviction_interval=120,
router_max_tree_size=2**26,
router_max_payload_size=512 * 1024 * 1024,
router_dp_aware=False,
router_api_key=None,
router_log_dir=None,
router_log_level=None,
router_service_discovery=False,
router_selector=None,
router_service_discovery_port=80,
router_service_discovery_namespace=None,
router_prefill_selector=None,
router_decode_selector=None,
router_prometheus_port=None,
router_prometheus_host=None,
router_request_id_headers=None,
router_request_timeout_secs=1800,
router_max_concurrent_requests=256,
router_queue_size=100,
router_queue_timeout_secs=60,
router_rate_limit_tokens_per_second=None,
router_cors_allowed_origins=[],
router_retry_max_retries=5,
router_retry_initial_backoff_ms=50,
router_retry_max_backoff_ms=30000,
router_retry_backoff_multiplier=1.5,
router_retry_jitter_factor=0.2,
router_cb_failure_threshold=10,
router_cb_success_threshold=3,
router_cb_timeout_duration_secs=60,
router_cb_window_duration_secs=120,
router_disable_retries=False,
router_disable_circuit_breaker=False,
router_health_failure_threshold=3,
router_health_success_threshold=2,
router_health_check_timeout_secs=5,
router_health_check_interval_secs=60,
router_health_check_endpoint="/health",
)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Test PD configuration
assert router_args.pd_disaggregation is True
assert router_args.prefill_urls == [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
]
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
assert router_args.prefill_policy == "power_of_two"
assert router_args.decode_policy == "round_robin"
assert router_args.policy == "cache_aware" # Main policy still set
def test_from_cli_args_without_prefix(self):
"""Test creating RouterArgs from CLI arguments without router prefix."""
args = SimpleNamespace(
host="127.0.0.1",
port=30000,
worker_urls=["http://worker1:8000"],
policy="random",
prefill=None,
decode=None,
pd_disaggregation=False,
prefill_policy=None,
decode_policy=None,
worker_startup_timeout_secs=600,
worker_startup_check_interval=30,
cache_threshold=0.3,
balance_abs_threshold=64,
balance_rel_threshold=1.5,
eviction_interval=120,
max_tree_size=2**26,
max_payload_size=512 * 1024 * 1024,
dp_aware=False,
api_key=None,
log_dir=None,
log_level=None,
service_discovery=False,
selector=None,
service_discovery_port=80,
service_discovery_namespace=None,
prefill_selector=None,
decode_selector=None,
prometheus_port=None,
prometheus_host=None,
request_id_headers=None,
request_timeout_secs=1800,
max_concurrent_requests=256,
queue_size=100,
queue_timeout_secs=60,
rate_limit_tokens_per_second=None,
cors_allowed_origins=[],
retry_max_retries=5,
retry_initial_backoff_ms=50,
retry_max_backoff_ms=30000,
retry_backoff_multiplier=1.5,
retry_jitter_factor=0.2,
cb_failure_threshold=10,
cb_success_threshold=3,
cb_timeout_duration_secs=60,
cb_window_duration_secs=120,
disable_retries=False,
disable_circuit_breaker=False,
health_failure_threshold=3,
health_success_threshold=2,
health_check_timeout_secs=5,
health_check_interval_secs=60,
health_check_endpoint="/health",
model_path=None,
tokenizer_path=None,
)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=False)
assert router_args.host == "127.0.0.1"
assert router_args.port == 30000
assert router_args.worker_urls == ["http://worker1:8000"]
assert router_args.policy == "random"
assert router_args.pd_disaggregation is False
class TestPolicyFromStr:
"""Test policy string to enum conversion."""
def test_valid_policies(self):
"""Test conversion of valid policy strings."""
from sglang_router_rs import PolicyType
assert policy_from_str("random") == PolicyType.Random
assert policy_from_str("round_robin") == PolicyType.RoundRobin
assert policy_from_str("cache_aware") == PolicyType.CacheAware
assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo
def test_invalid_policy(self):
"""Test conversion of invalid policy string."""
with pytest.raises(KeyError):
policy_from_str("invalid_policy")
class TestParseRouterArgs:
"""Test the parse_router_args function."""
def test_parse_basic_args(self):
"""Test parsing basic router arguments."""
args = [
"--host",
"0.0.0.0",
"--port",
"30001",
"--worker-urls",
"http://worker1:8000",
"http://worker2:8000",
"--policy",
"round_robin",
]
router_args = parse_router_args(args)
assert router_args.host == "0.0.0.0"
assert router_args.port == 30001
assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"]
assert router_args.policy == "round_robin"
def test_parse_pd_args(self):
"""Test parsing PD disaggregated mode arguments."""
args = [
"--pd-disaggregation",
"--prefill",
"http://prefill1:8000",
"9000",
"--prefill",
"http://prefill2:8000",
"none",
"--decode",
"http://decode1:8001",
"--decode",
"http://decode2:8001",
"--prefill-policy",
"power_of_two",
"--decode-policy",
"round_robin",
]
router_args = parse_router_args(args)
assert router_args.pd_disaggregation is True
assert router_args.prefill_urls == [
("http://prefill1:8000", 9000),
("http://prefill2:8000", None),
]
assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"]
assert router_args.prefill_policy == "power_of_two"
assert router_args.decode_policy == "round_robin"
def test_parse_service_discovery_args(self):
"""Test parsing service discovery arguments."""
args = [
"--service-discovery",
"--selector",
"app=worker",
"env=prod",
"--service-discovery-port",
"8080",
"--service-discovery-namespace",
"default",
]
router_args = parse_router_args(args)
assert router_args.service_discovery is True
assert router_args.selector == {"app": "worker", "env": "prod"}
assert router_args.service_discovery_port == 8080
assert router_args.service_discovery_namespace == "default"
def test_parse_retry_and_circuit_breaker_args(self):
"""Test parsing retry and circuit breaker arguments."""
args = [
"--retry-max-retries",
"3",
"--retry-initial-backoff-ms",
"100",
"--retry-max-backoff-ms",
"10000",
"--retry-backoff-multiplier",
"2.0",
"--retry-jitter-factor",
"0.1",
"--disable-retries",
"--cb-failure-threshold",
"5",
"--cb-success-threshold",
"2",
"--cb-timeout-duration-secs",
"30",
"--cb-window-duration-secs",
"60",
"--disable-circuit-breaker",
]
router_args = parse_router_args(args)
# Test retry configuration
assert router_args.retry_max_retries == 3
assert router_args.retry_initial_backoff_ms == 100
assert router_args.retry_max_backoff_ms == 10000
assert router_args.retry_backoff_multiplier == 2.0
assert router_args.retry_jitter_factor == 0.1
assert router_args.disable_retries is True
# Test circuit breaker configuration
assert router_args.cb_failure_threshold == 5
assert router_args.cb_success_threshold == 2
assert router_args.cb_timeout_duration_secs == 30
assert router_args.cb_window_duration_secs == 60
assert router_args.disable_circuit_breaker is True
def test_parse_rate_limiting_args(self):
"""Test parsing rate limiting arguments."""
args = [
"--max-concurrent-requests",
"512",
"--queue-size",
"200",
"--queue-timeout-secs",
"120",
"--rate-limit-tokens-per-second",
"100",
]
router_args = parse_router_args(args)
assert router_args.max_concurrent_requests == 512
assert router_args.queue_size == 200
assert router_args.queue_timeout_secs == 120
assert router_args.rate_limit_tokens_per_second == 100
def test_parse_health_check_args(self):
"""Test parsing health check arguments."""
args = [
"--health-failure-threshold",
"2",
"--health-success-threshold",
"1",
"--health-check-timeout-secs",
"3",
"--health-check-interval-secs",
"30",
"--health-check-endpoint",
"/healthz",
]
router_args = parse_router_args(args)
assert router_args.health_failure_threshold == 2
assert router_args.health_success_threshold == 1
assert router_args.health_check_timeout_secs == 3
assert router_args.health_check_interval_secs == 30
assert router_args.health_check_endpoint == "/healthz"
def test_parse_cors_args(self):
"""Test parsing CORS arguments."""
args = [
"--cors-allowed-origins",
"http://localhost:3000",
"https://example.com",
]
router_args = parse_router_args(args)
assert router_args.cors_allowed_origins == [
"http://localhost:3000",
"https://example.com",
]
def test_parse_tokenizer_args(self):
"""Test parsing tokenizer arguments."""
# Note: model-path and tokenizer-path arguments are not available in current implementation
# This test is skipped until those arguments are added
pytest.skip("Tokenizer arguments not available in current implementation")
def test_parse_invalid_args(self):
"""Test parsing invalid arguments."""
# Test invalid policy
with pytest.raises(SystemExit):
parse_router_args(["--policy", "invalid_policy"])
# Test invalid bootstrap port
with pytest.raises(ValueError, match="Invalid bootstrap port"):
parse_router_args(
[
"--pd-disaggregation",
"--prefill",
"http://prefill1:8000",
"invalid_port",
]
)
def test_help_output(self):
"""Test that help output is generated correctly."""
with pytest.raises(SystemExit) as exc_info:
parse_router_args(["--help"])
# SystemExit with code 0 indicates help was displayed
assert exc_info.value.code == 0