""" Unit tests for argument parsing functionality in sglang_router. These tests focus on testing the argument parsing logic in isolation, without starting actual router instances. """ import argparse from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest from sglang_router.launch_router import RouterArgs, parse_router_args from sglang_router.router import policy_from_str class TestRouterArgs: """Test RouterArgs dataclass and its methods.""" def test_default_values(self): """Test that RouterArgs has correct default values.""" args = RouterArgs() # Test basic defaults assert args.host == "127.0.0.1" assert args.port == 30000 assert args.policy == "cache_aware" assert args.worker_urls == [] assert args.pd_disaggregation is False assert args.prefill_urls == [] assert args.decode_urls == [] # Test PD-specific defaults assert args.prefill_policy is None assert args.decode_policy is None # Test service discovery defaults assert args.service_discovery is False assert args.selector == {} assert args.service_discovery_port == 80 assert args.service_discovery_namespace is None # Test retry and circuit breaker defaults assert args.retry_max_retries == 5 assert args.cb_failure_threshold == 10 assert args.disable_retries is False assert args.disable_circuit_breaker is False def test_parse_selector_valid(self): """Test parsing valid selector arguments.""" # Test single key-value pair result = RouterArgs._parse_selector(["app=worker"]) assert result == {"app": "worker"} # Test multiple key-value pairs result = RouterArgs._parse_selector(["app=worker", "env=prod", "version=v1"]) assert result == {"app": "worker", "env": "prod", "version": "v1"} # Test empty list result = RouterArgs._parse_selector([]) assert result == {} # Test None result = RouterArgs._parse_selector(None) assert result == {} def test_parse_selector_invalid(self): """Test parsing invalid selector arguments.""" # Test malformed selector (no equals sign) result = RouterArgs._parse_selector(["app"]) assert result == {} # Test multiple equals signs (should use first one) result = RouterArgs._parse_selector(["app=worker=extra"]) assert result == {"app": "worker=extra"} def test_parse_prefill_urls_valid(self): """Test parsing valid prefill URL arguments.""" # Test with bootstrap port result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "9000"]]) assert result == [("http://prefill1:8000", 9000)] # Test with 'none' bootstrap port result = RouterArgs._parse_prefill_urls([["http://prefill1:8000", "none"]]) assert result == [("http://prefill1:8000", None)] # Test without bootstrap port result = RouterArgs._parse_prefill_urls([["http://prefill1:8000"]]) assert result == [("http://prefill1:8000", None)] # Test multiple prefill URLs result = RouterArgs._parse_prefill_urls( [ ["http://prefill1:8000", "9000"], ["http://prefill2:8000", "none"], ["http://prefill3:8000"], ] ) expected = [ ("http://prefill1:8000", 9000), ("http://prefill2:8000", None), ("http://prefill3:8000", None), ] assert result == expected # Test empty list result = RouterArgs._parse_prefill_urls([]) assert result == [] # Test None result = RouterArgs._parse_prefill_urls(None) assert result == [] def test_parse_prefill_urls_invalid(self): """Test parsing invalid prefill URL arguments.""" # Test invalid bootstrap port with pytest.raises(ValueError, match="Invalid bootstrap port"): RouterArgs._parse_prefill_urls([["http://prefill1:8000", "invalid"]]) def test_parse_decode_urls_valid(self): """Test parsing valid decode URL arguments.""" # Test single decode URL result = RouterArgs._parse_decode_urls([["http://decode1:8001"]]) assert result == ["http://decode1:8001"] # Test multiple decode URLs result = RouterArgs._parse_decode_urls( [["http://decode1:8001"], ["http://decode2:8001"]] ) assert result == ["http://decode1:8001", "http://decode2:8001"] # Test empty list result = RouterArgs._parse_decode_urls([]) assert result == [] # Test None result = RouterArgs._parse_decode_urls(None) assert result == [] def test_from_cli_args_basic(self): """Test creating RouterArgs from basic CLI arguments.""" args = SimpleNamespace( host="0.0.0.0", port=30001, worker_urls=["http://worker1:8000", "http://worker2:8000"], policy="round_robin", prefill=None, decode=None, router_policy="round_robin", router_pd_disaggregation=False, router_prefill_policy=None, router_decode_policy=None, router_worker_startup_timeout_secs=300, router_worker_startup_check_interval=15, router_cache_threshold=0.7, router_balance_abs_threshold=128, router_balance_rel_threshold=2.0, router_eviction_interval=180, router_max_tree_size=2**28, router_max_payload_size=1024 * 1024 * 1024, # 1GB router_dp_aware=True, router_api_key="test-key", router_log_dir="/tmp/logs", router_log_level="debug", router_service_discovery=True, router_selector=["app=worker", "env=test"], router_service_discovery_port=8080, router_service_discovery_namespace="default", router_prefill_selector=["app=prefill"], router_decode_selector=["app=decode"], router_prometheus_port=29000, router_prometheus_host="0.0.0.0", router_request_id_headers=["x-request-id", "x-trace-id"], router_request_timeout_secs=1200, router_max_concurrent_requests=512, router_queue_size=200, router_queue_timeout_secs=120, router_rate_limit_tokens_per_second=100, router_cors_allowed_origins=["http://localhost:3000"], router_retry_max_retries=3, router_retry_initial_backoff_ms=100, router_retry_max_backoff_ms=10000, router_retry_backoff_multiplier=2.0, router_retry_jitter_factor=0.1, router_cb_failure_threshold=5, router_cb_success_threshold=2, router_cb_timeout_duration_secs=30, router_cb_window_duration_secs=60, router_disable_retries=False, router_disable_circuit_breaker=False, router_health_failure_threshold=2, router_health_success_threshold=1, router_health_check_timeout_secs=3, router_health_check_interval_secs=30, router_health_check_endpoint="/healthz", ) router_args = RouterArgs.from_cli_args(args, use_router_prefix=True) # Test basic configuration assert router_args.host == "0.0.0.0" assert router_args.port == 30001 assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"] assert router_args.policy == "round_robin" # Test PD configuration assert router_args.pd_disaggregation is False assert router_args.prefill_urls == [] assert router_args.decode_urls == [] # Test service discovery assert router_args.service_discovery is True assert router_args.selector == {"app": "worker", "env": "test"} assert router_args.service_discovery_port == 8080 assert router_args.service_discovery_namespace == "default" assert router_args.prefill_selector == {"app": "prefill"} assert router_args.decode_selector == {"app": "decode"} # Test other configurations assert router_args.dp_aware is True assert router_args.api_key == "test-key" assert router_args.log_dir == "/tmp/logs" assert router_args.log_level == "debug" assert router_args.prometheus_port == 29000 assert router_args.prometheus_host == "0.0.0.0" assert router_args.request_id_headers == ["x-request-id", "x-trace-id"] assert router_args.request_timeout_secs == 1200 assert router_args.max_concurrent_requests == 512 assert router_args.queue_size == 200 assert router_args.queue_timeout_secs == 120 assert router_args.rate_limit_tokens_per_second == 100 assert router_args.cors_allowed_origins == ["http://localhost:3000"] # Test retry configuration assert router_args.retry_max_retries == 3 assert router_args.retry_initial_backoff_ms == 100 assert router_args.retry_max_backoff_ms == 10000 assert router_args.retry_backoff_multiplier == 2.0 assert router_args.retry_jitter_factor == 0.1 # Test circuit breaker configuration assert router_args.cb_failure_threshold == 5 assert router_args.cb_success_threshold == 2 assert router_args.cb_timeout_duration_secs == 30 assert router_args.cb_window_duration_secs == 60 assert router_args.disable_retries is False assert router_args.disable_circuit_breaker is False # Test health check configuration assert router_args.health_failure_threshold == 2 assert router_args.health_success_threshold == 1 assert router_args.health_check_timeout_secs == 3 assert router_args.health_check_interval_secs == 30 assert router_args.health_check_endpoint == "/healthz" # Note: model_path and tokenizer_path are not available in current RouterArgs def test_from_cli_args_pd_mode(self): """Test creating RouterArgs from CLI arguments in PD mode.""" args = SimpleNamespace( host="127.0.0.1", port=30000, worker_urls=[], policy="cache_aware", prefill=[ ["http://prefill1:8000", "9000"], ["http://prefill2:8000", "none"], ], decode=[["http://decode1:8001"], ["http://decode2:8001"]], router_prefill=[ ["http://prefill1:8000", "9000"], ["http://prefill2:8000", "none"], ], router_decode=[["http://decode1:8001"], ["http://decode2:8001"]], router_policy="cache_aware", router_pd_disaggregation=True, router_prefill_policy="power_of_two", router_decode_policy="round_robin", # Include all required fields with defaults router_worker_startup_timeout_secs=600, router_worker_startup_check_interval=30, router_cache_threshold=0.3, router_balance_abs_threshold=64, router_balance_rel_threshold=1.5, router_eviction_interval=120, router_max_tree_size=2**26, router_max_payload_size=512 * 1024 * 1024, router_dp_aware=False, router_api_key=None, router_log_dir=None, router_log_level=None, router_service_discovery=False, router_selector=None, router_service_discovery_port=80, router_service_discovery_namespace=None, router_prefill_selector=None, router_decode_selector=None, router_prometheus_port=None, router_prometheus_host=None, router_request_id_headers=None, router_request_timeout_secs=1800, router_max_concurrent_requests=256, router_queue_size=100, router_queue_timeout_secs=60, router_rate_limit_tokens_per_second=None, router_cors_allowed_origins=[], router_retry_max_retries=5, router_retry_initial_backoff_ms=50, router_retry_max_backoff_ms=30000, router_retry_backoff_multiplier=1.5, router_retry_jitter_factor=0.2, router_cb_failure_threshold=10, router_cb_success_threshold=3, router_cb_timeout_duration_secs=60, router_cb_window_duration_secs=120, router_disable_retries=False, router_disable_circuit_breaker=False, router_health_failure_threshold=3, router_health_success_threshold=2, router_health_check_timeout_secs=5, router_health_check_interval_secs=60, router_health_check_endpoint="/health", ) router_args = RouterArgs.from_cli_args(args, use_router_prefix=True) # Test PD configuration assert router_args.pd_disaggregation is True assert router_args.prefill_urls == [ ("http://prefill1:8000", 9000), ("http://prefill2:8000", None), ] assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"] assert router_args.prefill_policy == "power_of_two" assert router_args.decode_policy == "round_robin" assert router_args.policy == "cache_aware" # Main policy still set def test_from_cli_args_without_prefix(self): """Test creating RouterArgs from CLI arguments without router prefix.""" args = SimpleNamespace( host="127.0.0.1", port=30000, worker_urls=["http://worker1:8000"], policy="random", prefill=None, decode=None, pd_disaggregation=False, prefill_policy=None, decode_policy=None, worker_startup_timeout_secs=600, worker_startup_check_interval=30, cache_threshold=0.3, balance_abs_threshold=64, balance_rel_threshold=1.5, eviction_interval=120, max_tree_size=2**26, max_payload_size=512 * 1024 * 1024, dp_aware=False, api_key=None, log_dir=None, log_level=None, service_discovery=False, selector=None, service_discovery_port=80, service_discovery_namespace=None, prefill_selector=None, decode_selector=None, prometheus_port=None, prometheus_host=None, request_id_headers=None, request_timeout_secs=1800, max_concurrent_requests=256, queue_size=100, queue_timeout_secs=60, rate_limit_tokens_per_second=None, cors_allowed_origins=[], retry_max_retries=5, retry_initial_backoff_ms=50, retry_max_backoff_ms=30000, retry_backoff_multiplier=1.5, retry_jitter_factor=0.2, cb_failure_threshold=10, cb_success_threshold=3, cb_timeout_duration_secs=60, cb_window_duration_secs=120, disable_retries=False, disable_circuit_breaker=False, health_failure_threshold=3, health_success_threshold=2, health_check_timeout_secs=5, health_check_interval_secs=60, health_check_endpoint="/health", model_path=None, tokenizer_path=None, ) router_args = RouterArgs.from_cli_args(args, use_router_prefix=False) assert router_args.host == "127.0.0.1" assert router_args.port == 30000 assert router_args.worker_urls == ["http://worker1:8000"] assert router_args.policy == "random" assert router_args.pd_disaggregation is False class TestPolicyFromStr: """Test policy string to enum conversion.""" def test_valid_policies(self): """Test conversion of valid policy strings.""" from sglang_router_rs import PolicyType assert policy_from_str("random") == PolicyType.Random assert policy_from_str("round_robin") == PolicyType.RoundRobin assert policy_from_str("cache_aware") == PolicyType.CacheAware assert policy_from_str("power_of_two") == PolicyType.PowerOfTwo def test_invalid_policy(self): """Test conversion of invalid policy string.""" with pytest.raises(KeyError): policy_from_str("invalid_policy") class TestParseRouterArgs: """Test the parse_router_args function.""" def test_parse_basic_args(self): """Test parsing basic router arguments.""" args = [ "--host", "0.0.0.0", "--port", "30001", "--worker-urls", "http://worker1:8000", "http://worker2:8000", "--policy", "round_robin", ] router_args = parse_router_args(args) assert router_args.host == "0.0.0.0" assert router_args.port == 30001 assert router_args.worker_urls == ["http://worker1:8000", "http://worker2:8000"] assert router_args.policy == "round_robin" def test_parse_pd_args(self): """Test parsing PD disaggregated mode arguments.""" args = [ "--pd-disaggregation", "--prefill", "http://prefill1:8000", "9000", "--prefill", "http://prefill2:8000", "none", "--decode", "http://decode1:8001", "--decode", "http://decode2:8001", "--prefill-policy", "power_of_two", "--decode-policy", "round_robin", ] router_args = parse_router_args(args) assert router_args.pd_disaggregation is True assert router_args.prefill_urls == [ ("http://prefill1:8000", 9000), ("http://prefill2:8000", None), ] assert router_args.decode_urls == ["http://decode1:8001", "http://decode2:8001"] assert router_args.prefill_policy == "power_of_two" assert router_args.decode_policy == "round_robin" def test_parse_service_discovery_args(self): """Test parsing service discovery arguments.""" args = [ "--service-discovery", "--selector", "app=worker", "env=prod", "--service-discovery-port", "8080", "--service-discovery-namespace", "default", ] router_args = parse_router_args(args) assert router_args.service_discovery is True assert router_args.selector == {"app": "worker", "env": "prod"} assert router_args.service_discovery_port == 8080 assert router_args.service_discovery_namespace == "default" def test_parse_retry_and_circuit_breaker_args(self): """Test parsing retry and circuit breaker arguments.""" args = [ "--retry-max-retries", "3", "--retry-initial-backoff-ms", "100", "--retry-max-backoff-ms", "10000", "--retry-backoff-multiplier", "2.0", "--retry-jitter-factor", "0.1", "--disable-retries", "--cb-failure-threshold", "5", "--cb-success-threshold", "2", "--cb-timeout-duration-secs", "30", "--cb-window-duration-secs", "60", "--disable-circuit-breaker", ] router_args = parse_router_args(args) # Test retry configuration assert router_args.retry_max_retries == 3 assert router_args.retry_initial_backoff_ms == 100 assert router_args.retry_max_backoff_ms == 10000 assert router_args.retry_backoff_multiplier == 2.0 assert router_args.retry_jitter_factor == 0.1 assert router_args.disable_retries is True # Test circuit breaker configuration assert router_args.cb_failure_threshold == 5 assert router_args.cb_success_threshold == 2 assert router_args.cb_timeout_duration_secs == 30 assert router_args.cb_window_duration_secs == 60 assert router_args.disable_circuit_breaker is True def test_parse_rate_limiting_args(self): """Test parsing rate limiting arguments.""" args = [ "--max-concurrent-requests", "512", "--queue-size", "200", "--queue-timeout-secs", "120", "--rate-limit-tokens-per-second", "100", ] router_args = parse_router_args(args) assert router_args.max_concurrent_requests == 512 assert router_args.queue_size == 200 assert router_args.queue_timeout_secs == 120 assert router_args.rate_limit_tokens_per_second == 100 def test_parse_health_check_args(self): """Test parsing health check arguments.""" args = [ "--health-failure-threshold", "2", "--health-success-threshold", "1", "--health-check-timeout-secs", "3", "--health-check-interval-secs", "30", "--health-check-endpoint", "/healthz", ] router_args = parse_router_args(args) assert router_args.health_failure_threshold == 2 assert router_args.health_success_threshold == 1 assert router_args.health_check_timeout_secs == 3 assert router_args.health_check_interval_secs == 30 assert router_args.health_check_endpoint == "/healthz" def test_parse_cors_args(self): """Test parsing CORS arguments.""" args = [ "--cors-allowed-origins", "http://localhost:3000", "https://example.com", ] router_args = parse_router_args(args) assert router_args.cors_allowed_origins == [ "http://localhost:3000", "https://example.com", ] def test_parse_tokenizer_args(self): """Test parsing tokenizer arguments.""" # Note: model-path and tokenizer-path arguments are not available in current implementation # This test is skipped until those arguments are added pytest.skip("Tokenizer arguments not available in current implementation") def test_parse_invalid_args(self): """Test parsing invalid arguments.""" # Test invalid policy with pytest.raises(SystemExit): parse_router_args(["--policy", "invalid_policy"]) # Test invalid bootstrap port with pytest.raises(ValueError, match="Invalid bootstrap port"): parse_router_args( [ "--pd-disaggregation", "--prefill", "http://prefill1:8000", "invalid_port", ] ) def test_help_output(self): """Test that help output is generated correctly.""" with pytest.raises(SystemExit) as exc_info: parse_router_args(["--help"]) # SystemExit with code 0 indicates help was displayed assert exc_info.value.code == 0