import argparse import dataclasses import logging import sys from typing import Dict, List, Optional from sglang_router import Router from sglang_router_rs import PolicyType def setup_logger(): logger = logging.getLogger("router") logger.setLevel(logging.INFO) formatter = logging.Formatter( "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) handler = logging.StreamHandler() handler.setFormatter(formatter) logger.addHandler(handler) return logger @dataclasses.dataclass class RouterArgs: # Worker configuration worker_urls: List[str] = dataclasses.field(default_factory=list) host: str = "127.0.0.1" port: int = 30000 # PD-specific configuration pd_disaggregation: bool = False # Enable PD disaggregated mode prefill_urls: List[tuple] = dataclasses.field( default_factory=list ) # List of (url, bootstrap_port) decode_urls: List[str] = dataclasses.field(default_factory=list) # Routing policy policy: str = "cache_aware" worker_startup_timeout_secs: int = 300 worker_startup_check_interval: int = 10 cache_threshold: float = 0.5 balance_abs_threshold: int = 32 balance_rel_threshold: float = 1.0001 eviction_interval: int = 60 max_tree_size: int = 2**24 max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches verbose: bool = False log_dir: Optional[str] = None # Service discovery configuration service_discovery: bool = False selector: Dict[str, str] = dataclasses.field(default_factory=dict) service_discovery_port: int = 80 service_discovery_namespace: Optional[str] = None # PD service discovery configuration prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict) decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict) bootstrap_port_annotation: str = "sglang.ai/bootstrap-port" # Prometheus configuration prometheus_port: Optional[int] = None prometheus_host: Optional[str] = None @staticmethod def add_cli_args( parser: argparse.ArgumentParser, use_router_prefix: bool = False, exclude_host_port: bool = False, ): """ Add router-specific arguments to an argument parser. Args: parser: The argument parser to add arguments to use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts exclude_host_port: If True, don't add host and port arguments (used when inheriting from server) """ prefix = "router-" if use_router_prefix else "" # Worker configuration if not exclude_host_port: parser.add_argument( "--host", type=str, default=RouterArgs.host, help="Host address to bind the router server", ) parser.add_argument( "--port", type=int, default=RouterArgs.port, help="Port number to bind the router server", ) parser.add_argument( "--worker-urls", type=str, nargs="+", help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)", ) # Routing policy configuration parser.add_argument( f"--{prefix}policy", type=str, default=RouterArgs.policy, choices=["random", "round_robin", "cache_aware", "power_of_two"], help="Load balancing policy to use. Note: power_of_two is only available in PD disaggregated mode", ) # PD-specific arguments parser.add_argument( f"--{prefix}pd-disaggregation", action="store_true", help="Enable PD (Prefill-Decode) disaggregated mode", ) parser.add_argument( f"--{prefix}prefill", nargs=2, action="append", metavar=("URL", "BOOTSTRAP_PORT"), help="Prefill server URL and bootstrap port. Can be specified multiple times. BOOTSTRAP_PORT can be 'none' for no bootstrap port.", ) parser.add_argument( f"--{prefix}decode", nargs=1, action="append", metavar=("URL",), help="Decode server URL. Can be specified multiple times.", ) parser.add_argument( f"--{prefix}worker-startup-timeout-secs", type=int, default=RouterArgs.worker_startup_timeout_secs, help="Timeout in seconds for worker startup", ) parser.add_argument( f"--{prefix}worker-startup-check-interval", type=int, default=RouterArgs.worker_startup_check_interval, help="Interval in seconds between checks for worker startup", ) parser.add_argument( f"--{prefix}cache-threshold", type=float, default=RouterArgs.cache_threshold, help="Cache threshold (0.0-1.0) for cache-aware routing", ) parser.add_argument( f"--{prefix}balance-abs-threshold", type=int, default=RouterArgs.balance_abs_threshold, help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware", ) parser.add_argument( f"--{prefix}balance-rel-threshold", type=float, default=RouterArgs.balance_rel_threshold, help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware", ) parser.add_argument( f"--{prefix}eviction-interval", type=int, default=RouterArgs.eviction_interval, help="Interval in seconds between cache eviction operations", ) parser.add_argument( f"--{prefix}max-tree-size", type=int, default=RouterArgs.max_tree_size, help="Maximum size of the approximation tree for cache-aware routing", ) parser.add_argument( f"--{prefix}max-payload-size", type=int, default=RouterArgs.max_payload_size, help="Maximum payload size in bytes", ) parser.add_argument( f"--{prefix}verbose", action="store_true", help="Enable verbose logging", ) parser.add_argument( f"--{prefix}log-dir", type=str, default=None, help="Directory to store log files. If not specified, logs are only output to console.", ) parser.add_argument( f"--{prefix}service-discovery", action="store_true", help="Enable Kubernetes service discovery", ) parser.add_argument( f"--{prefix}selector", type=str, nargs="+", help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)", ) parser.add_argument( f"--{prefix}service-discovery-port", type=int, default=RouterArgs.service_discovery_port, help="Port to use for discovered worker pods", ) parser.add_argument( f"--{prefix}service-discovery-namespace", type=str, help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)", ) parser.add_argument( f"--{prefix}prefill-selector", type=str, nargs="+", help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)", ) parser.add_argument( f"--{prefix}decode-selector", type=str, nargs="+", help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)", ) # Prometheus configuration parser.add_argument( f"--{prefix}prometheus-port", type=int, default=29000, help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled", ) parser.add_argument( f"--{prefix}prometheus-host", type=str, default="127.0.0.1", help="Host address to bind the Prometheus metrics server", ) @classmethod def from_cli_args( cls, args: argparse.Namespace, use_router_prefix: bool = False ) -> "RouterArgs": """ Create RouterArgs instance from parsed command line arguments. Args: args: Parsed command line arguments use_router_prefix: If True, look for arguments with 'router-' prefix """ prefix = "router_" if use_router_prefix else "" worker_urls = getattr(args, "worker_urls", []) # Parse PD URLs prefill_urls = cls._parse_prefill_urls(getattr(args, f"{prefix}prefill", None)) decode_urls = cls._parse_decode_urls(getattr(args, f"{prefix}decode", None)) return cls( worker_urls=worker_urls, host=args.host, port=args.port, pd_disaggregation=getattr(args, f"{prefix}pd_disaggregation", False), prefill_urls=prefill_urls, decode_urls=decode_urls, policy=getattr(args, f"{prefix}policy"), worker_startup_timeout_secs=getattr( args, f"{prefix}worker_startup_timeout_secs" ), worker_startup_check_interval=getattr( args, f"{prefix}worker_startup_check_interval" ), cache_threshold=getattr(args, f"{prefix}cache_threshold"), balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"), balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"), eviction_interval=getattr(args, f"{prefix}eviction_interval"), max_tree_size=getattr(args, f"{prefix}max_tree_size"), max_payload_size=getattr(args, f"{prefix}max_payload_size"), verbose=getattr(args, f"{prefix}verbose", False), log_dir=getattr(args, f"{prefix}log_dir", None), service_discovery=getattr(args, f"{prefix}service_discovery", False), selector=cls._parse_selector(getattr(args, f"{prefix}selector", None)), service_discovery_port=getattr(args, f"{prefix}service_discovery_port"), service_discovery_namespace=getattr( args, f"{prefix}service_discovery_namespace", None ), prefill_selector=cls._parse_selector( getattr(args, f"{prefix}prefill_selector", None) ), decode_selector=cls._parse_selector( getattr(args, f"{prefix}decode_selector", None) ), bootstrap_port_annotation="sglang.ai/bootstrap-port", # Mooncake-specific annotation prometheus_port=getattr(args, f"{prefix}prometheus_port", None), prometheus_host=getattr(args, f"{prefix}prometheus_host", None), ) @staticmethod def _parse_selector(selector_list): if not selector_list: return {} selector = {} for item in selector_list: if "=" in item: key, value = item.split("=", 1) selector[key] = value return selector @staticmethod def _parse_prefill_urls(prefill_list): """Parse prefill URLs from --prefill arguments. Format: --prefill URL BOOTSTRAP_PORT Example: --prefill http://prefill1:8080 9000 --prefill http://prefill2:8080 none """ if not prefill_list: return [] prefill_urls = [] for url, bootstrap_port_str in prefill_list: # Handle 'none' as None if bootstrap_port_str.lower() == "none": bootstrap_port = None else: try: bootstrap_port = int(bootstrap_port_str) except ValueError: raise ValueError( f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'" ) prefill_urls.append((url, bootstrap_port)) return prefill_urls @staticmethod def _parse_decode_urls(decode_list): """Parse decode URLs from --decode arguments. Format: --decode URL Example: --decode http://decode1:8081 --decode http://decode2:8081 """ if not decode_list: return [] # decode_list is a list of single-element lists due to nargs=1 return [url[0] for url in decode_list] def policy_from_str(policy_str: str) -> PolicyType: """Convert policy string to PolicyType enum.""" policy_map = { "random": PolicyType.Random, "round_robin": PolicyType.RoundRobin, "cache_aware": PolicyType.CacheAware, "power_of_two": PolicyType.PowerOfTwo, } return policy_map[policy_str] def launch_router(args: argparse.Namespace) -> Optional[Router]: """ Launch the SGLang router with the configuration from parsed arguments. Args: args: Namespace object containing router configuration Can be either raw argparse.Namespace or converted RouterArgs Returns: Router instance if successful, None if failed """ logger = logging.getLogger("router") try: # Convert to RouterArgs if needed if not isinstance(args, RouterArgs): router_args = RouterArgs.from_cli_args(args) else: router_args = args # Validate configuration based on mode if router_args.pd_disaggregation: # Validate PD configuration - skip URL requirements if using service discovery if not router_args.service_discovery: if not router_args.prefill_urls: raise ValueError("PD disaggregation mode requires --prefill") if not router_args.decode_urls: raise ValueError("PD disaggregation mode requires --decode") # Create router with unified constructor router = Router( worker_urls=( [] if router_args.service_discovery or router_args.pd_disaggregation else router_args.worker_urls ), host=router_args.host, port=router_args.port, policy=policy_from_str(router_args.policy), worker_startup_timeout_secs=router_args.worker_startup_timeout_secs, worker_startup_check_interval=router_args.worker_startup_check_interval, cache_threshold=router_args.cache_threshold, balance_abs_threshold=router_args.balance_abs_threshold, balance_rel_threshold=router_args.balance_rel_threshold, eviction_interval_secs=router_args.eviction_interval, max_tree_size=router_args.max_tree_size, max_payload_size=router_args.max_payload_size, verbose=router_args.verbose, log_dir=router_args.log_dir, service_discovery=router_args.service_discovery, selector=router_args.selector, service_discovery_port=router_args.service_discovery_port, service_discovery_namespace=router_args.service_discovery_namespace, prefill_selector=router_args.prefill_selector, decode_selector=router_args.decode_selector, prometheus_port=router_args.prometheus_port, prometheus_host=router_args.prometheus_host, pd_disaggregation=router_args.pd_disaggregation, prefill_urls=( router_args.prefill_urls if router_args.pd_disaggregation else None ), decode_urls=( router_args.decode_urls if router_args.pd_disaggregation else None ), ) router.start() return router except Exception as e: logger.error(f"Error starting router: {e}") raise e class CustomHelpFormatter( argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter ): """Custom formatter that preserves both description formatting and shows defaults""" pass def parse_router_args(args: List[str]) -> RouterArgs: """Parse command line arguments and return RouterArgs instance.""" parser = argparse.ArgumentParser( description="""SGLang Router - High-performance request distribution across worker nodes Usage: This launcher enables starting a router with individual worker instances. It is useful for multi-node setups or when you want to start workers and router separately. Examples: # Regular mode python -m sglang_router.launch_router --worker-urls http://worker1:8000 http://worker2:8000 # PD disaggregated mode python -m sglang_router.launch_router --pd-disaggregation \\ --prefill http://prefill1:8000 9000 --prefill http://prefill2:8000 none \\ --decode http://decode1:8001 --decode http://decode2:8001 \\ --policy cache_aware """, formatter_class=CustomHelpFormatter, ) RouterArgs.add_cli_args(parser, use_router_prefix=False) return RouterArgs.from_cli_args(parser.parse_args(args), use_router_prefix=False) def main() -> None: router_args = parse_router_args(sys.argv[1:]) launch_router(router_args) if __name__ == "__main__": main()