507 lines
19 KiB
Python
507 lines
19 KiB
Python
"""
|
|
Unit tests for validation logic in sglang_router.
|
|
|
|
These tests focus on testing the validation logic in isolation,
|
|
including parameter validation, URL validation, and configuration validation.
|
|
"""
|
|
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from sglang_router.launch_router import RouterArgs, launch_router
|
|
|
|
|
|
class TestURLValidation:
|
|
"""Test URL validation logic."""
|
|
|
|
def test_valid_worker_urls(self):
|
|
"""Test validation of valid worker URLs."""
|
|
valid_urls = [
|
|
"http://worker1:8000",
|
|
"https://worker2:8000",
|
|
"http://localhost:8000",
|
|
"http://127.0.0.1:8000",
|
|
"http://192.168.1.100:8000",
|
|
"http://worker.example.com:8000",
|
|
]
|
|
|
|
for url in valid_urls:
|
|
args = RouterArgs(worker_urls=[url])
|
|
# Should not raise any validation errors
|
|
assert url in args.worker_urls
|
|
|
|
def test_valid_prefill_urls(self):
|
|
"""Test validation of valid prefill URLs."""
|
|
valid_prefill_urls = [
|
|
("http://prefill1:8000", 9000),
|
|
("https://prefill2:8000", None),
|
|
("http://localhost:8000", 9000),
|
|
("http://127.0.0.1:8000", None),
|
|
]
|
|
|
|
for url, bootstrap_port in valid_prefill_urls:
|
|
args = RouterArgs(prefill_urls=[(url, bootstrap_port)])
|
|
# Should not raise any validation errors
|
|
assert (url, bootstrap_port) in args.prefill_urls
|
|
|
|
def test_valid_decode_urls(self):
|
|
"""Test validation of valid decode URLs."""
|
|
valid_decode_urls = [
|
|
"http://decode1:8001",
|
|
"https://decode2:8001",
|
|
"http://localhost:8001",
|
|
"http://127.0.0.1:8001",
|
|
]
|
|
|
|
for url in valid_decode_urls:
|
|
args = RouterArgs(decode_urls=[url])
|
|
# Should not raise any validation errors
|
|
assert url in args.decode_urls
|
|
|
|
def test_malformed_urls(self):
|
|
"""Test handling of malformed URLs."""
|
|
# Note: The current implementation doesn't validate URL format
|
|
# This test documents the current behavior
|
|
malformed_urls = [
|
|
"not-a-url",
|
|
"ftp://worker1:8000", # Wrong protocol
|
|
"http://", # Missing host
|
|
":8000", # Missing protocol and host
|
|
"http://worker1", # Missing port
|
|
]
|
|
|
|
for url in malformed_urls:
|
|
args = RouterArgs(worker_urls=[url])
|
|
# Currently, malformed URLs are accepted
|
|
# This might be something to improve in the future
|
|
assert url in args.worker_urls
|
|
|
|
|
|
class TestPortValidation:
|
|
"""Test port validation logic."""
|
|
|
|
def test_valid_ports(self):
|
|
"""Test validation of valid port numbers."""
|
|
valid_ports = [1, 80, 8000, 30000, 65535]
|
|
|
|
for port in valid_ports:
|
|
args = RouterArgs(port=port)
|
|
assert args.port == port
|
|
|
|
def test_invalid_ports(self):
|
|
"""Test handling of invalid port numbers."""
|
|
# Note: The current implementation doesn't validate port ranges
|
|
# This test documents the current behavior
|
|
invalid_ports = [0, -1, 65536, 70000]
|
|
|
|
for port in invalid_ports:
|
|
args = RouterArgs(port=port)
|
|
# Currently, invalid ports are accepted
|
|
# This might be something to improve in the future
|
|
assert args.port == port
|
|
|
|
def test_bootstrap_port_validation(self):
|
|
"""Test validation of bootstrap ports in PD mode."""
|
|
valid_bootstrap_ports = [1, 80, 9000, 30000, 65535, None]
|
|
|
|
for bootstrap_port in valid_bootstrap_ports:
|
|
args = RouterArgs(prefill_urls=[("http://prefill1:8000", bootstrap_port)])
|
|
assert args.prefill_urls[0][1] == bootstrap_port
|
|
|
|
|
|
class TestParameterValidation:
|
|
"""Test parameter validation logic."""
|
|
|
|
def test_cache_threshold_validation(self):
|
|
"""Test cache threshold parameter validation."""
|
|
# Valid cache thresholds
|
|
valid_thresholds = [0.0, 0.1, 0.5, 0.9, 1.0]
|
|
|
|
for threshold in valid_thresholds:
|
|
args = RouterArgs(cache_threshold=threshold)
|
|
assert args.cache_threshold == threshold
|
|
|
|
def test_balance_threshold_validation(self):
|
|
"""Test load balancing threshold parameter validation."""
|
|
# Valid absolute thresholds
|
|
valid_abs_thresholds = [0, 1, 32, 64, 128, 1000]
|
|
for threshold in valid_abs_thresholds:
|
|
args = RouterArgs(balance_abs_threshold=threshold)
|
|
assert args.balance_abs_threshold == threshold
|
|
|
|
# Valid relative thresholds
|
|
valid_rel_thresholds = [1.0, 1.1, 1.5, 2.0, 10.0]
|
|
for threshold in valid_rel_thresholds:
|
|
args = RouterArgs(balance_rel_threshold=threshold)
|
|
assert args.balance_rel_threshold == threshold
|
|
|
|
def test_timeout_validation(self):
|
|
"""Test timeout parameter validation."""
|
|
# Valid timeouts
|
|
valid_timeouts = [1, 30, 60, 300, 600, 1800, 3600]
|
|
|
|
for timeout in valid_timeouts:
|
|
args = RouterArgs(
|
|
worker_startup_timeout_secs=timeout,
|
|
worker_startup_check_interval=timeout,
|
|
request_timeout_secs=timeout,
|
|
queue_timeout_secs=timeout,
|
|
)
|
|
assert args.worker_startup_timeout_secs == timeout
|
|
assert args.worker_startup_check_interval == timeout
|
|
assert args.request_timeout_secs == timeout
|
|
assert args.queue_timeout_secs == timeout
|
|
|
|
def test_retry_parameter_validation(self):
|
|
"""Test retry parameter validation."""
|
|
# Valid retry parameters
|
|
valid_retry_counts = [0, 1, 3, 5, 10]
|
|
for count in valid_retry_counts:
|
|
args = RouterArgs(retry_max_retries=count)
|
|
assert args.retry_max_retries == count
|
|
|
|
# Valid backoff parameters
|
|
valid_backoff_ms = [1, 50, 100, 1000, 30000]
|
|
for backoff in valid_backoff_ms:
|
|
args = RouterArgs(
|
|
retry_initial_backoff_ms=backoff, retry_max_backoff_ms=backoff
|
|
)
|
|
assert args.retry_initial_backoff_ms == backoff
|
|
assert args.retry_max_backoff_ms == backoff
|
|
|
|
# Valid multiplier parameters
|
|
valid_multipliers = [1.0, 1.5, 2.0, 3.0]
|
|
for multiplier in valid_multipliers:
|
|
args = RouterArgs(retry_backoff_multiplier=multiplier)
|
|
assert args.retry_backoff_multiplier == multiplier
|
|
|
|
# Valid jitter parameters
|
|
valid_jitter = [0.0, 0.1, 0.2, 0.5]
|
|
for jitter in valid_jitter:
|
|
args = RouterArgs(retry_jitter_factor=jitter)
|
|
assert args.retry_jitter_factor == jitter
|
|
|
|
def test_circuit_breaker_parameter_validation(self):
|
|
"""Test circuit breaker parameter validation."""
|
|
# Valid failure thresholds
|
|
valid_failure_thresholds = [1, 3, 5, 10, 20]
|
|
for threshold in valid_failure_thresholds:
|
|
args = RouterArgs(cb_failure_threshold=threshold)
|
|
assert args.cb_failure_threshold == threshold
|
|
|
|
# Valid success thresholds
|
|
valid_success_thresholds = [1, 2, 3, 5]
|
|
for threshold in valid_success_thresholds:
|
|
args = RouterArgs(cb_success_threshold=threshold)
|
|
assert args.cb_success_threshold == threshold
|
|
|
|
# Valid timeout durations
|
|
valid_timeouts = [10, 30, 60, 120, 300]
|
|
for timeout in valid_timeouts:
|
|
args = RouterArgs(
|
|
cb_timeout_duration_secs=timeout, cb_window_duration_secs=timeout
|
|
)
|
|
assert args.cb_timeout_duration_secs == timeout
|
|
assert args.cb_window_duration_secs == timeout
|
|
|
|
def test_health_check_parameter_validation(self):
|
|
"""Test health check parameter validation."""
|
|
# Valid failure thresholds
|
|
valid_failure_thresholds = [1, 2, 3, 5, 10]
|
|
for threshold in valid_failure_thresholds:
|
|
args = RouterArgs(health_failure_threshold=threshold)
|
|
assert args.health_failure_threshold == threshold
|
|
|
|
# Valid success thresholds
|
|
valid_success_thresholds = [1, 2, 3, 5]
|
|
for threshold in valid_success_thresholds:
|
|
args = RouterArgs(health_success_threshold=threshold)
|
|
assert args.health_success_threshold == threshold
|
|
|
|
# Valid timeouts and intervals
|
|
valid_times = [1, 5, 10, 30, 60, 120]
|
|
for time_val in valid_times:
|
|
args = RouterArgs(
|
|
health_check_timeout_secs=time_val, health_check_interval_secs=time_val
|
|
)
|
|
assert args.health_check_timeout_secs == time_val
|
|
assert args.health_check_interval_secs == time_val
|
|
|
|
def test_rate_limiting_parameter_validation(self):
|
|
"""Test rate limiting parameter validation."""
|
|
# Valid concurrent request limits
|
|
valid_limits = [1, 10, 64, 256, 512, 1000]
|
|
for limit in valid_limits:
|
|
args = RouterArgs(max_concurrent_requests=limit)
|
|
assert args.max_concurrent_requests == limit
|
|
|
|
# Valid queue sizes
|
|
valid_queue_sizes = [0, 10, 50, 100, 500, 1000]
|
|
for size in valid_queue_sizes:
|
|
args = RouterArgs(queue_size=size)
|
|
assert args.queue_size == size
|
|
|
|
# Valid token rates
|
|
valid_rates = [1, 10, 50, 100, 500, 1000]
|
|
for rate in valid_rates:
|
|
args = RouterArgs(rate_limit_tokens_per_second=rate)
|
|
assert args.rate_limit_tokens_per_second == rate
|
|
|
|
def test_tree_size_validation(self):
|
|
"""Test tree size parameter validation."""
|
|
# Valid tree sizes (powers of 2)
|
|
valid_sizes = [2**10, 2**20, 2**24, 2**26, 2**28, 2**30]
|
|
|
|
for size in valid_sizes:
|
|
args = RouterArgs(max_tree_size=size)
|
|
assert args.max_tree_size == size
|
|
|
|
def test_payload_size_validation(self):
|
|
"""Test payload size parameter validation."""
|
|
# Valid payload sizes
|
|
valid_sizes = [
|
|
1024, # 1KB
|
|
1024 * 1024, # 1MB
|
|
10 * 1024 * 1024, # 10MB
|
|
100 * 1024 * 1024, # 100MB
|
|
512 * 1024 * 1024, # 512MB
|
|
1024 * 1024 * 1024, # 1GB
|
|
]
|
|
|
|
for size in valid_sizes:
|
|
args = RouterArgs(max_payload_size=size)
|
|
assert args.max_payload_size == size
|
|
|
|
|
|
class TestConfigurationValidation:
|
|
"""Test configuration validation logic."""
|
|
|
|
def test_pd_mode_validation(self):
|
|
"""Test PD mode configuration validation."""
|
|
# Valid PD configuration
|
|
args = RouterArgs(
|
|
pd_disaggregation=True,
|
|
prefill_urls=[("http://prefill1:8000", 9000)],
|
|
decode_urls=["http://decode1:8001"],
|
|
)
|
|
|
|
assert args.pd_disaggregation is True
|
|
assert len(args.prefill_urls) > 0
|
|
assert len(args.decode_urls) > 0
|
|
|
|
def test_service_discovery_validation(self):
|
|
"""Test service discovery configuration validation."""
|
|
# Valid service discovery configuration
|
|
args = RouterArgs(
|
|
service_discovery=True,
|
|
selector={"app": "worker", "env": "prod"},
|
|
service_discovery_port=8080,
|
|
service_discovery_namespace="default",
|
|
)
|
|
|
|
assert args.service_discovery is True
|
|
assert args.selector == {"app": "worker", "env": "prod"}
|
|
assert args.service_discovery_port == 8080
|
|
assert args.service_discovery_namespace == "default"
|
|
|
|
def test_pd_service_discovery_validation(self):
|
|
"""Test PD service discovery configuration validation."""
|
|
# Valid PD service discovery configuration
|
|
args = RouterArgs(
|
|
pd_disaggregation=True,
|
|
service_discovery=True,
|
|
prefill_selector={"app": "prefill"},
|
|
decode_selector={"app": "decode"},
|
|
)
|
|
|
|
assert args.pd_disaggregation is True
|
|
assert args.service_discovery is True
|
|
assert args.prefill_selector == {"app": "prefill"}
|
|
assert args.decode_selector == {"app": "decode"}
|
|
|
|
def test_policy_validation(self):
|
|
"""Test policy configuration validation."""
|
|
# Valid policies
|
|
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
|
|
|
|
for policy in valid_policies:
|
|
args = RouterArgs(policy=policy)
|
|
assert args.policy == policy
|
|
|
|
def test_pd_policy_validation(self):
|
|
"""Test PD policy configuration validation."""
|
|
# Valid PD policies
|
|
valid_policies = ["random", "round_robin", "cache_aware", "power_of_two"]
|
|
|
|
for prefill_policy in valid_policies:
|
|
for decode_policy in valid_policies:
|
|
args = RouterArgs(
|
|
pd_disaggregation=True,
|
|
prefill_urls=[("http://prefill1:8000", None)],
|
|
decode_urls=["http://decode1:8001"],
|
|
prefill_policy=prefill_policy,
|
|
decode_policy=decode_policy,
|
|
)
|
|
assert args.prefill_policy == prefill_policy
|
|
assert args.decode_policy == decode_policy
|
|
|
|
def test_cors_validation(self):
|
|
"""Test CORS configuration validation."""
|
|
# Valid CORS origins
|
|
valid_origins = [
|
|
[],
|
|
["http://localhost:3000"],
|
|
["https://example.com"],
|
|
["http://localhost:3000", "https://example.com"],
|
|
["*"], # Wildcard (if supported)
|
|
]
|
|
|
|
for origins in valid_origins:
|
|
args = RouterArgs(cors_allowed_origins=origins)
|
|
assert args.cors_allowed_origins == origins
|
|
|
|
def test_logging_validation(self):
|
|
"""Test logging configuration validation."""
|
|
# Valid log levels
|
|
valid_log_levels = ["debug", "info", "warning", "error", "critical"]
|
|
|
|
for level in valid_log_levels:
|
|
args = RouterArgs(log_level=level)
|
|
assert args.log_level == level
|
|
|
|
def test_prometheus_validation(self):
|
|
"""Test Prometheus configuration validation."""
|
|
# Valid Prometheus configuration
|
|
args = RouterArgs(prometheus_port=29000, prometheus_host="127.0.0.1")
|
|
|
|
assert args.prometheus_port == 29000
|
|
assert args.prometheus_host == "127.0.0.1"
|
|
|
|
def test_tokenizer_validation(self):
|
|
"""Test tokenizer configuration validation."""
|
|
# Note: model_path and tokenizer_path are not available in current RouterArgs
|
|
pytest.skip("Tokenizer configuration not available in current implementation")
|
|
|
|
def test_request_id_headers_validation(self):
|
|
"""Test request ID headers configuration validation."""
|
|
# Valid request ID headers
|
|
valid_headers = [
|
|
["x-request-id"],
|
|
["x-request-id", "x-trace-id"],
|
|
["x-request-id", "x-trace-id", "x-correlation-id"],
|
|
["custom-header"],
|
|
]
|
|
|
|
for headers in valid_headers:
|
|
args = RouterArgs(request_id_headers=headers)
|
|
assert args.request_id_headers == headers
|
|
|
|
|
|
class TestLaunchValidation:
|
|
"""Test launch-time validation logic."""
|
|
|
|
def test_pd_mode_requires_urls(self):
|
|
"""Test that PD mode requires prefill and decode URLs."""
|
|
# PD mode without URLs should fail
|
|
args = RouterArgs(
|
|
pd_disaggregation=True,
|
|
prefill_urls=[],
|
|
decode_urls=[],
|
|
service_discovery=False,
|
|
)
|
|
|
|
with pytest.raises(
|
|
ValueError, match="PD disaggregation mode requires --prefill"
|
|
):
|
|
launch_router(args)
|
|
|
|
def test_pd_mode_with_service_discovery_allows_empty_urls(self):
|
|
"""Test that PD mode with service discovery allows empty URLs."""
|
|
args = RouterArgs(
|
|
pd_disaggregation=True,
|
|
prefill_urls=[],
|
|
decode_urls=[],
|
|
service_discovery=True,
|
|
)
|
|
|
|
# Should not raise validation error
|
|
with patch("sglang_router.launch_router.Router") as router_mod:
|
|
mock_router_instance = MagicMock()
|
|
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
|
|
|
launch_router(args)
|
|
|
|
# Should create router instance via from_args
|
|
router_mod.from_args.assert_called_once()
|
|
|
|
def test_regular_mode_allows_empty_worker_urls(self):
|
|
"""Test that regular mode allows empty worker URLs."""
|
|
args = RouterArgs(worker_urls=[], service_discovery=False)
|
|
|
|
# Should not raise validation error
|
|
with patch("sglang_router.launch_router.Router") as router_mod:
|
|
mock_router_instance = MagicMock()
|
|
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
|
|
|
launch_router(args)
|
|
|
|
# Should create router instance via from_args
|
|
router_mod.from_args.assert_called_once()
|
|
|
|
def test_launch_with_valid_config(self):
|
|
"""Test launching with valid configuration."""
|
|
args = RouterArgs(
|
|
host="127.0.0.1",
|
|
port=30000,
|
|
worker_urls=["http://worker1:8000"],
|
|
policy="cache_aware",
|
|
)
|
|
|
|
# Should not raise validation error
|
|
with patch("sglang_router.launch_router.Router") as router_mod:
|
|
mock_router_instance = MagicMock()
|
|
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
|
|
|
launch_router(args)
|
|
|
|
# Should create router instance via from_args
|
|
router_mod.from_args.assert_called_once()
|
|
|
|
def test_launch_with_pd_config(self):
|
|
"""Test launching with valid PD configuration."""
|
|
args = RouterArgs(
|
|
pd_disaggregation=True,
|
|
prefill_urls=[("http://prefill1:8000", 9000)],
|
|
decode_urls=["http://decode1:8001"],
|
|
policy="cache_aware",
|
|
)
|
|
|
|
# Should not raise validation error
|
|
with patch("sglang_router.launch_router.Router") as router_mod:
|
|
mock_router_instance = MagicMock()
|
|
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
|
|
|
launch_router(args)
|
|
|
|
# Should create router instance via from_args
|
|
router_mod.from_args.assert_called_once()
|
|
|
|
def test_launch_with_service_discovery_config(self):
|
|
"""Test launching with valid service discovery configuration."""
|
|
args = RouterArgs(
|
|
service_discovery=True,
|
|
selector={"app": "worker"},
|
|
service_discovery_port=8080,
|
|
)
|
|
|
|
# Should not raise validation error
|
|
with patch("sglang_router.launch_router.Router") as router_mod:
|
|
mock_router_instance = MagicMock()
|
|
router_mod.from_args = MagicMock(return_value=mock_router_instance)
|
|
|
|
launch_router(args)
|
|
|
|
# Should create router instance via from_args
|
|
router_mod.from_args.assert_called_once()
|