import argparse import dataclasses import logging from typing import Dict, List, Optional logger = logging.getLogger(__name__) @dataclasses.dataclass class RouterArgs: # Worker configuration worker_urls: List[str] = dataclasses.field(default_factory=list) host: str = "127.0.0.1" port: int = 30000 # PD-specific configuration mini_lb: bool = False pd_disaggregation: bool = False # Enable PD disaggregated mode prefill_urls: List[tuple] = dataclasses.field( default_factory=list ) # List of (url, bootstrap_port) decode_urls: List[str] = dataclasses.field(default_factory=list) # Routing policy policy: str = "cache_aware" prefill_policy: Optional[str] = None # Specific policy for prefill nodes in PD mode decode_policy: Optional[str] = None # Specific policy for decode nodes in PD mode worker_startup_timeout_secs: int = 600 worker_startup_check_interval: int = 30 cache_threshold: float = 0.3 balance_abs_threshold: int = 64 balance_rel_threshold: float = 1.5 eviction_interval_secs: int = 120 max_tree_size: int = 2**26 max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches dp_aware: bool = False api_key: Optional[str] = None log_dir: Optional[str] = None log_level: Optional[str] = None # Service discovery configuration service_discovery: bool = False selector: Dict[str, str] = dataclasses.field(default_factory=dict) service_discovery_port: int = 80 service_discovery_namespace: Optional[str] = None # PD service discovery configuration prefill_selector: Dict[str, str] = dataclasses.field(default_factory=dict) decode_selector: Dict[str, str] = dataclasses.field(default_factory=dict) bootstrap_port_annotation: str = "sglang.ai/bootstrap-port" # Prometheus configuration prometheus_port: Optional[int] = None prometheus_host: Optional[str] = None # Request ID headers configuration request_id_headers: Optional[List[str]] = None # Request timeout in seconds request_timeout_secs: int = 1800 # Max concurrent requests for rate limiting max_concurrent_requests: int = 256 # Queue size for pending requests when max concurrent limit reached queue_size: int = 100 # Maximum time (in seconds) a request can wait in queue before timing out queue_timeout_secs: int = 60 # Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests rate_limit_tokens_per_second: Optional[int] = None # CORS allowed origins cors_allowed_origins: List[str] = dataclasses.field(default_factory=list) # Retry configuration retry_max_retries: int = 5 retry_initial_backoff_ms: int = 50 retry_max_backoff_ms: int = 30_000 retry_backoff_multiplier: float = 1.5 retry_jitter_factor: float = 0.2 disable_retries: bool = False # Health check configuration health_failure_threshold: int = 3 health_success_threshold: int = 2 health_check_timeout_secs: int = 5 health_check_interval_secs: int = 60 health_check_endpoint: str = "/health" # Circuit breaker configuration cb_failure_threshold: int = 10 cb_success_threshold: int = 3 cb_timeout_duration_secs: int = 60 cb_window_duration_secs: int = 120 disable_circuit_breaker: bool = False # Tokenizer configuration model_path: Optional[str] = None tokenizer_path: Optional[str] = None @staticmethod def add_cli_args( parser: argparse.ArgumentParser, use_router_prefix: bool = False, exclude_host_port: bool = False, ): """ Add router-specific arguments to an argument parser. Args: parser: The argument parser to add arguments to use_router_prefix: If True, prefix all arguments with 'router-' to avoid conflicts exclude_host_port: If True, don't add host and port arguments (used when inheriting from server) """ prefix = "router-" if use_router_prefix else "" # Worker configuration if not exclude_host_port: parser.add_argument( "--host", type=str, default=RouterArgs.host, help="Host address to bind the router server", ) parser.add_argument( "--port", type=int, default=RouterArgs.port, help="Port number to bind the router server", ) parser.add_argument( "--worker-urls", type=str, nargs="*", default=[], help="List of worker URLs (e.g., http://worker1:8000 http://worker2:8000)", ) # Routing policy configuration parser.add_argument( f"--{prefix}policy", type=str, default=RouterArgs.policy, choices=["random", "round_robin", "cache_aware", "power_of_two"], help="Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden", ) parser.add_argument( f"--{prefix}prefill-policy", type=str, default=None, choices=["random", "round_robin", "cache_aware", "power_of_two"], help="Specific policy for prefill nodes in PD mode. If not specified, uses the main policy", ) parser.add_argument( f"--{prefix}decode-policy", type=str, default=None, choices=["random", "round_robin", "cache_aware", "power_of_two"], help="Specific policy for decode nodes in PD mode. If not specified, uses the main policy", ) # PD-specific arguments parser.add_argument( f"--{prefix}mini-lb", action="store_true", help="Enable MiniLB", ) parser.add_argument( f"--{prefix}pd-disaggregation", action="store_true", help="Enable PD (Prefill-Decode) disaggregated mode", ) parser.add_argument( f"--{prefix}prefill", nargs="+", action="append", help="Prefill server URL and optional bootstrap port. Can be specified multiple times. " "Format: --prefill URL [BOOTSTRAP_PORT]. " "BOOTSTRAP_PORT can be a port number, 'none', or omitted (defaults to none).", ) parser.add_argument( f"--{prefix}decode", nargs=1, action="append", metavar=("URL",), help="Decode server URL. Can be specified multiple times.", ) parser.add_argument( f"--{prefix}worker-startup-timeout-secs", type=int, default=RouterArgs.worker_startup_timeout_secs, help="Timeout in seconds for worker startup", ) parser.add_argument( f"--{prefix}worker-startup-check-interval", type=int, default=RouterArgs.worker_startup_check_interval, help="Interval in seconds between checks for worker startup", ) parser.add_argument( f"--{prefix}cache-threshold", type=float, default=RouterArgs.cache_threshold, help="Cache threshold (0.0-1.0) for cache-aware routing", ) parser.add_argument( f"--{prefix}balance-abs-threshold", type=int, default=RouterArgs.balance_abs_threshold, help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware", ) parser.add_argument( f"--{prefix}balance-rel-threshold", type=float, default=RouterArgs.balance_rel_threshold, help="Load balancing is triggered when (max_load - min_load) > abs_threshold AND max_load > min_load * rel_threshold. Otherwise, use cache aware", ) parser.add_argument( f"--{prefix}eviction-interval-secs", type=int, default=RouterArgs.eviction_interval_secs, help="Interval in seconds between cache eviction operations", ) parser.add_argument( f"--{prefix}max-tree-size", type=int, default=RouterArgs.max_tree_size, help="Maximum size of the approximation tree for cache-aware routing", ) parser.add_argument( f"--{prefix}max-payload-size", type=int, default=RouterArgs.max_payload_size, help="Maximum payload size in bytes", ) parser.add_argument( f"--{prefix}dp-aware", action="store_true", help="Enable data parallelism aware schedule", ) parser.add_argument( f"--{prefix}api-key", type=str, default=None, help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.", ) parser.add_argument( f"--{prefix}log-dir", type=str, default=None, help="Directory to store log files. If not specified, logs are only output to console.", ) parser.add_argument( f"--{prefix}log-level", type=str, default="info", choices=["debug", "info", "warning", "error", "critical"], help="Set the logging level. If not specified, defaults to INFO.", ) parser.add_argument( f"--{prefix}service-discovery", action="store_true", help="Enable Kubernetes service discovery", ) parser.add_argument( f"--{prefix}selector", type=str, nargs="+", default={}, help="Label selector for Kubernetes service discovery (format: key1=value1 key2=value2)", ) parser.add_argument( f"--{prefix}service-discovery-port", type=int, default=RouterArgs.service_discovery_port, help="Port to use for discovered worker pods", ) parser.add_argument( f"--{prefix}service-discovery-namespace", type=str, help="Kubernetes namespace to watch for pods. If not provided, watches all namespaces (requires cluster-wide permissions)", ) parser.add_argument( f"--{prefix}prefill-selector", type=str, nargs="+", default={}, help="Label selector for prefill server pods in PD mode (format: key1=value1 key2=value2)", ) parser.add_argument( f"--{prefix}decode-selector", type=str, nargs="+", default={}, help="Label selector for decode server pods in PD mode (format: key1=value1 key2=value2)", ) # Prometheus configuration parser.add_argument( f"--{prefix}prometheus-port", type=int, default=29000, help="Port to expose Prometheus metrics. If not specified, Prometheus metrics are disabled", ) parser.add_argument( f"--{prefix}prometheus-host", type=str, default="127.0.0.1", help="Host address to bind the Prometheus metrics server", ) parser.add_argument( f"--{prefix}request-id-headers", type=str, nargs="*", help="Custom HTTP headers to check for request IDs (e.g., x-request-id x-trace-id). If not specified, uses common defaults.", ) parser.add_argument( f"--{prefix}request-timeout-secs", type=int, default=RouterArgs.request_timeout_secs, help="Request timeout in seconds", ) # Retry configuration parser.add_argument( f"--{prefix}retry-max-retries", type=int, default=RouterArgs.retry_max_retries, ) parser.add_argument( f"--{prefix}retry-initial-backoff-ms", type=int, default=RouterArgs.retry_initial_backoff_ms, ) parser.add_argument( f"--{prefix}retry-max-backoff-ms", type=int, default=RouterArgs.retry_max_backoff_ms, ) parser.add_argument( f"--{prefix}retry-backoff-multiplier", type=float, default=RouterArgs.retry_backoff_multiplier, ) parser.add_argument( f"--{prefix}retry-jitter-factor", type=float, default=RouterArgs.retry_jitter_factor, ) parser.add_argument( f"--{prefix}disable-retries", action="store_true", help="Disable retries (equivalent to setting retry_max_retries=1)", ) # Circuit breaker configuration parser.add_argument( f"--{prefix}cb-failure-threshold", type=int, default=RouterArgs.cb_failure_threshold, ) parser.add_argument( f"--{prefix}cb-success-threshold", type=int, default=RouterArgs.cb_success_threshold, ) parser.add_argument( f"--{prefix}cb-timeout-duration-secs", type=int, default=RouterArgs.cb_timeout_duration_secs, ) parser.add_argument( f"--{prefix}cb-window-duration-secs", type=int, default=RouterArgs.cb_window_duration_secs, ) parser.add_argument( f"--{prefix}disable-circuit-breaker", action="store_true", help="Disable circuit breaker (equivalent to setting cb_failure_threshold to u32::MAX)", ) # Health check configuration parser.add_argument( f"--{prefix}health-failure-threshold", type=int, default=RouterArgs.health_failure_threshold, help="Number of consecutive health check failures before marking worker unhealthy", ) parser.add_argument( f"--{prefix}health-success-threshold", type=int, default=RouterArgs.health_success_threshold, help="Number of consecutive health check successes before marking worker healthy", ) parser.add_argument( f"--{prefix}health-check-timeout-secs", type=int, default=RouterArgs.health_check_timeout_secs, help="Timeout in seconds for health check requests", ) parser.add_argument( f"--{prefix}health-check-interval-secs", type=int, default=RouterArgs.health_check_interval_secs, help="Interval in seconds between runtime health checks", ) parser.add_argument( f"--{prefix}health-check-endpoint", type=str, default=RouterArgs.health_check_endpoint, help="Health check endpoint path", ) parser.add_argument( f"--{prefix}max-concurrent-requests", type=int, default=RouterArgs.max_concurrent_requests, help="Maximum number of concurrent requests allowed (for rate limiting)", ) parser.add_argument( f"--{prefix}queue-size", type=int, default=RouterArgs.queue_size, help="Queue size for pending requests when max concurrent limit reached (0 = no queue, return 429 immediately)", ) parser.add_argument( f"--{prefix}queue-timeout-secs", type=int, default=RouterArgs.queue_timeout_secs, help="Maximum time (in seconds) a request can wait in queue before timing out", ) parser.add_argument( f"--{prefix}rate-limit-tokens-per-second", type=int, default=RouterArgs.rate_limit_tokens_per_second, help="Token bucket refill rate (tokens per second). If not set, defaults to max_concurrent_requests", ) parser.add_argument( f"--{prefix}cors-allowed-origins", type=str, nargs="*", default=[], help="CORS allowed origins (e.g., http://localhost:3000 https://example.com)", ) # Tokenizer configuration parser.add_argument( f"--{prefix}model-path", type=str, default=None, help="Model path for loading tokenizer (HuggingFace model ID or local path)", ) parser.add_argument( f"--{prefix}tokenizer-path", type=str, default=None, help="Explicit tokenizer path (overrides model_path tokenizer if provided)", ) @classmethod def from_cli_args( cls, args: argparse.Namespace, use_router_prefix: bool = False ) -> "RouterArgs": """ Create RouterArgs instance from parsed command line arguments. Args: args: Parsed command line arguments use_router_prefix: If True, look for arguments with 'router-' prefix """ prefix = "router_" if use_router_prefix else "" cli_args_dict = vars(args) args_dict = {} for attr in dataclasses.fields(cls): # Auto strip prefix from args if f"{prefix}{attr.name}" in cli_args_dict: args_dict[attr.name] = cli_args_dict[f"{prefix}{attr.name}"] elif attr.name in cli_args_dict: args_dict[attr.name] = cli_args_dict[attr.name] # parse special arguments and remove "--prefill" and "--decode" from cli_args_dict args_dict["prefill_urls"] = cls._parse_prefill_urls( cli_args_dict.get(f"{prefix}prefill", None) ) args_dict["decode_urls"] = cls._parse_decode_urls( cli_args_dict.get(f"{prefix}decode", None) ) args_dict["selector"] = cls._parse_selector( cli_args_dict.get(f"{prefix}selector", None) ) args_dict["prefill_selector"] = cls._parse_selector( cli_args_dict.get(f"{prefix}prefill_selector", None) ) args_dict["decode_selector"] = cls._parse_selector( cli_args_dict.get(f"{prefix}decode_selector", None) ) # Mooncake-specific annotation args_dict["bootstrap_port_annotation"] = "sglang.ai/bootstrap-port" return cls(**args_dict) def _validate_router_args(self): # Validate configuration based on mode if self.pd_disaggregation: # Validate PD configuration - skip URL requirements if using service discovery if not self.service_discovery: if not self.prefill_urls: raise ValueError("PD disaggregation mode requires --prefill") if not self.decode_urls: raise ValueError("PD disaggregation mode requires --decode") # Warn about policy usage in PD mode if self.prefill_policy and self.decode_policy and self.policy: logger.warning( "Both --prefill-policy and --decode-policy are specified. " "The main --policy flag will be ignored for PD mode." ) elif self.prefill_policy and not self.decode_policy and self.policy: logger.info( f"Using --prefill-policy '{self.prefill_policy}' for prefill nodes " f"and --policy '{self.policy}' for decode nodes." ) elif self.decode_policy and not self.prefill_policy and self.policy: logger.info( f"Using --policy '{self.policy}' for prefill nodes " f"and --decode-policy '{self.decode_policy}' for decode nodes." ) @staticmethod def _parse_selector(selector_list): if not selector_list: return {} selector = {} for item in selector_list: if "=" in item: key, value = item.split("=", 1) selector[key] = value return selector @staticmethod def _parse_prefill_urls(prefill_list): """Parse prefill URLs from --prefill arguments. Format: --prefill URL [BOOTSTRAP_PORT] Example: --prefill http://prefill1:8080 9000 # With bootstrap port --prefill http://prefill2:8080 none # Explicitly no bootstrap port --prefill http://prefill3:8080 # Defaults to no bootstrap port """ if not prefill_list: return [] prefill_urls = [] for prefill_args in prefill_list: url = prefill_args[0] # Handle optional bootstrap port if len(prefill_args) >= 2: bootstrap_port_str = prefill_args[1] # Handle 'none' as None if bootstrap_port_str.lower() == "none": bootstrap_port = None else: try: bootstrap_port = int(bootstrap_port_str) except ValueError: raise ValueError( f"Invalid bootstrap port: {bootstrap_port_str}. Must be a number or 'none'" ) else: # No bootstrap port specified, default to None bootstrap_port = None prefill_urls.append((url, bootstrap_port)) return prefill_urls @staticmethod def _parse_decode_urls(decode_list): """Parse decode URLs from --decode arguments. Format: --decode URL Example: --decode http://decode1:8081 --decode http://decode2:8081 """ if not decode_list: return [] # decode_list is a list of single-element lists due to nargs=1 return [url[0] for url in decode_list]