sglang.0.4.8.post1/sglang/sgl-router/py_test/test_launch_router.py

293 lines
10 KiB
Python

import multiprocessing
import time
import unittest
from types import SimpleNamespace
def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> None:
"""Terminate a process gracefully, with forced kill as fallback.
Args:
process: The process to terminate
timeout: Seconds to wait for graceful termination before forcing kill
"""
if not process.is_alive():
return
process.terminate()
process.join(timeout=timeout)
if process.is_alive():
process.kill() # Force kill if terminate didn't work
process.join()
class TestLaunchRouter(unittest.TestCase):
def setUp(self):
"""Set up default arguments for router tests."""
self.default_args = SimpleNamespace(
host="127.0.0.1",
port=30000,
policy="cache_aware",
worker_startup_timeout_secs=600,
worker_startup_check_interval=10,
cache_threshold=0.5,
balance_abs_threshold=32,
balance_rel_threshold=1.0001,
eviction_interval=60,
max_tree_size=2**24,
max_payload_size=256 * 1024 * 1024, # 256MB
verbose=False,
log_dir=None,
service_discovery=False,
selector=None,
service_discovery_port=80,
service_discovery_namespace=None,
prometheus_port=None,
prometheus_host=None,
# PD-specific attributes
pd_disaggregation=False,
prefill=None,
decode=None,
# Keep worker_urls for regular mode
worker_urls=[],
)
def create_router_args(self, **kwargs):
"""Create router arguments by updating default args with provided kwargs."""
args_dict = vars(self.default_args).copy()
args_dict.update(kwargs)
return SimpleNamespace(**args_dict)
def run_router_process(self, args):
"""Run router in a separate process and verify it starts successfully."""
def run_router():
try:
from sglang_router.launch_router import launch_router
router = launch_router(args)
if router is None:
return 1
return 0
except Exception as e:
print(e)
return 1
process = multiprocessing.Process(target=run_router)
try:
process.start()
# Wait 3 seconds
time.sleep(3)
# Process is still running means router started successfully
self.assertTrue(process.is_alive())
finally:
terminate_process(process)
def test_launch_router_common(self):
args = self.create_router_args(worker_urls=["http://localhost:8000"])
self.run_router_process(args)
def test_launch_router_with_empty_worker_urls(self):
args = self.create_router_args(worker_urls=[])
self.run_router_process(args) # Expected error
def test_launch_router_with_service_discovery(self):
# Test router startup with service discovery enabled but no selectors
args = self.create_router_args(
worker_urls=[], service_discovery=True, selector=["app=test-worker"]
)
self.run_router_process(args)
def test_launch_router_with_service_discovery_namespace(self):
# Test router startup with service discovery enabled and namespace specified
args = self.create_router_args(
worker_urls=[],
service_discovery=True,
selector=["app=test-worker"],
service_discovery_namespace="test-namespace",
)
self.run_router_process(args)
def test_launch_router_pd_mode_basic(self):
"""Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured
# without actually starting it (which would require real prefill/decode servers)
from sglang_router import Router
from sglang_router.launch_router import RouterArgs
from sglang_router_rs import PolicyType
# Test RouterArgs parsing for PD mode
# Simulate the parsed args structure from argparse with action="append"
args = self.create_router_args(
pd_disaggregation=True,
policy="power_of_two", # PowerOfTwo is only valid in PD mode
prefill=[
["http://prefill1:8080", "9000"],
["http://prefill2:8080", "none"],
],
decode=[
["http://decode1:8081"],
["http://decode2:8081"],
],
worker_urls=[], # Empty for PD mode
)
router_args = RouterArgs.from_cli_args(args)
self.assertTrue(router_args.pd_disaggregation)
self.assertEqual(router_args.policy, "power_of_two")
self.assertEqual(len(router_args.prefill_urls), 2)
self.assertEqual(len(router_args.decode_urls), 2)
# Verify the parsed URLs and bootstrap ports
self.assertEqual(router_args.prefill_urls[0], ("http://prefill1:8080", 9000))
self.assertEqual(router_args.prefill_urls[1], ("http://prefill2:8080", None))
self.assertEqual(router_args.decode_urls[0], "http://decode1:8081")
self.assertEqual(router_args.decode_urls[1], "http://decode2:8081")
# Test Router creation in PD mode
router = Router(
worker_urls=[], # Empty for PD mode
pd_disaggregation=True,
prefill_urls=[
("http://prefill1:8080", 9000),
("http://prefill2:8080", None),
],
decode_urls=["http://decode1:8081", "http://decode2:8081"],
policy=PolicyType.CacheAware,
host="127.0.0.1",
port=3001,
)
self.assertIsNotNone(router)
def test_policy_validation(self):
"""Test that policy validation works correctly for PD and regular modes."""
from sglang_router.launch_router import RouterArgs, launch_router
# Test 1: PowerOfTwo is only valid in PD mode
args = self.create_router_args(
pd_disaggregation=False,
policy="power_of_two",
worker_urls=["http://localhost:8000"],
)
# Should raise error
with self.assertRaises(ValueError) as cm:
launch_router(args)
self.assertIn(
"PowerOfTwo policy is only supported in PD disaggregated mode",
str(cm.exception),
)
# Test 2: RoundRobin is not valid in PD mode
args = self.create_router_args(
pd_disaggregation=True,
policy="round_robin",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
worker_urls=[],
)
# Should raise error
with self.assertRaises(ValueError) as cm:
launch_router(args)
self.assertIn(
"RoundRobin policy is not supported in PD disaggregated mode",
str(cm.exception),
)
# Test 3: Valid combinations should not raise errors
# Regular mode with RoundRobin
args = self.create_router_args(
pd_disaggregation=False,
policy="round_robin",
worker_urls=["http://localhost:8000"],
)
# This should not raise (though it may fail to connect)
# PD mode with PowerOfTwo
args = self.create_router_args(
pd_disaggregation=True,
policy="power_of_two",
prefill=[["http://prefill1:8080", "9000"]],
decode=[["http://decode1:8081"]],
worker_urls=[],
)
# This should not raise (though it may fail to connect)
def test_pd_service_discovery_args_parsing(self):
"""Test PD service discovery CLI argument parsing."""
import argparse
from sglang_router.launch_router import RouterArgs
parser = argparse.ArgumentParser()
RouterArgs.add_cli_args(parser)
args = parser.parse_args(
[
"--pd-disaggregation",
"--service-discovery",
"--prefill-selector",
"app=sglang",
"component=prefill",
"--decode-selector",
"app=sglang",
"component=decode",
"--service-discovery-port",
"8000",
"--service-discovery-namespace",
"production",
"--policy",
"cache_aware",
]
)
router_args = RouterArgs.from_cli_args(args)
self.assertTrue(router_args.pd_disaggregation)
self.assertTrue(router_args.service_discovery)
self.assertEqual(
router_args.prefill_selector, {"app": "sglang", "component": "prefill"}
)
self.assertEqual(
router_args.decode_selector, {"app": "sglang", "component": "decode"}
)
self.assertEqual(router_args.service_discovery_port, 8000)
self.assertEqual(router_args.service_discovery_namespace, "production")
def test_regular_service_discovery_args_parsing(self):
"""Test regular mode service discovery CLI argument parsing."""
import argparse
from sglang_router.launch_router import RouterArgs
parser = argparse.ArgumentParser()
RouterArgs.add_cli_args(parser)
args = parser.parse_args(
[
"--service-discovery",
"--selector",
"app=sglang-worker",
"environment=staging",
"--service-discovery-port",
"8000",
"--policy",
"round_robin",
]
)
router_args = RouterArgs.from_cli_args(args)
self.assertFalse(router_args.pd_disaggregation)
self.assertTrue(router_args.service_discovery)
self.assertEqual(
router_args.selector, {"app": "sglang-worker", "environment": "staging"}
)
self.assertEqual(router_args.prefill_selector, {})
self.assertEqual(router_args.decode_selector, {})
if __name__ == "__main__":
unittest.main()