2715 lines
109 KiB
Python
2715 lines
109 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""The arguments of the server."""
|
|
|
|
import argparse
|
|
import dataclasses
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import sys
|
|
import tempfile
|
|
from typing import List, Literal, Optional, Union
|
|
|
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
|
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
|
from sglang.srt.lora.lora_registry import LoRARef
|
|
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
|
from sglang.srt.utils import (
|
|
LORA_TARGET_ALL_MODULES,
|
|
SUPPORTED_LORA_TARGET_MODULES,
|
|
configure_ipv6,
|
|
get_device,
|
|
get_device_memory_capacity,
|
|
is_cuda,
|
|
is_flashinfer_available,
|
|
is_hip,
|
|
is_port_available,
|
|
is_remote_url,
|
|
is_sm90_supported,
|
|
is_sm100_supported,
|
|
is_triton_kernels_available,
|
|
is_valid_ipv6_address,
|
|
nullable_str,
|
|
)
|
|
from sglang.utils import is_in_ci
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Define constants
|
|
LOAD_FORMAT_CHOICES = [
|
|
"auto",
|
|
"pt",
|
|
"safetensors",
|
|
"npcache",
|
|
"dummy",
|
|
"sharded_state",
|
|
"gguf",
|
|
"bitsandbytes",
|
|
"layered",
|
|
"remote",
|
|
]
|
|
|
|
QUANTIZATION_CHOICES = [
|
|
"awq",
|
|
"fp8",
|
|
"gptq",
|
|
"marlin",
|
|
"gptq_marlin",
|
|
"awq_marlin",
|
|
"bitsandbytes",
|
|
"gguf",
|
|
"modelopt",
|
|
"modelopt_fp4",
|
|
"petit_nvfp4",
|
|
"w8a8_int8",
|
|
"w8a8_fp8",
|
|
"moe_wna16",
|
|
"qoq",
|
|
"w4afp8",
|
|
"mxfp4",
|
|
]
|
|
|
|
ATTENTION_BACKEND_CHOICES = [
|
|
# Common
|
|
"triton",
|
|
"torch_native",
|
|
# NVIDIA specific
|
|
"cutlass_mla",
|
|
"fa3",
|
|
"flashinfer",
|
|
"flashmla",
|
|
"trtllm_mla",
|
|
"trtllm_mha",
|
|
"dual_chunk_flash_attn",
|
|
"hybrid_linear_attn",
|
|
# AMD specific
|
|
"aiter",
|
|
"wave",
|
|
# Other platforms
|
|
"intel_amx",
|
|
"ascend",
|
|
]
|
|
|
|
DISAGG_TRANSFER_BACKEND_CHOICES = ["mooncake", "nixl", "ascend", "fake"]
|
|
|
|
GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
|
|
|
|
|
|
# Allow external code to add more choices
|
|
def add_load_format_choices(choices):
|
|
LOAD_FORMAT_CHOICES.extend(choices)
|
|
|
|
|
|
def add_quantization_method_choices(choices):
|
|
QUANTIZATION_CHOICES.extend(choices)
|
|
|
|
|
|
def add_attention_backend_choices(choices):
|
|
ATTENTION_BACKEND_CHOICES.extend(choices)
|
|
|
|
|
|
def add_disagg_transfer_backend_choices(choices):
|
|
DISAGG_TRANSFER_BACKEND_CHOICES.extend(choices)
|
|
|
|
|
|
def add_grammar_backend_choices(choices):
|
|
GRAMMAR_BACKEND_CHOICES.extend(choices)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class ServerArgs:
|
|
# Model and tokenizer
|
|
model_path: str
|
|
tokenizer_path: Optional[str] = None
|
|
tokenizer_mode: str = "auto"
|
|
tokenizer_worker_num: int = 1
|
|
skip_tokenizer_init: bool = False
|
|
load_format: str = "auto"
|
|
model_loader_extra_config: str = "{}"
|
|
trust_remote_code: bool = False
|
|
context_length: Optional[int] = None
|
|
is_embedding: bool = False
|
|
enable_multimodal: Optional[bool] = None
|
|
revision: Optional[str] = None
|
|
model_impl: str = "auto"
|
|
|
|
# HTTP server
|
|
host: str = "127.0.0.1"
|
|
port: int = 30000
|
|
skip_server_warmup: bool = False
|
|
warmups: Optional[str] = None
|
|
nccl_port: Optional[int] = None
|
|
|
|
# Quantization and data type
|
|
dtype: str = "auto"
|
|
quantization: Optional[str] = None
|
|
quantization_param_path: Optional[str] = None
|
|
kv_cache_dtype: str = "auto"
|
|
|
|
# Memory and scheduling
|
|
mem_fraction_static: Optional[float] = None
|
|
max_running_requests: Optional[int] = None
|
|
max_queued_requests: Optional[int] = sys.maxsize
|
|
max_total_tokens: Optional[int] = None
|
|
chunked_prefill_size: Optional[int] = None
|
|
max_prefill_tokens: int = 16384
|
|
schedule_policy: str = "fcfs"
|
|
schedule_conservativeness: float = 1.0
|
|
page_size: Optional[int] = None
|
|
hybrid_kvcache_ratio: Optional[float] = None
|
|
swa_full_tokens_ratio: float = 0.8
|
|
disable_hybrid_swa_memory: bool = False
|
|
|
|
# Runtime options
|
|
device: Optional[str] = None
|
|
tp_size: int = 1
|
|
pp_size: int = 1
|
|
max_micro_batch_size: Optional[int] = None
|
|
stream_interval: int = 1
|
|
stream_output: bool = False
|
|
random_seed: Optional[int] = None
|
|
constrained_json_whitespace_pattern: Optional[str] = None
|
|
watchdog_timeout: float = 300
|
|
dist_timeout: Optional[int] = None # timeout for torch.distributed
|
|
download_dir: Optional[str] = None
|
|
base_gpu_id: int = 0
|
|
gpu_id_step: int = 1
|
|
sleep_on_idle: bool = False
|
|
|
|
# Logging
|
|
log_level: str = "info"
|
|
log_level_http: Optional[str] = None
|
|
log_requests: bool = False
|
|
log_requests_level: int = 2
|
|
crash_dump_folder: Optional[str] = None
|
|
show_time_cost: bool = False
|
|
enable_metrics: bool = False
|
|
enable_metrics_for_all_schedulers: bool = False
|
|
bucket_time_to_first_token: Optional[List[float]] = None
|
|
bucket_inter_token_latency: Optional[List[float]] = None
|
|
bucket_e2e_request_latency: Optional[List[float]] = None
|
|
collect_tokens_histogram: bool = False
|
|
prompt_tokens_buckets: Optional[List[str]] = None
|
|
generation_tokens_buckets: Optional[List[str]] = None
|
|
decode_log_interval: int = 40
|
|
enable_request_time_stats_logging: bool = False
|
|
kv_events_config: Optional[str] = None
|
|
gc_warning_threshold_secs: float = 0.0
|
|
|
|
# API related
|
|
api_key: Optional[str] = None
|
|
served_model_name: Optional[str] = None
|
|
weight_version: str = "default"
|
|
chat_template: Optional[str] = None
|
|
completion_template: Optional[str] = None
|
|
file_storage_path: str = "sglang_storage"
|
|
enable_cache_report: bool = False
|
|
reasoning_parser: Optional[str] = None
|
|
tool_call_parser: Optional[str] = None
|
|
tool_server: Optional[str] = None
|
|
|
|
# Data parallelism
|
|
dp_size: int = 1
|
|
load_balance_method: str = "round_robin"
|
|
# FIXME: remove this after dp rank scheduling is fully supported with PD-Disaggregation
|
|
prefill_round_robin_balance: bool = False
|
|
|
|
# Multi-node distributed serving
|
|
dist_init_addr: Optional[str] = None
|
|
nnodes: int = 1
|
|
node_rank: int = 0
|
|
|
|
# Model override args in JSON
|
|
json_model_override_args: str = "{}"
|
|
preferred_sampling_params: Optional[str] = None
|
|
|
|
# LoRA
|
|
enable_lora: Optional[bool] = None
|
|
max_lora_rank: Optional[int] = None
|
|
lora_target_modules: Optional[Union[set[str], List[str]]] = None
|
|
lora_paths: Optional[
|
|
Union[dict[str, str], List[dict[str, str]], List[str], List[LoRARef]]
|
|
] = None
|
|
max_loaded_loras: Optional[int] = None
|
|
max_loras_per_batch: int = 8
|
|
lora_backend: str = "triton"
|
|
|
|
# Kernel backend
|
|
attention_backend: Optional[str] = None
|
|
decode_attention_backend: Optional[str] = None
|
|
prefill_attention_backend: Optional[str] = None
|
|
sampling_backend: Optional[str] = None
|
|
grammar_backend: Optional[str] = None
|
|
mm_attention_backend: Optional[str] = None
|
|
|
|
# Speculative decoding
|
|
speculative_algorithm: Optional[str] = None
|
|
speculative_draft_model_path: Optional[str] = None
|
|
speculative_draft_model_revision: Optional[str] = None
|
|
speculative_num_steps: Optional[int] = None
|
|
speculative_eagle_topk: Optional[int] = None
|
|
speculative_num_draft_tokens: Optional[int] = None
|
|
speculative_accept_threshold_single: float = 1.0
|
|
speculative_accept_threshold_acc: float = 1.0
|
|
speculative_token_map: Optional[str] = None
|
|
speculative_attention_mode: str = "prefill"
|
|
|
|
# Expert parallelism
|
|
ep_size: int = 1
|
|
moe_a2a_backend: Literal["none", "deepep"] = "none"
|
|
moe_runner_backend: Literal[
|
|
"auto",
|
|
"triton",
|
|
"triton_kernel",
|
|
"flashinfer_trtllm",
|
|
"flashinfer_cutlass",
|
|
"flashinfer_mxfp4",
|
|
] = "auto"
|
|
flashinfer_mxfp4_moe_precision: Literal["default", "bf16"] = "default"
|
|
enable_flashinfer_allreduce_fusion: bool = False
|
|
deepep_mode: Literal["auto", "normal", "low_latency"] = "auto"
|
|
ep_num_redundant_experts: int = 0
|
|
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
|
init_expert_location: str = "trivial"
|
|
enable_eplb: bool = False
|
|
eplb_algorithm: str = "auto"
|
|
eplb_rebalance_num_iterations: int = 1000
|
|
eplb_rebalance_layers_per_chunk: Optional[int] = None
|
|
eplb_min_rebalancing_utilization_threshold: float = 1.0
|
|
expert_distribution_recorder_mode: Optional[
|
|
Literal["stat", "stat_approx", "per_pass", "per_token"]
|
|
] = None
|
|
expert_distribution_recorder_buffer_size: Optional[int] = None
|
|
enable_expert_distribution_metrics: bool = False
|
|
deepep_config: Optional[str] = None
|
|
moe_dense_tp_size: Optional[int] = None
|
|
|
|
# Hierarchical cache
|
|
enable_hierarchical_cache: bool = False
|
|
hicache_ratio: float = 2.0
|
|
hicache_size: int = 0
|
|
hicache_write_policy: str = "write_through"
|
|
hicache_io_backend: str = "kernel"
|
|
hicache_mem_layout: str = "layer_first"
|
|
hicache_storage_backend: Optional[str] = None
|
|
hicache_storage_prefetch_policy: str = "best_effort"
|
|
hicache_storage_backend_extra_config: Optional[str] = None
|
|
# LMCache
|
|
enable_lmcache: bool = False
|
|
|
|
# Double Sparsity
|
|
enable_double_sparsity: bool = False
|
|
ds_channel_config_path: Optional[str] = None
|
|
ds_heavy_channel_num: int = 32
|
|
ds_heavy_token_num: int = 256
|
|
ds_heavy_channel_type: str = "qk"
|
|
ds_sparse_decode_threshold: int = 4096
|
|
|
|
# Offloading
|
|
cpu_offload_gb: int = 0
|
|
offload_group_size: int = -1
|
|
offload_num_in_group: int = 1
|
|
offload_prefetch_step: int = 1
|
|
offload_mode: str = "cpu"
|
|
|
|
# Optimization/debug options
|
|
disable_radix_cache: bool = False
|
|
cuda_graph_max_bs: Optional[int] = None
|
|
cuda_graph_bs: Optional[List[int]] = None
|
|
disable_cuda_graph: bool = False
|
|
disable_cuda_graph_padding: bool = False
|
|
enable_profile_cuda_graph: bool = False
|
|
enable_cudagraph_gc: bool = False
|
|
enable_nccl_nvls: bool = False
|
|
enable_symm_mem: bool = False
|
|
disable_flashinfer_cutlass_moe_fp4_allgather: bool = False
|
|
enable_tokenizer_batch_encode: bool = False
|
|
disable_outlines_disk_cache: bool = False
|
|
disable_custom_all_reduce: bool = False
|
|
enable_mscclpp: bool = False
|
|
disable_overlap_schedule: bool = False
|
|
enable_mixed_chunk: bool = False
|
|
enable_dp_attention: bool = False
|
|
enable_dp_lm_head: bool = False
|
|
enable_two_batch_overlap: bool = False
|
|
tbo_token_distribution_threshold: float = 0.48
|
|
enable_torch_compile: bool = False
|
|
torch_compile_max_bs: int = 32
|
|
torchao_config: str = ""
|
|
enable_nan_detection: bool = False
|
|
enable_p2p_check: bool = False
|
|
triton_attention_reduce_in_fp32: bool = False
|
|
triton_attention_num_kv_splits: int = 8
|
|
num_continuous_decode_steps: int = 1
|
|
delete_ckpt_after_loading: bool = False
|
|
enable_memory_saver: bool = False
|
|
allow_auto_truncate: bool = False
|
|
enable_custom_logit_processor: bool = False
|
|
flashinfer_mla_disable_ragged: bool = False
|
|
disable_shared_experts_fusion: bool = False
|
|
disable_chunked_prefix_cache: bool = False
|
|
disable_fast_image_processor: bool = False
|
|
enable_return_hidden_states: bool = False
|
|
scheduler_recv_interval: int = 1
|
|
numa_node: Optional[List[int]] = None
|
|
|
|
# Debug tensor dumps
|
|
debug_tensor_dump_output_folder: Optional[str] = None
|
|
debug_tensor_dump_input_file: Optional[str] = None
|
|
debug_tensor_dump_inject: bool = False
|
|
debug_tensor_dump_prefill_only: bool = False
|
|
|
|
# PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
|
disaggregation_mode: str = "null"
|
|
disaggregation_transfer_backend: str = "mooncake"
|
|
disaggregation_bootstrap_port: int = 8998
|
|
disaggregation_decode_tp: Optional[int] = None
|
|
disaggregation_decode_dp: Optional[int] = None
|
|
disaggregation_prefill_pp: Optional[int] = 1
|
|
disaggregation_ib_device: Optional[str] = None
|
|
num_reserved_decode_tokens: int = 512 # used for decode kv cache offload in PD
|
|
|
|
# For model weight update
|
|
custom_weight_loader: Optional[List[str]] = None
|
|
weight_loader_disable_mmap: bool = False
|
|
|
|
# For PD-Multiplexing
|
|
enable_pdmux: bool = False
|
|
sm_group_num: int = 3
|
|
|
|
# Mamba cache
|
|
max_mamba_cache_size: Optional[int] = None
|
|
mamba_ssm_dtype: str = "float32"
|
|
|
|
# Deprecated arguments
|
|
enable_ep_moe: bool = False
|
|
enable_deepep_moe: bool = False
|
|
enable_flashinfer_cutlass_moe: bool = False
|
|
enable_flashinfer_trtllm_moe: bool = False
|
|
enable_triton_kernel_moe: bool = False
|
|
enable_flashinfer_mxfp4_moe: bool = False
|
|
|
|
def __post_init__(self):
|
|
# Check deprecated arguments
|
|
if self.enable_ep_moe:
|
|
self.ep_size = self.tp_size
|
|
print_deprecated_warning(
|
|
"NOTE: --enable-ep-moe is deprecated. Please set `--ep-size` to the same value as `--tp-size` instead."
|
|
)
|
|
if self.enable_deepep_moe:
|
|
self.moe_a2a_backend = "deepep"
|
|
print_deprecated_warning(
|
|
"NOTE: --enable-deepep-moe is deprecated. Please set `--moe-a2a-backend` to 'deepep' instead."
|
|
)
|
|
if self.enable_triton_kernel_moe:
|
|
self.moe_runner_backend = "triton_kernel"
|
|
print_deprecated_warning(
|
|
"NOTE: --enable-triton-kernel-moe is deprecated. Please set `--moe-runner-backend` to 'triton_kernel' instead."
|
|
)
|
|
if self.enable_flashinfer_cutlass_moe:
|
|
self.moe_runner_backend = "flashinfer_cutlass"
|
|
print_deprecated_warning(
|
|
"NOTE: --enable-flashinfer-cutlass-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_cutlass' instead."
|
|
)
|
|
if self.enable_flashinfer_trtllm_moe:
|
|
self.moe_runner_backend = "flashinfer_trtllm"
|
|
print_deprecated_warning(
|
|
"NOTE: --enable-flashinfer-trtllm-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_trtllm' instead."
|
|
)
|
|
if self.enable_flashinfer_mxfp4_moe:
|
|
self.moe_runner_backend = "flashinfer_mxfp4"
|
|
print_deprecated_warning(
|
|
"NOTE: --enable-flashinfer-mxfp4-moe is deprecated. Please set `--moe-runner-backend` to 'flashinfer_mxfp4' instead."
|
|
)
|
|
|
|
# Set missing default values
|
|
if self.tokenizer_path is None:
|
|
self.tokenizer_path = self.model_path
|
|
if self.served_model_name is None:
|
|
self.served_model_name = self.model_path
|
|
if self.device is None:
|
|
self.device = get_device()
|
|
if self.random_seed is None:
|
|
self.random_seed = random.randint(0, 1 << 30)
|
|
|
|
gpu_mem = get_device_memory_capacity(self.device)
|
|
|
|
# Set mem fraction static
|
|
if self.mem_fraction_static is None:
|
|
if gpu_mem is not None:
|
|
# GPU memory capacity = model weights + KV cache pool + activations + cuda graph buffers
|
|
# mem_fraction_static = (model weights + KV cache pool) / GPU memory capacity.
|
|
|
|
# We want mem_fraction_static to be as large as possible but still has enough room
|
|
# for activations and cuda graph buffers. We use the following heuristic to
|
|
# compute the needed size for activations and cuda graph buffers:
|
|
# - The size of the activation depends on the chunked_prefill_size and model size.
|
|
# - The size of cuda graph buffers depends on the cuda graph capture range and model size.
|
|
# For GPUs with more memory, we use a larger chunked_prefill_size and
|
|
# capture more cuda graphs, so they need to reserve more memory.
|
|
parallel_size = self.tp_size * self.pp_size
|
|
|
|
if gpu_mem < 20 * 1024:
|
|
# T4, 4080. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
|
|
reserved_mem = (2.8 + parallel_size / 10) * 1024
|
|
elif gpu_mem < 35 * 1024:
|
|
# A10, L40, 4090, 5090. (chunked_prefill_size 2k, cuda_graph_max_bs 8)
|
|
reserved_mem = (2.8 + parallel_size / 10) * 1024
|
|
elif gpu_mem < 90 * 1024:
|
|
# H100, A100. (chunked_prefill_size 8k, cuda_graph_max_bs 160)
|
|
reserved_mem = (9.5 + parallel_size / 2) * 1024
|
|
elif gpu_mem < 100 * 1024:
|
|
# H20. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
|
|
reserved_mem = (12 + parallel_size / 2) * 1024
|
|
elif gpu_mem < 160 * 1024:
|
|
# H200. (chunked_prefill_size 8k, cuda_graph_max_bs 256)
|
|
reserved_mem = (12 + parallel_size / 2) * 1024
|
|
else:
|
|
# B200, MI300. (chunked_prefill_size 16k, cuda_graph_max_bs 512)
|
|
reserved_mem = 32 * 1024
|
|
|
|
# draft model and larger cuda graph buffers
|
|
if self.speculative_algorithm is not None:
|
|
if self.speculative_algorithm == "STANDALONE":
|
|
# Standalone speculative decoding needs more memory than other speculative
|
|
# decoding algorithms since the draft model is typically larger.
|
|
reserved_mem += 6 * 1024
|
|
else:
|
|
reserved_mem += 2 * 1024
|
|
if self.enable_dp_attention:
|
|
reserved_mem += 4 * 1024
|
|
|
|
self.mem_fraction_static = round((gpu_mem - reserved_mem) / gpu_mem, 3)
|
|
else:
|
|
self.mem_fraction_static = 0.88
|
|
|
|
# Lazy init to avoid circular import
|
|
# Multimodal models need more memory for the image processor
|
|
from sglang.srt.configs.model_config import ModelConfig
|
|
|
|
model_config = ModelConfig.from_server_args(self)
|
|
if model_config.is_multimodal:
|
|
self.adjust_mem_fraction_for_vlm(model_config)
|
|
|
|
# Set chunked prefill size, which depends on the gpu memory capacity
|
|
if self.chunked_prefill_size is None:
|
|
if gpu_mem is not None:
|
|
if gpu_mem < 35 * 1024: # A10, L40, 4090
|
|
self.chunked_prefill_size = 2048
|
|
elif gpu_mem < 160 * 1024: # H100, H200, A100, H20
|
|
self.chunked_prefill_size = 8192
|
|
else: # B200, MI300
|
|
self.chunked_prefill_size = 16384
|
|
else:
|
|
self.chunked_prefill_size = 4096
|
|
|
|
# Set cuda graph max batch size
|
|
if self.cuda_graph_max_bs is None:
|
|
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
|
|
if gpu_mem is not None and gpu_mem < 35 * 1024:
|
|
if self.tp_size < 4:
|
|
self.cuda_graph_max_bs = 8
|
|
else:
|
|
self.cuda_graph_max_bs = 80
|
|
|
|
# Set kernel backends for hpu device
|
|
if self.device == "hpu":
|
|
self.attention_backend = "torch_native"
|
|
self.sampling_backend = "pytorch"
|
|
|
|
# Model-specific adjustments
|
|
self.model_specific_adjustments()
|
|
|
|
# Set kernel backends
|
|
if self.device == "cpu":
|
|
if self.attention_backend is None:
|
|
self.attention_backend = "intel_amx"
|
|
self.sampling_backend = "pytorch"
|
|
|
|
if self.sampling_backend is None:
|
|
self.sampling_backend = (
|
|
"flashinfer" if is_flashinfer_available() else "pytorch"
|
|
)
|
|
|
|
if self.attention_backend == "torch_native":
|
|
logger.warning(
|
|
"Cuda graph is disabled because of using torch native attention backend"
|
|
)
|
|
self.disable_cuda_graph = True
|
|
|
|
if self.attention_backend == "ascend":
|
|
logger.warning(
|
|
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
|
|
)
|
|
self.page_size = 128
|
|
|
|
if (
|
|
self.attention_backend == "flashmla"
|
|
or self.decode_attention_backend == "flashmla"
|
|
):
|
|
logger.warning(
|
|
"FlashMLA only supports a page_size of 64, change page_size to 64."
|
|
)
|
|
self.page_size = 64
|
|
|
|
if (
|
|
self.attention_backend == "cutlass_mla"
|
|
or self.decode_attention_backend == "cutlass_mla"
|
|
):
|
|
logger.warning(
|
|
"Cutlass MLA only supports a page_size of 128, change page_size to 128."
|
|
)
|
|
self.page_size = 128
|
|
|
|
if (
|
|
self.attention_backend == "trtllm_mla"
|
|
or self.decode_attention_backend == "trtllm_mla"
|
|
):
|
|
if not is_sm100_supported():
|
|
raise ValueError(
|
|
"TRTLLM MLA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
|
|
)
|
|
|
|
if self.page_size not in [32, 64]:
|
|
logger.warning(
|
|
f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 64."
|
|
)
|
|
self.page_size = 64
|
|
|
|
if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
|
|
raise ValueError(
|
|
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
|
|
)
|
|
|
|
if (
|
|
self.attention_backend == "trtllm_mha"
|
|
or self.decode_attention_backend == "trtllm_mha"
|
|
or self.prefill_attention_backend == "trtllm_mha"
|
|
):
|
|
if not is_sm100_supported():
|
|
raise ValueError(
|
|
"TRTLLM MHA backend is only supported on Blackwell GPUs (SM100). Please use a different backend."
|
|
)
|
|
|
|
if self.page_size not in [16, 32, 64]:
|
|
logger.warning(
|
|
f"TensorRT-LLM MHA only supports page_size of 16, 32 or 64, changing page_size from {self.page_size} to 64."
|
|
)
|
|
self.page_size = 64
|
|
|
|
if self.attention_backend == "dual_chunk_flash_attn":
|
|
logger.warning(
|
|
"Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend"
|
|
)
|
|
self.enable_mixed_chunk = False
|
|
self.disable_cuda_graph = True
|
|
self.disable_radix_cache = True
|
|
|
|
# Set page size
|
|
if self.page_size is None:
|
|
self.page_size = 1
|
|
|
|
# AMD-specific Triton attention KV splits default number
|
|
if is_hip():
|
|
self.triton_attention_num_kv_splits = 16
|
|
|
|
# Choose grammar backend
|
|
if self.grammar_backend is None:
|
|
self.grammar_backend = "xgrammar"
|
|
|
|
if self.dp_size == 1:
|
|
self.enable_dp_attention = False
|
|
|
|
# Data parallelism attention
|
|
if self.enable_dp_attention:
|
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
|
assert self.tp_size % self.dp_size == 0
|
|
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
|
|
logger.warning(
|
|
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
|
)
|
|
|
|
if self.enable_dp_lm_head:
|
|
assert (
|
|
self.enable_dp_attention
|
|
), "Please enable dp attention when setting enable_dp_lm_head. "
|
|
|
|
# MoE kernel
|
|
if self.moe_runner_backend == "flashinfer_cutlass":
|
|
assert (
|
|
self.quantization == "modelopt_fp4"
|
|
), "modelopt_fp4 quantization is required for Flashinfer MOE"
|
|
assert self.ep_size in [
|
|
1,
|
|
self.tp_size,
|
|
], "The expert parallel size must be 1 or the same as the tensor parallel size"
|
|
|
|
if self.moe_runner_backend == "flashinfer_trtllm":
|
|
assert (
|
|
self.quantization == "modelopt_fp4" or self.quantization == "fp8"
|
|
), "modelopt_fp4 quantization is required for Flashinfer TRTLLM MoE"
|
|
self.disable_shared_experts_fusion = True
|
|
logger.warning(
|
|
"FlashInfer TRTLLM MoE is enabled. --disable-shared-experts-fusion is automatically set."
|
|
)
|
|
|
|
# DeepEP MoE
|
|
if self.moe_a2a_backend == "deepep":
|
|
if self.deepep_mode == "normal":
|
|
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
|
self.disable_cuda_graph = True
|
|
self.ep_size = self.tp_size
|
|
logger.warning(
|
|
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
|
)
|
|
|
|
if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
|
|
self.expert_distribution_recorder_mode = "stat"
|
|
logger.warning(
|
|
"EPLB is enabled. The expert_distribution_recorder_mode is automatically set."
|
|
)
|
|
|
|
if (self.enable_eplb or (self.init_expert_location is not None)) and (
|
|
self.ep_dispatch_algorithm is None
|
|
):
|
|
self.ep_dispatch_algorithm = "static"
|
|
|
|
if self.enable_eplb:
|
|
assert self.ep_size > 1
|
|
|
|
if self.enable_expert_distribution_metrics and (
|
|
self.expert_distribution_recorder_mode is None
|
|
):
|
|
self.expert_distribution_recorder_mode = "stat"
|
|
|
|
if self.expert_distribution_recorder_buffer_size is None:
|
|
if (x := self.eplb_rebalance_num_iterations) is not None:
|
|
self.expert_distribution_recorder_buffer_size = x
|
|
elif self.expert_distribution_recorder_mode is not None:
|
|
self.expert_distribution_recorder_buffer_size = 1000
|
|
|
|
# Pipeline parallelism
|
|
if self.pp_size > 1:
|
|
self.disable_overlap_schedule = True
|
|
logger.warning(
|
|
"Pipeline parallelism is incompatible with overlap schedule."
|
|
)
|
|
|
|
# Hicache
|
|
if self.hicache_storage_backend == "mooncake":
|
|
# to use mooncake storage backend, the following conditions must be met:
|
|
self.hicache_io_backend = "kernel"
|
|
self.hicache_mem_layout = "page_first"
|
|
|
|
# Speculative Decoding
|
|
if self.speculative_algorithm == "NEXTN":
|
|
# NEXTN shares the same implementation of EAGLE
|
|
self.speculative_algorithm = "EAGLE"
|
|
|
|
if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"):
|
|
if self.speculative_algorithm == "STANDALONE":
|
|
# TODO: support dp attention for standalone speculative decoding
|
|
assert (
|
|
self.enable_dp_attention is False
|
|
), "Currently standalone speculative decoding does not support dp attention."
|
|
if self.max_running_requests is None:
|
|
self.max_running_requests = 48
|
|
self.disable_overlap_schedule = True
|
|
logger.warning(
|
|
"Overlap scheduler is disabled because of using "
|
|
"eagle speculative decoding."
|
|
)
|
|
if self.enable_mixed_chunk:
|
|
self.enable_mixed_chunk = False
|
|
logger.warning(
|
|
"Mixed chunked prefill is disabled because of using "
|
|
"eagle speculative decoding."
|
|
)
|
|
|
|
model_arch = self.get_hf_config().architectures[0]
|
|
if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]:
|
|
# Auto set draft_model_path DeepSeek-V3/R1
|
|
if self.speculative_draft_model_path is None:
|
|
self.speculative_draft_model_path = self.model_path
|
|
else:
|
|
logger.warning(
|
|
"DeepSeek MTP does not require setting speculative_draft_model_path."
|
|
)
|
|
|
|
# Auto choose parameters
|
|
if self.speculative_num_steps is None:
|
|
assert (
|
|
self.speculative_eagle_topk is None
|
|
and self.speculative_num_draft_tokens is None
|
|
)
|
|
(
|
|
self.speculative_num_steps,
|
|
self.speculative_eagle_topk,
|
|
self.speculative_num_draft_tokens,
|
|
) = auto_choose_speculative_params(self)
|
|
|
|
if (
|
|
self.attention_backend == "trtllm_mha"
|
|
or self.decode_attention_backend == "trtllm_mha"
|
|
or self.prefill_attention_backend == "trtllm_mha"
|
|
):
|
|
if self.speculative_eagle_topk > 1:
|
|
raise ValueError(
|
|
"trtllm_mha backend only supports topk = 1 for speculative decoding."
|
|
)
|
|
|
|
if (
|
|
self.speculative_eagle_topk == 1
|
|
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
|
|
):
|
|
logger.warning(
|
|
"speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
|
|
)
|
|
self.speculative_num_draft_tokens = self.speculative_num_steps + 1
|
|
|
|
if (
|
|
self.speculative_eagle_topk > 1
|
|
and self.page_size > 1
|
|
and self.attention_backend != "flashinfer"
|
|
):
|
|
raise ValueError(
|
|
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
|
|
)
|
|
|
|
# The token generated from the verify step is counted.
|
|
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
|
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
|
|
|
# GGUF
|
|
if (
|
|
self.load_format == "auto" or self.load_format == "gguf"
|
|
) and check_gguf_file(self.model_path):
|
|
self.quantization = self.load_format = "gguf"
|
|
|
|
# Model loading
|
|
if is_remote_url(self.model_path):
|
|
self.load_format = "remote"
|
|
if self.custom_weight_loader is None:
|
|
self.custom_weight_loader = []
|
|
|
|
# PD disaggregation
|
|
if self.disaggregation_mode == "decode":
|
|
assert (
|
|
self.disaggregation_decode_tp is None
|
|
), "Cannot set --disaggregation-decode-tp for the decode engine."
|
|
assert (
|
|
self.disaggregation_decode_dp is None
|
|
), "Cannot set --disaggregation-decode-dp for the decode engine."
|
|
|
|
self.disable_radix_cache = True
|
|
logger.warning("KV cache is forced as chunk cache for decode server")
|
|
|
|
if self.dp_size > 1 and not is_in_ci():
|
|
assert self.prefill_round_robin_balance, (
|
|
"Prefill round robin balance is required when dp size > 1. "
|
|
"Please make sure that the prefill instance is launched with `--load-balance-method round_robin`"
|
|
" and `--prefill-round-robin-balance` is set for decode server."
|
|
)
|
|
elif self.disaggregation_mode == "prefill":
|
|
if self.disaggregation_decode_tp is None:
|
|
self.disaggregation_decode_tp = self.tp_size
|
|
if self.disaggregation_decode_dp is None:
|
|
self.disaggregation_decode_dp = self.dp_size
|
|
|
|
self.disaggregation_prefill_pp = self.pp_size
|
|
self.validate_disagg_tp_size(self.tp_size, self.disaggregation_decode_tp)
|
|
|
|
self.disable_cuda_graph = True
|
|
logger.warning("Cuda graph is disabled for prefill server")
|
|
|
|
# Propagate env vars
|
|
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
|
"1" if self.enable_torch_compile else "0"
|
|
)
|
|
os.environ["SGLANG_MAMBA_SSM_DTYPE"] = self.mamba_ssm_dtype
|
|
|
|
# Set env var before grammar backends init
|
|
os.environ["SGLANG_DISABLE_OUTLINES_DISK_CACHE"] = (
|
|
"1" if self.disable_outlines_disk_cache else "0"
|
|
)
|
|
|
|
if self.enable_hierarchical_cache and self.disable_radix_cache:
|
|
raise ValueError(
|
|
"The arguments enable-hierarchical-cache and disable-radix-cache are mutually exclusive "
|
|
"and cannot be used at the same time. Please use only one of them."
|
|
)
|
|
|
|
@staticmethod
|
|
def add_cli_args(parser: argparse.ArgumentParser):
|
|
# Model and tokenizer
|
|
parser.add_argument(
|
|
"--model-path",
|
|
"--model",
|
|
type=str,
|
|
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
|
|
required=True,
|
|
)
|
|
parser.add_argument(
|
|
"--tokenizer-path",
|
|
type=str,
|
|
default=ServerArgs.tokenizer_path,
|
|
help="The path of the tokenizer.",
|
|
)
|
|
parser.add_argument(
|
|
"--tokenizer-mode",
|
|
type=str,
|
|
default=ServerArgs.tokenizer_mode,
|
|
choices=["auto", "slow"],
|
|
help="Tokenizer mode. 'auto' will use the fast "
|
|
"tokenizer if available, and 'slow' will "
|
|
"always use the slow tokenizer.",
|
|
)
|
|
parser.add_argument(
|
|
"--tokenizer-worker-num",
|
|
type=int,
|
|
default=ServerArgs.tokenizer_worker_num,
|
|
help="The worker num of the tokenizer manager.",
|
|
)
|
|
parser.add_argument(
|
|
"--skip-tokenizer-init",
|
|
action="store_true",
|
|
help="If set, skip init tokenizer and pass input_ids in generate request.",
|
|
)
|
|
parser.add_argument(
|
|
"--load-format",
|
|
type=str,
|
|
default=ServerArgs.load_format,
|
|
choices=LOAD_FORMAT_CHOICES,
|
|
help="The format of the model weights to load. "
|
|
'"auto" will try to load the weights in the safetensors format '
|
|
"and fall back to the pytorch bin format if safetensors format "
|
|
"is not available. "
|
|
'"pt" will load the weights in the pytorch bin format. '
|
|
'"safetensors" will load the weights in the safetensors format. '
|
|
'"npcache" will load the weights in pytorch format and store '
|
|
"a numpy cache to speed up the loading. "
|
|
'"dummy" will initialize the weights with random values, '
|
|
"which is mainly for profiling."
|
|
'"gguf" will load the weights in the gguf format. '
|
|
'"bitsandbytes" will load the weights using bitsandbytes '
|
|
"quantization."
|
|
'"layered" loads weights layer by layer so that one can quantize a '
|
|
"layer before loading another to make the peak memory envelope "
|
|
"smaller.",
|
|
)
|
|
parser.add_argument(
|
|
"--model-loader-extra-config",
|
|
type=str,
|
|
help="Extra config for model loader. "
|
|
"This will be passed to the model loader corresponding to the chosen load_format.",
|
|
default=ServerArgs.model_loader_extra_config,
|
|
)
|
|
parser.add_argument(
|
|
"--trust-remote-code",
|
|
action="store_true",
|
|
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
|
)
|
|
parser.add_argument(
|
|
"--context-length",
|
|
type=int,
|
|
default=ServerArgs.context_length,
|
|
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
|
)
|
|
parser.add_argument(
|
|
"--is-embedding",
|
|
action="store_true",
|
|
help="Whether to use a CausalLM as an embedding model.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-multimodal",
|
|
default=ServerArgs.enable_multimodal,
|
|
action="store_true",
|
|
help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
|
|
)
|
|
parser.add_argument(
|
|
"--revision",
|
|
type=str,
|
|
default=None,
|
|
help="The specific model version to use. It can be a branch "
|
|
"name, a tag name, or a commit id. If unspecified, will use "
|
|
"the default version.",
|
|
)
|
|
parser.add_argument(
|
|
"--model-impl",
|
|
type=str,
|
|
default=ServerArgs.model_impl,
|
|
help="Which implementation of the model to use.\n\n"
|
|
'* "auto" will try to use the SGLang implementation if it exists '
|
|
"and fall back to the Transformers implementation if no SGLang "
|
|
"implementation is available.\n"
|
|
'* "sglang" will use the SGLang model implementation.\n'
|
|
'* "transformers" will use the Transformers model '
|
|
"implementation.\n",
|
|
)
|
|
|
|
# HTTP server
|
|
parser.add_argument(
|
|
"--host",
|
|
type=str,
|
|
default=ServerArgs.host,
|
|
help="The host of the HTTP server.",
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=ServerArgs.port,
|
|
help="The port of the HTTP server.",
|
|
)
|
|
parser.add_argument(
|
|
"--skip-server-warmup",
|
|
action="store_true",
|
|
help="If set, skip warmup.",
|
|
)
|
|
parser.add_argument(
|
|
"--warmups",
|
|
type=str,
|
|
required=False,
|
|
help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
|
|
"will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
|
|
)
|
|
parser.add_argument(
|
|
"--nccl-port",
|
|
type=int,
|
|
default=ServerArgs.nccl_port,
|
|
help="The port for NCCL distributed environment setup. Defaults to a random port.",
|
|
)
|
|
|
|
# Quantization and data type
|
|
parser.add_argument(
|
|
"--dtype",
|
|
type=str,
|
|
default=ServerArgs.dtype,
|
|
choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
|
|
help="Data type for model weights and activations.\n\n"
|
|
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
|
|
"BF16 precision for BF16 models.\n"
|
|
'* "half" for FP16. Recommended for AWQ quantization.\n'
|
|
'* "float16" is the same as "half".\n'
|
|
'* "bfloat16" for a balance between precision and range.\n'
|
|
'* "float" is shorthand for FP32 precision.\n'
|
|
'* "float32" for FP32 precision.',
|
|
)
|
|
parser.add_argument(
|
|
"--quantization",
|
|
type=str,
|
|
default=ServerArgs.quantization,
|
|
choices=QUANTIZATION_CHOICES,
|
|
help="The quantization method.",
|
|
)
|
|
parser.add_argument(
|
|
"--quantization-param-path",
|
|
type=nullable_str,
|
|
default=None,
|
|
help="Path to the JSON file containing the KV cache "
|
|
"scaling factors. This should generally be supplied, when "
|
|
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
|
|
"default to 1.0, which may cause accuracy issues. ",
|
|
)
|
|
parser.add_argument(
|
|
"--kv-cache-dtype",
|
|
type=str,
|
|
default=ServerArgs.kv_cache_dtype,
|
|
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
|
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
|
)
|
|
|
|
# Memory and scheduling
|
|
parser.add_argument(
|
|
"--mem-fraction-static",
|
|
type=float,
|
|
default=ServerArgs.mem_fraction_static,
|
|
help="The fraction of the memory used for static allocation (model weights and KV cache memory pool). Use a smaller value if you see out-of-memory errors.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-running-requests",
|
|
type=int,
|
|
default=ServerArgs.max_running_requests,
|
|
help="The maximum number of running requests.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-queued-requests",
|
|
type=int,
|
|
default=ServerArgs.max_queued_requests,
|
|
help="The maximum number of queued requests. This option is ignored when using disaggregation-mode.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-total-tokens",
|
|
type=int,
|
|
default=ServerArgs.max_total_tokens,
|
|
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
|
|
"This option is typically used for development and debugging purposes.",
|
|
)
|
|
parser.add_argument(
|
|
"--chunked-prefill-size",
|
|
type=int,
|
|
default=ServerArgs.chunked_prefill_size,
|
|
help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-prefill-tokens",
|
|
type=int,
|
|
default=ServerArgs.max_prefill_tokens,
|
|
help="The maximum number of tokens in a prefill batch. The real bound will be the maximum of this value and the model's maximum context length.",
|
|
)
|
|
parser.add_argument(
|
|
"--schedule-policy",
|
|
type=str,
|
|
default=ServerArgs.schedule_policy,
|
|
choices=["lpm", "random", "fcfs", "dfs-weight", "lof"],
|
|
help="The scheduling policy of the requests.",
|
|
)
|
|
parser.add_argument(
|
|
"--schedule-conservativeness",
|
|
type=float,
|
|
default=ServerArgs.schedule_conservativeness,
|
|
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
|
)
|
|
parser.add_argument(
|
|
"--page-size",
|
|
type=int,
|
|
default=ServerArgs.page_size,
|
|
help="The number of tokens in a page.",
|
|
)
|
|
parser.add_argument(
|
|
"--hybrid-kvcache-ratio",
|
|
nargs="?",
|
|
const=0.5,
|
|
type=float,
|
|
default=ServerArgs.hybrid_kvcache_ratio,
|
|
help=(
|
|
"Mix ratio in [0,1] between uniform and hybrid kv buffers "
|
|
"(0.0 = pure uniform: swa_size / full_size = 1)"
|
|
"(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--swa-full-tokens-ratio",
|
|
type=float,
|
|
default=ServerArgs.swa_full_tokens_ratio,
|
|
help="The ratio of SWA layer KV tokens / full layer KV tokens, regardless of the number of swa:full layers. It should be between 0 and 1. "
|
|
"E.g. 0.5 means if each swa layer has 50 tokens, then each full layer has 100 tokens.",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-hybrid-swa-memory",
|
|
action="store_true",
|
|
help="Disable the hybrid SWA memory.",
|
|
)
|
|
|
|
# Runtime options
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
default=ServerArgs.device,
|
|
help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
|
|
)
|
|
parser.add_argument(
|
|
"--tensor-parallel-size",
|
|
"--tp-size",
|
|
type=int,
|
|
default=ServerArgs.tp_size,
|
|
help="The tensor parallelism size.",
|
|
)
|
|
parser.add_argument(
|
|
"--pipeline-parallel-size",
|
|
"--pp-size",
|
|
type=int,
|
|
default=ServerArgs.pp_size,
|
|
help="The pipeline parallelism size.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-micro-batch-size",
|
|
type=int,
|
|
default=ServerArgs.max_micro_batch_size,
|
|
help="The maximum micro batch size in pipeline parallelism.",
|
|
)
|
|
parser.add_argument(
|
|
"--stream-interval",
|
|
type=int,
|
|
default=ServerArgs.stream_interval,
|
|
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
|
|
)
|
|
parser.add_argument(
|
|
"--stream-output",
|
|
action="store_true",
|
|
help="Whether to output as a sequence of disjoint segments.",
|
|
)
|
|
parser.add_argument(
|
|
"--random-seed",
|
|
type=int,
|
|
default=ServerArgs.random_seed,
|
|
help="The random seed.",
|
|
)
|
|
parser.add_argument(
|
|
"--constrained-json-whitespace-pattern",
|
|
type=str,
|
|
default=ServerArgs.constrained_json_whitespace_pattern,
|
|
help="(outlines backend only) Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
|
|
)
|
|
parser.add_argument(
|
|
"--watchdog-timeout",
|
|
type=float,
|
|
default=ServerArgs.watchdog_timeout,
|
|
help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
|
|
)
|
|
parser.add_argument(
|
|
"--dist-timeout",
|
|
type=int,
|
|
default=ServerArgs.dist_timeout,
|
|
help="Set timeout for torch.distributed initialization.",
|
|
)
|
|
parser.add_argument(
|
|
"--download-dir",
|
|
type=str,
|
|
default=ServerArgs.download_dir,
|
|
help="Model download directory for huggingface.",
|
|
)
|
|
parser.add_argument(
|
|
"--base-gpu-id",
|
|
type=int,
|
|
default=ServerArgs.base_gpu_id,
|
|
help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
|
|
)
|
|
parser.add_argument(
|
|
"--gpu-id-step",
|
|
type=int,
|
|
default=ServerArgs.gpu_id_step,
|
|
help="The delta between consecutive GPU IDs that are used. For example, setting it to 2 will use GPU 0,2,4,...",
|
|
)
|
|
parser.add_argument(
|
|
"--sleep-on-idle",
|
|
action="store_true",
|
|
help="Reduce CPU usage when sglang is idle.",
|
|
)
|
|
|
|
# Logging
|
|
parser.add_argument(
|
|
"--log-level",
|
|
type=str,
|
|
default=ServerArgs.log_level,
|
|
help="The logging level of all loggers.",
|
|
)
|
|
parser.add_argument(
|
|
"--log-level-http",
|
|
type=str,
|
|
default=ServerArgs.log_level_http,
|
|
help="The logging level of HTTP server. If not set, reuse --log-level by default.",
|
|
)
|
|
parser.add_argument(
|
|
"--log-requests",
|
|
action="store_true",
|
|
help="Log metadata, inputs, outputs of all requests. The verbosity is decided by --log-requests-level",
|
|
)
|
|
parser.add_argument(
|
|
"--log-requests-level",
|
|
type=int,
|
|
default=ServerArgs.log_requests_level,
|
|
help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
|
|
choices=[0, 1, 2, 3],
|
|
)
|
|
parser.add_argument(
|
|
"--crash-dump-folder",
|
|
type=str,
|
|
default=ServerArgs.crash_dump_folder,
|
|
help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
|
|
)
|
|
parser.add_argument(
|
|
"--show-time-cost",
|
|
action="store_true",
|
|
help="Show time cost of custom marks.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-metrics",
|
|
action="store_true",
|
|
help="Enable log prometheus metrics.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-metrics-for-all-schedulers",
|
|
action="store_true",
|
|
help="Enable --enable-metrics-for-all-schedulers when you want schedulers on all TP ranks (not just TP 0) "
|
|
"to record request metrics separately. This is especially useful when dp_attention is enabled, as "
|
|
"otherwise all metrics appear to come from TP 0.",
|
|
)
|
|
parser.add_argument(
|
|
"--bucket-time-to-first-token",
|
|
type=float,
|
|
nargs="+",
|
|
default=ServerArgs.bucket_time_to_first_token,
|
|
help="The buckets of time to first token, specified as a list of floats.",
|
|
)
|
|
parser.add_argument(
|
|
"--bucket-inter-token-latency",
|
|
type=float,
|
|
nargs="+",
|
|
default=ServerArgs.bucket_inter_token_latency,
|
|
help="The buckets of inter-token latency, specified as a list of floats.",
|
|
)
|
|
parser.add_argument(
|
|
"--bucket-e2e-request-latency",
|
|
type=float,
|
|
nargs="+",
|
|
default=ServerArgs.bucket_e2e_request_latency,
|
|
help="The buckets of end-to-end request latency, specified as a list of floats.",
|
|
)
|
|
parser.add_argument(
|
|
"--collect-tokens-histogram",
|
|
action="store_true",
|
|
default=ServerArgs.collect_tokens_histogram,
|
|
help="Collect prompt/generation tokens histogram.",
|
|
)
|
|
bucket_rule = (
|
|
"Supports 3 rule types: 'default' uses predefined buckets; 'tse <middle> <base> <count>' "
|
|
"generates two sides exponential distributed buckets (e.g., 'tse 1000 2 8' generates buckets "
|
|
"[984.0, 992.0, 996.0, 998.0, 1000.0, 1002.0, 1004.0, 1008.0, 1016.0]).); 'customer <value1> "
|
|
"<value2> ...' uses custom bucket values (e.g., 'customer 10 50 100 500')."
|
|
)
|
|
parser.add_argument(
|
|
"--prompt-tokens-buckets",
|
|
type=str,
|
|
nargs="+",
|
|
default=ServerArgs.prompt_tokens_buckets,
|
|
help=f"The buckets rule of prompt tokens. {bucket_rule}",
|
|
)
|
|
parser.add_argument(
|
|
"--generation-tokens-buckets",
|
|
type=str,
|
|
nargs="+",
|
|
default=ServerArgs.generation_tokens_buckets,
|
|
help=f"The buckets rule for generation tokens histogram. {bucket_rule}",
|
|
)
|
|
parser.add_argument(
|
|
"--gc-warning-threshold-secs",
|
|
type=float,
|
|
default=ServerArgs.gc_warning_threshold_secs,
|
|
help="The threshold for long GC warning. If a GC takes longer than this, a warning will be logged. Set to 0 to disable.",
|
|
)
|
|
parser.add_argument(
|
|
"--decode-log-interval",
|
|
type=int,
|
|
default=ServerArgs.decode_log_interval,
|
|
help="The log interval of decode batch.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-request-time-stats-logging",
|
|
action="store_true",
|
|
default=ServerArgs.enable_request_time_stats_logging,
|
|
help="Enable per request time stats logging",
|
|
)
|
|
parser.add_argument(
|
|
"--kv-events-config",
|
|
type=str,
|
|
default=None,
|
|
help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
|
|
)
|
|
|
|
# API related
|
|
parser.add_argument(
|
|
"--api-key",
|
|
type=str,
|
|
default=ServerArgs.api_key,
|
|
help="Set API key of the server. It is also used in the OpenAI API compatible server.",
|
|
)
|
|
parser.add_argument(
|
|
"--served-model-name",
|
|
type=str,
|
|
default=ServerArgs.served_model_name,
|
|
help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
|
|
)
|
|
parser.add_argument(
|
|
"--weight-version",
|
|
type=str,
|
|
default=ServerArgs.weight_version,
|
|
help="Version identifier for the model weights. Defaults to 'default' if not specified.",
|
|
)
|
|
parser.add_argument(
|
|
"--chat-template",
|
|
type=str,
|
|
default=ServerArgs.chat_template,
|
|
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
|
|
)
|
|
parser.add_argument(
|
|
"--completion-template",
|
|
type=str,
|
|
default=ServerArgs.completion_template,
|
|
help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
|
|
)
|
|
parser.add_argument(
|
|
"--file-storage-path",
|
|
type=str,
|
|
default=ServerArgs.file_storage_path,
|
|
help="The path of the file storage in backend.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-cache-report",
|
|
action="store_true",
|
|
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
|
)
|
|
parser.add_argument(
|
|
"--reasoning-parser",
|
|
type=str,
|
|
choices=list(ReasoningParser.DetectorMap.keys()),
|
|
default=ServerArgs.reasoning_parser,
|
|
help=f"Specify the parser for reasoning models, supported parsers are: {list(ReasoningParser.DetectorMap.keys())}.",
|
|
)
|
|
tool_call_parser_choices = list(FunctionCallParser.ToolCallParserEnum.keys())
|
|
parser.add_argument(
|
|
"--tool-call-parser",
|
|
type=str,
|
|
choices=tool_call_parser_choices,
|
|
default=ServerArgs.tool_call_parser,
|
|
help=f"Specify the parser for handling tool-call interactions. Options include: {tool_call_parser_choices}.",
|
|
)
|
|
parser.add_argument(
|
|
"--tool-server",
|
|
type=str,
|
|
default=None,
|
|
help="Either 'demo' or a comma-separated list of tool server urls to use for the model. If not specified, no tool server will be used.",
|
|
)
|
|
|
|
# Data parallelism
|
|
parser.add_argument(
|
|
"--data-parallel-size",
|
|
"--dp-size",
|
|
type=int,
|
|
default=ServerArgs.dp_size,
|
|
help="The data parallelism size.",
|
|
)
|
|
parser.add_argument(
|
|
"--load-balance-method",
|
|
type=str,
|
|
default=ServerArgs.load_balance_method,
|
|
help="The load balancing strategy for data parallelism.",
|
|
choices=[
|
|
"round_robin",
|
|
"shortest_queue",
|
|
"minimum_tokens",
|
|
],
|
|
)
|
|
parser.add_argument(
|
|
"--prefill-round-robin-balance",
|
|
default=ServerArgs.prefill_round_robin_balance,
|
|
action="store_true",
|
|
help="Prefill is round robin balanced. This is used to promise decode server can get the correct dp rank.",
|
|
)
|
|
|
|
# Multi-node distributed serving
|
|
parser.add_argument(
|
|
"--dist-init-addr",
|
|
"--nccl-init-addr", # For backward compatibility. This will be removed in the future.
|
|
type=str,
|
|
help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
|
|
)
|
|
parser.add_argument(
|
|
"--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
|
|
)
|
|
parser.add_argument(
|
|
"--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
|
|
)
|
|
|
|
# Model override args
|
|
parser.add_argument(
|
|
"--json-model-override-args",
|
|
type=str,
|
|
help="A dictionary in JSON string format used to override default model configurations.",
|
|
default=ServerArgs.json_model_override_args,
|
|
)
|
|
parser.add_argument(
|
|
"--preferred-sampling-params",
|
|
type=str,
|
|
help="json-formatted sampling settings that will be returned in /get_model_info",
|
|
)
|
|
|
|
# LoRA
|
|
parser.add_argument(
|
|
"--enable-lora",
|
|
default=ServerArgs.enable_lora,
|
|
action="store_true",
|
|
help="Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-lora-rank",
|
|
default=ServerArgs.max_lora_rank,
|
|
type=int,
|
|
help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
|
|
)
|
|
parser.add_argument(
|
|
"--lora-target-modules",
|
|
type=str,
|
|
choices=SUPPORTED_LORA_TARGET_MODULES + [LORA_TARGET_ALL_MODULES],
|
|
nargs="*",
|
|
default=None,
|
|
help="The union set of all target modules where LoRA should be applied. If not specified, "
|
|
"it will be automatically inferred from the adapters provided in --lora-paths. If 'all' is specified, "
|
|
"all supported modules will be targeted.",
|
|
)
|
|
parser.add_argument(
|
|
"--lora-paths",
|
|
type=str,
|
|
nargs="*",
|
|
default=None,
|
|
action=LoRAPathAction,
|
|
help='The list of LoRA adapters to load. Each adapter must be specified in one of the following formats: <PATH> | <NAME>=<PATH> | JSON with schema {"lora_name":str,"lora_path":str,"pinned":bool}',
|
|
)
|
|
parser.add_argument(
|
|
"--max-loras-per-batch",
|
|
type=int,
|
|
default=8,
|
|
help="Maximum number of adapters for a running batch, include base-only request.",
|
|
)
|
|
parser.add_argument(
|
|
"--max-loaded-loras",
|
|
type=int,
|
|
default=ServerArgs.max_loaded_loras,
|
|
help="If specified, it limits the maximum number of LoRA adapters loaded in CPU memory at a time. The value must be greater than or equal to `--max-loras-per-batch`.",
|
|
)
|
|
parser.add_argument(
|
|
"--lora-backend",
|
|
type=str,
|
|
default="triton",
|
|
help="Choose the kernel backend for multi-LoRA serving.",
|
|
)
|
|
|
|
# Kernel backend
|
|
parser.add_argument(
|
|
"--attention-backend",
|
|
type=str,
|
|
choices=ATTENTION_BACKEND_CHOICES,
|
|
default=ServerArgs.attention_backend,
|
|
help="Choose the kernels for attention layers.",
|
|
)
|
|
parser.add_argument(
|
|
"--prefill-attention-backend",
|
|
type=str,
|
|
choices=ATTENTION_BACKEND_CHOICES,
|
|
default=ServerArgs.prefill_attention_backend,
|
|
help="Choose the kernels for prefill attention layers (have priority over --attention-backend).",
|
|
)
|
|
parser.add_argument(
|
|
"--decode-attention-backend",
|
|
type=str,
|
|
choices=ATTENTION_BACKEND_CHOICES,
|
|
default=ServerArgs.decode_attention_backend,
|
|
help="Choose the kernels for decode attention layers (have priority over --attention-backend).",
|
|
)
|
|
parser.add_argument(
|
|
"--sampling-backend",
|
|
type=str,
|
|
choices=["flashinfer", "pytorch"],
|
|
default=ServerArgs.sampling_backend,
|
|
help="Choose the kernels for sampling layers.",
|
|
)
|
|
parser.add_argument(
|
|
"--grammar-backend",
|
|
type=str,
|
|
choices=GRAMMAR_BACKEND_CHOICES,
|
|
default=ServerArgs.grammar_backend,
|
|
help="Choose the backend for grammar-guided decoding.",
|
|
)
|
|
parser.add_argument(
|
|
"--mm-attention-backend",
|
|
type=str,
|
|
choices=["sdpa", "fa3", "triton_attn"],
|
|
default=ServerArgs.mm_attention_backend,
|
|
help="Set multimodal attention backend.",
|
|
)
|
|
|
|
# Speculative decoding
|
|
parser.add_argument(
|
|
"--speculative-algorithm",
|
|
type=str,
|
|
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE"],
|
|
help="Speculative algorithm.",
|
|
)
|
|
parser.add_argument(
|
|
"--speculative-draft-model-path",
|
|
"--speculative-draft-model",
|
|
type=str,
|
|
help="The path of the draft model weights. This can be a local folder or a Hugging Face repo ID.",
|
|
)
|
|
parser.add_argument(
|
|
"--speculative-draft-model-revision",
|
|
type=str,
|
|
default=None,
|
|
help="The specific draft model version to use. It can be a branch "
|
|
"name, a tag name, or a commit id. If unspecified, will use "
|
|
"the default version.",
|
|
)
|
|
parser.add_argument(
|
|
"--speculative-num-steps",
|
|
type=int,
|
|
help="The number of steps sampled from draft model in Speculative Decoding.",
|
|
default=ServerArgs.speculative_num_steps,
|
|
)
|
|
parser.add_argument(
|
|
"--speculative-eagle-topk",
|
|
type=int,
|
|
help="The number of tokens sampled from the draft model in eagle2 each step.",
|
|
default=ServerArgs.speculative_eagle_topk,
|
|
)
|
|
parser.add_argument(
|
|
"--speculative-num-draft-tokens",
|
|
type=int,
|
|
help="The number of tokens sampled from the draft model in Speculative Decoding.",
|
|
default=ServerArgs.speculative_num_draft_tokens,
|
|
)
|
|
parser.add_argument(
|
|
"--speculative-accept-threshold-single",
|
|
type=float,
|
|
help="Accept a draft token if its probability in the target model is greater than this threshold.",
|
|
default=ServerArgs.speculative_accept_threshold_single,
|
|
)
|
|
parser.add_argument(
|
|
"--speculative-accept-threshold-acc",
|
|
type=float,
|
|
help="The accept probability of a draft token is raised from its target probability p to min(1, p / threshold_acc).",
|
|
default=ServerArgs.speculative_accept_threshold_acc,
|
|
)
|
|
parser.add_argument(
|
|
"--speculative-token-map",
|
|
type=str,
|
|
help="The path of the draft model's small vocab table.",
|
|
default=ServerArgs.speculative_token_map,
|
|
)
|
|
parser.add_argument(
|
|
"--speculative-attention-mode",
|
|
type=str,
|
|
choices=["prefill", "decode"],
|
|
help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
|
|
default=ServerArgs.speculative_attention_mode,
|
|
)
|
|
|
|
# Expert parallelism
|
|
parser.add_argument(
|
|
"--expert-parallel-size",
|
|
"--ep-size",
|
|
"--ep",
|
|
type=int,
|
|
default=ServerArgs.ep_size,
|
|
help="The expert parallelism size.",
|
|
)
|
|
parser.add_argument(
|
|
"--moe-a2a-backend",
|
|
type=str,
|
|
choices=["none", "deepep"],
|
|
default=ServerArgs.moe_a2a_backend,
|
|
help="Choose the backend for MoE A2A.",
|
|
)
|
|
parser.add_argument(
|
|
"--moe-runner-backend",
|
|
type=str,
|
|
choices=[
|
|
"auto",
|
|
"triton",
|
|
"triton_kernel",
|
|
"flashinfer_trtllm",
|
|
"flashinfer_cutlass",
|
|
"flashinfer_mxfp4",
|
|
],
|
|
default=ServerArgs.moe_runner_backend,
|
|
help="Choose the runner backend for MoE.",
|
|
)
|
|
parser.add_argument(
|
|
"--flashinfer-mxfp4-moe-precision",
|
|
type=str,
|
|
choices=["default", "bf16"],
|
|
default=ServerArgs.flashinfer_mxfp4_moe_precision,
|
|
help="Choose the computation precision of flashinfer mxfp4 moe",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-flashinfer-allreduce-fusion",
|
|
action="store_true",
|
|
help="Enable FlashInfer allreduce fusion with Residual RMSNorm.",
|
|
)
|
|
parser.add_argument(
|
|
"--deepep-mode",
|
|
type=str,
|
|
choices=["normal", "low_latency", "auto"],
|
|
default="auto",
|
|
help="Select the mode when enable DeepEP MoE, could be `normal`, `low_latency` or `auto`. Default is `auto`, which means `low_latency` for decode batch and `normal` for prefill batch.",
|
|
)
|
|
parser.add_argument(
|
|
"--ep-num-redundant-experts",
|
|
type=int,
|
|
default=ServerArgs.ep_num_redundant_experts,
|
|
help="Allocate this number of redundant experts in expert parallel.",
|
|
)
|
|
parser.add_argument(
|
|
"--ep-dispatch-algorithm",
|
|
type=str,
|
|
default=ServerArgs.ep_dispatch_algorithm,
|
|
help="The algorithm to choose ranks for redundant experts in expert parallel.",
|
|
)
|
|
parser.add_argument(
|
|
"--init-expert-location",
|
|
type=str,
|
|
default=ServerArgs.init_expert_location,
|
|
help="Initial location of EP experts.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-eplb",
|
|
action="store_true",
|
|
help="Enable EPLB algorithm",
|
|
)
|
|
parser.add_argument(
|
|
"--eplb-algorithm",
|
|
type=str,
|
|
default=ServerArgs.eplb_algorithm,
|
|
help="Chosen EPLB algorithm",
|
|
)
|
|
parser.add_argument(
|
|
"--eplb-rebalance-num-iterations",
|
|
type=int,
|
|
default=ServerArgs.eplb_rebalance_num_iterations,
|
|
help="Number of iterations to automatically trigger a EPLB re-balance.",
|
|
)
|
|
parser.add_argument(
|
|
"--eplb-rebalance-layers-per-chunk",
|
|
type=int,
|
|
default=ServerArgs.eplb_rebalance_layers_per_chunk,
|
|
help="Number of layers to rebalance per forward pass.",
|
|
)
|
|
parser.add_argument(
|
|
"--eplb-min-rebalancing-utilization-threshold",
|
|
type=float,
|
|
default=ServerArgs.eplb_min_rebalancing_utilization_threshold,
|
|
help="Minimum threshold for GPU average utilization to trigger EPLB rebalancing. Must be in the range [0.0, 1.0].",
|
|
)
|
|
parser.add_argument(
|
|
"--expert-distribution-recorder-mode",
|
|
type=str,
|
|
default=ServerArgs.expert_distribution_recorder_mode,
|
|
help="Mode of expert distribution recorder.",
|
|
)
|
|
parser.add_argument(
|
|
"--expert-distribution-recorder-buffer-size",
|
|
type=int,
|
|
default=ServerArgs.expert_distribution_recorder_buffer_size,
|
|
help="Circular buffer size of expert distribution recorder. Set to -1 to denote infinite buffer.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-expert-distribution-metrics",
|
|
action="store_true",
|
|
help="Enable logging metrics for expert balancedness",
|
|
)
|
|
parser.add_argument(
|
|
"--deepep-config",
|
|
type=str,
|
|
default=ServerArgs.deepep_config,
|
|
help="Tuned DeepEP config suitable for your own cluster. It can be either a string with JSON content or a file path.",
|
|
)
|
|
parser.add_argument(
|
|
"--moe-dense-tp-size",
|
|
type=int,
|
|
default=ServerArgs.moe_dense_tp_size,
|
|
help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
|
|
)
|
|
|
|
# Mamba Cache
|
|
parser.add_argument(
|
|
"--max-mamba-cache-size",
|
|
type=int,
|
|
default=ServerArgs.max_mamba_cache_size,
|
|
help="The maximum size of the mamba cache.",
|
|
)
|
|
parser.add_argument(
|
|
"--mamba-ssm-dtype",
|
|
type=str,
|
|
default=ServerArgs.mamba_ssm_dtype,
|
|
choices=["float32", "bfloat16"],
|
|
help="The data type of the SSM states in mamba cache.",
|
|
)
|
|
|
|
# Hierarchical cache
|
|
parser.add_argument(
|
|
"--enable-hierarchical-cache",
|
|
action="store_true",
|
|
help="Enable hierarchical cache",
|
|
)
|
|
parser.add_argument(
|
|
"--hicache-ratio",
|
|
type=float,
|
|
default=ServerArgs.hicache_ratio,
|
|
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
|
|
)
|
|
parser.add_argument(
|
|
"--hicache-size",
|
|
type=int,
|
|
default=ServerArgs.hicache_size,
|
|
help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
|
|
)
|
|
parser.add_argument(
|
|
"--hicache-write-policy",
|
|
type=str,
|
|
choices=["write_back", "write_through", "write_through_selective"],
|
|
default=ServerArgs.hicache_write_policy,
|
|
help="The write policy of hierarchical cache.",
|
|
)
|
|
parser.add_argument(
|
|
"--hicache-io-backend",
|
|
type=str,
|
|
choices=["direct", "kernel"],
|
|
default=ServerArgs.hicache_io_backend,
|
|
help="The IO backend for KV cache transfer between CPU and GPU",
|
|
)
|
|
parser.add_argument(
|
|
"--hicache-mem-layout",
|
|
type=str,
|
|
choices=["layer_first", "page_first"],
|
|
default=ServerArgs.hicache_mem_layout,
|
|
help="The layout of host memory pool for hierarchical cache.",
|
|
)
|
|
parser.add_argument(
|
|
"--hicache-storage-backend",
|
|
type=str,
|
|
choices=["file", "mooncake", "hf3fs", "nixl"],
|
|
default=ServerArgs.hicache_storage_backend,
|
|
help="The storage backend for hierarchical KV cache.",
|
|
)
|
|
parser.add_argument(
|
|
"--hicache-storage-prefetch-policy",
|
|
type=str,
|
|
choices=["best_effort", "wait_complete", "timeout"],
|
|
default=ServerArgs.hicache_storage_prefetch_policy,
|
|
help="Control when prefetching from the storage backend should stop.",
|
|
)
|
|
parser.add_argument(
|
|
"--hicache-storage-backend-extra-config",
|
|
type=str,
|
|
default=ServerArgs.hicache_storage_backend_extra_config,
|
|
help="A dictionary in JSON string format containing extra configuration for the storage backend.",
|
|
)
|
|
# LMCache
|
|
parser.add_argument(
|
|
"--enable-lmcache",
|
|
action="store_true",
|
|
help="Using LMCache as an alternative hierarchical cache solution",
|
|
)
|
|
|
|
# Double Sparsity
|
|
parser.add_argument(
|
|
"--enable-double-sparsity",
|
|
action="store_true",
|
|
help="Enable double sparsity attention",
|
|
)
|
|
parser.add_argument(
|
|
"--ds-channel-config-path",
|
|
type=str,
|
|
default=ServerArgs.ds_channel_config_path,
|
|
help="The path of the double sparsity channel config",
|
|
)
|
|
parser.add_argument(
|
|
"--ds-heavy-channel-num",
|
|
type=int,
|
|
default=ServerArgs.ds_heavy_channel_num,
|
|
help="The number of heavy channels in double sparsity attention",
|
|
)
|
|
parser.add_argument(
|
|
"--ds-heavy-token-num",
|
|
type=int,
|
|
default=ServerArgs.ds_heavy_token_num,
|
|
help="The number of heavy tokens in double sparsity attention",
|
|
)
|
|
parser.add_argument(
|
|
"--ds-heavy-channel-type",
|
|
type=str,
|
|
default=ServerArgs.ds_heavy_channel_type,
|
|
help="The type of heavy channels in double sparsity attention",
|
|
)
|
|
parser.add_argument(
|
|
"--ds-sparse-decode-threshold",
|
|
type=int,
|
|
default=ServerArgs.ds_sparse_decode_threshold,
|
|
help="The type of heavy channels in double sparsity attention",
|
|
)
|
|
|
|
# Offloading
|
|
parser.add_argument(
|
|
"--cpu-offload-gb",
|
|
type=int,
|
|
default=ServerArgs.cpu_offload_gb,
|
|
help="How many GBs of RAM to reserve for CPU offloading.",
|
|
)
|
|
parser.add_argument(
|
|
"--offload-group-size",
|
|
type=int,
|
|
default=ServerArgs.offload_group_size,
|
|
help="Number of layers per group in offloading.",
|
|
)
|
|
parser.add_argument(
|
|
"--offload-num-in-group",
|
|
type=int,
|
|
default=ServerArgs.offload_num_in_group,
|
|
help="Number of layers to be offloaded within a group.",
|
|
)
|
|
parser.add_argument(
|
|
"--offload-prefetch-step",
|
|
type=int,
|
|
default=ServerArgs.offload_prefetch_step,
|
|
help="Steps to prefetch in offloading.",
|
|
)
|
|
parser.add_argument(
|
|
"--offload-mode",
|
|
type=str,
|
|
default=ServerArgs.offload_mode,
|
|
help="Mode of offloading.",
|
|
)
|
|
|
|
# Optimization/debug options
|
|
parser.add_argument(
|
|
"--disable-radix-cache",
|
|
action="store_true",
|
|
help="Disable RadixAttention for prefix caching.",
|
|
)
|
|
parser.add_argument(
|
|
"--cuda-graph-max-bs",
|
|
type=int,
|
|
default=ServerArgs.cuda_graph_max_bs,
|
|
help="Set the maximum batch size for cuda graph. It will extend the cuda graph capture batch size to this value.",
|
|
)
|
|
parser.add_argument(
|
|
"--cuda-graph-bs",
|
|
type=int,
|
|
nargs="+",
|
|
help="Set the list of batch sizes for cuda graph.",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-cuda-graph",
|
|
action="store_true",
|
|
help="Disable cuda graph.",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-cuda-graph-padding",
|
|
action="store_true",
|
|
help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-profile-cuda-graph",
|
|
action="store_true",
|
|
help="Enable profiling of cuda graph capture.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-cudagraph-gc",
|
|
action="store_true",
|
|
help="Enable garbage collection during CUDA graph capture. If disabled (default), GC is frozen during capture to speed up the process.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-nccl-nvls",
|
|
action="store_true",
|
|
help="Enable NCCL NVLS for prefill heavy requests when available.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-symm-mem",
|
|
action="store_true",
|
|
help="Enable NCCL symmetric memory for fast collectives.",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-flashinfer-cutlass-moe-fp4-allgather",
|
|
action="store_true",
|
|
help="Disables quantize before all-gather for flashinfer cutlass moe.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-tokenizer-batch-encode",
|
|
action="store_true",
|
|
help="Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds.",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-outlines-disk-cache",
|
|
action="store_true",
|
|
help="Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency.",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-custom-all-reduce",
|
|
action="store_true",
|
|
help="Disable the custom all-reduce kernel and fall back to NCCL.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-mscclpp",
|
|
action="store_true",
|
|
help="Enable using mscclpp for small messages for all-reduce kernel and fall back to NCCL.",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-overlap-schedule",
|
|
action="store_true",
|
|
help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-mixed-chunk",
|
|
action="store_true",
|
|
help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-dp-attention",
|
|
action="store_true",
|
|
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently DeepSeek-V2 and Qwen 2/3 MoE models are supported.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-dp-lm-head",
|
|
action="store_true",
|
|
help="Enable vocabulary parallel across the attention TP group to avoid all-gather across DP groups, optimizing performance under DP attention.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-two-batch-overlap",
|
|
action="store_true",
|
|
help="Enabling two micro batches to overlap.",
|
|
)
|
|
parser.add_argument(
|
|
"--tbo-token-distribution-threshold",
|
|
type=float,
|
|
default=ServerArgs.tbo_token_distribution_threshold,
|
|
help="The threshold of token distribution between two batches in micro-batch-overlap, determines whether to two-batch-overlap or two-chunk-overlap. Set to 0 denote disable two-chunk-overlap.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-torch-compile",
|
|
action="store_true",
|
|
help="Optimize the model with torch.compile. Experimental feature.",
|
|
)
|
|
parser.add_argument(
|
|
"--torch-compile-max-bs",
|
|
type=int,
|
|
default=ServerArgs.torch_compile_max_bs,
|
|
help="Set the maximum batch size when using torch compile.",
|
|
)
|
|
parser.add_argument(
|
|
"--torchao-config",
|
|
type=str,
|
|
default=ServerArgs.torchao_config,
|
|
help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-nan-detection",
|
|
action="store_true",
|
|
help="Enable the NaN detection for debugging purposes.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-p2p-check",
|
|
action="store_true",
|
|
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
|
)
|
|
parser.add_argument(
|
|
"--triton-attention-reduce-in-fp32",
|
|
action="store_true",
|
|
help="Cast the intermediate attention results to fp32 to avoid possible crashes related to fp16."
|
|
"This only affects Triton attention kernels.",
|
|
)
|
|
parser.add_argument(
|
|
"--triton-attention-num-kv-splits",
|
|
type=int,
|
|
default=ServerArgs.triton_attention_num_kv_splits,
|
|
help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.",
|
|
)
|
|
parser.add_argument(
|
|
"--num-continuous-decode-steps",
|
|
type=int,
|
|
default=ServerArgs.num_continuous_decode_steps,
|
|
help="Run multiple continuous decoding steps to reduce scheduling overhead. "
|
|
"This can potentially increase throughput but may also increase time-to-first-token latency. "
|
|
"The default value is 1, meaning only run one decoding step at a time.",
|
|
)
|
|
parser.add_argument(
|
|
"--delete-ckpt-after-loading",
|
|
action="store_true",
|
|
help="Delete the model checkpoint after loading the model.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-memory-saver",
|
|
action="store_true",
|
|
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
|
|
)
|
|
parser.add_argument(
|
|
"--allow-auto-truncate",
|
|
action="store_true",
|
|
help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-custom-logit-processor",
|
|
action="store_true",
|
|
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
|
)
|
|
parser.add_argument(
|
|
"--flashinfer-mla-disable-ragged",
|
|
action="store_true",
|
|
help="Not using ragged prefill wrapper when running flashinfer mla",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-shared-experts-fusion",
|
|
action="store_true",
|
|
help="Disable shared experts fusion optimization for deepseek v3/r1.",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-chunked-prefix-cache",
|
|
action="store_true",
|
|
help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-fast-image-processor",
|
|
action="store_true",
|
|
help="Adopt base image processor instead of fast image processor.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-return-hidden-states",
|
|
action="store_true",
|
|
help="Enable returning hidden states with responses.",
|
|
)
|
|
parser.add_argument(
|
|
"--scheduler-recv-interval",
|
|
type=int,
|
|
default=ServerArgs.scheduler_recv_interval,
|
|
help="The interval to poll requests in scheduler. Can be set to >1 to reduce the overhead of this.",
|
|
)
|
|
parser.add_argument(
|
|
"--numa-node",
|
|
type=int,
|
|
nargs="+",
|
|
help="Sets the numa node for the subprocesses. i-th element corresponds to i-th subprocess.",
|
|
)
|
|
|
|
# Debug tensor dumps
|
|
parser.add_argument(
|
|
"--debug-tensor-dump-output-folder",
|
|
type=str,
|
|
default=ServerArgs.debug_tensor_dump_output_folder,
|
|
help="The output folder for dumping tensors.",
|
|
)
|
|
parser.add_argument(
|
|
"--debug-tensor-dump-input-file",
|
|
type=str,
|
|
default=ServerArgs.debug_tensor_dump_input_file,
|
|
help="The input filename for dumping tensors",
|
|
)
|
|
parser.add_argument(
|
|
"--debug-tensor-dump-inject",
|
|
type=str,
|
|
default=ServerArgs.debug_tensor_dump_inject,
|
|
help="Inject the outputs from jax as the input of every layer.",
|
|
)
|
|
parser.add_argument(
|
|
"--debug-tensor-dump-prefill-only",
|
|
action="store_true",
|
|
help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
|
|
)
|
|
|
|
# PD disaggregation
|
|
parser.add_argument(
|
|
"--disaggregation-mode",
|
|
type=str,
|
|
default="null",
|
|
choices=["null", "prefill", "decode"],
|
|
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
|
|
)
|
|
parser.add_argument(
|
|
"--disaggregation-transfer-backend",
|
|
type=str,
|
|
default=ServerArgs.disaggregation_transfer_backend,
|
|
choices=DISAGG_TRANSFER_BACKEND_CHOICES,
|
|
help="The backend for disaggregation transfer. Default is mooncake.",
|
|
)
|
|
parser.add_argument(
|
|
"--disaggregation-bootstrap-port",
|
|
type=int,
|
|
default=ServerArgs.disaggregation_bootstrap_port,
|
|
help="Bootstrap server port on the prefill server. Default is 8998.",
|
|
)
|
|
parser.add_argument(
|
|
"--disaggregation-decode-tp",
|
|
type=int,
|
|
default=ServerArgs.disaggregation_decode_tp,
|
|
help="Decode tp size. If not set, it matches the tp size of the current engine. This is only set on the prefill server.",
|
|
)
|
|
parser.add_argument(
|
|
"--disaggregation-decode-dp",
|
|
type=int,
|
|
default=ServerArgs.disaggregation_decode_dp,
|
|
help="Decode dp size. If not set, it matches the dp size of the current engine. This is only set on the prefill server.",
|
|
)
|
|
parser.add_argument(
|
|
"--disaggregation-prefill-pp",
|
|
type=int,
|
|
default=ServerArgs.disaggregation_prefill_pp,
|
|
help="Prefill pp size. If not set, it is default to 1. This is only set on the decode server.",
|
|
)
|
|
parser.add_argument(
|
|
"--disaggregation-ib-device",
|
|
type=str,
|
|
default=ServerArgs.disaggregation_ib_device,
|
|
help="The InfiniBand devices for disaggregation transfer, accepts single device (e.g., --disaggregation-ib-device mlx5_0) "
|
|
"or multiple comma-separated devices (e.g., --disaggregation-ib-device mlx5_0,mlx5_1). "
|
|
"Default is None, which triggers automatic device detection when mooncake backend is enabled.",
|
|
)
|
|
parser.add_argument(
|
|
"--num-reserved-decode-tokens",
|
|
type=int,
|
|
default=ServerArgs.num_reserved_decode_tokens,
|
|
help="Number of decode tokens that will have memory reserved when adding new request to the running batch.",
|
|
)
|
|
|
|
# Custom weight loader
|
|
parser.add_argument(
|
|
"--custom-weight-loader",
|
|
type=str,
|
|
nargs="*",
|
|
default=None,
|
|
help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
|
|
)
|
|
parser.add_argument(
|
|
"--weight-loader-disable-mmap",
|
|
action="store_true",
|
|
help="Disable mmap while loading weight using safetensors.",
|
|
)
|
|
|
|
# For PD-Multiplexing
|
|
parser.add_argument(
|
|
"--enable-pdmux",
|
|
action="store_true",
|
|
help="Enable PD-Multiplexing, PD running on greenctx stream.",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--sm-group-num",
|
|
type=int,
|
|
default=ServerArgs.sm_group_num,
|
|
help="Number of sm partition groups.",
|
|
)
|
|
|
|
# Deprecated arguments
|
|
parser.add_argument(
|
|
"--enable-ep-moe",
|
|
action="store_true",
|
|
help="(Deprecated) Enabling expert parallelism for moe. The ep size is equal to the tp size.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-deepep-moe",
|
|
action="store_true",
|
|
help="(Deprecated) Enabling DeepEP MoE implementation for EP MoE.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-flashinfer-cutlass-moe",
|
|
action="store_true",
|
|
help="(Deprecated) Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-flashinfer-trtllm-moe",
|
|
action="store_true",
|
|
help="(Deprecated) Enable FlashInfer TRTLLM MoE backend on Blackwell. Supports BlockScale FP8 MoE-EP",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-triton-kernel-moe",
|
|
action="store_true",
|
|
help="(Deprecated) Use triton moe grouped gemm kernel.",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-flashinfer-mxfp4-moe",
|
|
action="store_true",
|
|
help="(Deprecated) Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell.",
|
|
)
|
|
|
|
@classmethod
|
|
def from_cli_args(cls, args: argparse.Namespace):
|
|
args.tp_size = args.tensor_parallel_size
|
|
args.pp_size = args.pipeline_parallel_size
|
|
args.dp_size = args.data_parallel_size
|
|
args.ep_size = args.expert_parallel_size
|
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
|
|
|
def url(self):
|
|
if is_valid_ipv6_address(self.host):
|
|
return f"http://[{self.host}]:{self.port}"
|
|
else:
|
|
return f"http://{self.host}:{self.port}"
|
|
|
|
def get_hf_config(self):
|
|
kwargs = {}
|
|
hf_config = get_config(
|
|
self.model_path,
|
|
trust_remote_code=self.trust_remote_code,
|
|
revision=self.revision,
|
|
model_override_args=json.loads(self.json_model_override_args),
|
|
**kwargs,
|
|
)
|
|
return hf_config
|
|
|
|
def check_server_args(self):
|
|
# Check parallel size constraints
|
|
assert (
|
|
self.tp_size * self.pp_size
|
|
) % self.nnodes == 0, "tp_size must be divisible by number of nodes"
|
|
|
|
if self.pp_size > 1:
|
|
assert (
|
|
self.disable_overlap_schedule
|
|
and self.speculative_algorithm is None
|
|
and not self.enable_mixed_chunk
|
|
), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill."
|
|
|
|
assert not (
|
|
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
|
|
), "multi-node data parallel is not supported unless dp attention!"
|
|
|
|
assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
|
|
assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
|
|
|
|
assert self.moe_dense_tp_size in {
|
|
1,
|
|
None,
|
|
}, "moe_dense_tp_size only support 1 and None currently"
|
|
|
|
# Check LoRA
|
|
self.check_lora_server_args()
|
|
|
|
# Check speculative decoding
|
|
if self.speculative_algorithm is not None:
|
|
assert (
|
|
not self.enable_mixed_chunk
|
|
), "enable_mixed_chunk is required for speculative decoding"
|
|
|
|
# Check chunked prefill
|
|
# Skip validation if chunked prefill is disabled (i.e., size <= 0).
|
|
if self.chunked_prefill_size > 0:
|
|
assert (
|
|
self.chunked_prefill_size % self.page_size == 0
|
|
), "chunked_prefill_size must be divisible by page_size"
|
|
|
|
# Check multi tokenizer
|
|
assert self.tokenizer_worker_num > 0, "Tokenizer worker num must >= 1"
|
|
self.validate_buckets_rule(
|
|
"--prompt-tokens-buckets", self.prompt_tokens_buckets
|
|
)
|
|
self.validate_buckets_rule(
|
|
"--generation-tokens-buckets", self.generation_tokens_buckets
|
|
)
|
|
|
|
def check_lora_server_args(self):
|
|
assert self.max_loras_per_batch > 0, "max_loras_per_batch must be positive"
|
|
|
|
# Enable LoRA if any LoRA paths are provided for backward compatibility.
|
|
if self.lora_paths:
|
|
if self.enable_lora is None:
|
|
self.enable_lora = True
|
|
logger.warning(
|
|
"--enable-lora is set to True because --lora-paths is provided."
|
|
)
|
|
elif self.enable_lora is False:
|
|
logger.warning(
|
|
"--enable-lora is set to False, any provided lora_paths will be ignored."
|
|
)
|
|
|
|
if self.enable_lora:
|
|
if isinstance(self.lora_paths, list):
|
|
lora_paths = self.lora_paths
|
|
self.lora_paths = []
|
|
for lora_path in lora_paths:
|
|
if isinstance(lora_path, str):
|
|
if "=" in lora_path:
|
|
name, path = lora_path.split("=", 1)
|
|
lora_ref = LoRARef(
|
|
lora_name=name, lora_path=path, pinned=False
|
|
)
|
|
else:
|
|
lora_ref = LoRARef(
|
|
lora_name=lora_path, lora_path=lora_path, pinned=False
|
|
)
|
|
elif isinstance(lora_path, dict):
|
|
assert (
|
|
"lora_name" in lora_path and "lora_path" in lora_path
|
|
), f"When providing LoRA paths as a list of dict, each dict should contain 'lora_name' and 'lora_path' keys. Got: {lora_path}"
|
|
lora_ref = LoRARef(
|
|
lora_name=lora_path["lora_name"],
|
|
lora_path=lora_path["lora_path"],
|
|
pinned=lora_path.get("pinned", False),
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid type for item in --lora-paths list: {type(lora_path)}. "
|
|
"Expected a string or a dictionary."
|
|
)
|
|
self.lora_paths.append(lora_ref)
|
|
elif isinstance(self.lora_paths, dict):
|
|
self.lora_paths = [
|
|
LoRARef(lora_name=k, lora_path=v, pinned=False)
|
|
for k, v in self.lora_paths.items()
|
|
]
|
|
elif self.lora_paths is None:
|
|
self.lora_paths = []
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid type for --lora-paths: {type(self.lora_paths)}. "
|
|
"Expected a list or a dictionary."
|
|
)
|
|
|
|
# Expand target modules
|
|
if self.lora_target_modules:
|
|
self.lora_target_modules = set(self.lora_target_modules)
|
|
if "all" in self.lora_target_modules:
|
|
assert (
|
|
len(self.lora_target_modules) == 1
|
|
), "If 'all' is specified in --lora-target-modules, it should be the only module specified."
|
|
self.lora_target_modules = set(SUPPORTED_LORA_TARGET_MODULES)
|
|
|
|
# Ensure sufficient information is provided for LoRA initialization.
|
|
assert self.lora_paths or (
|
|
self.max_lora_rank and self.lora_target_modules
|
|
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
|
|
|
|
# Validate max_loaded_loras
|
|
if self.max_loaded_loras is not None:
|
|
assert self.max_loaded_loras >= self.max_loras_per_batch, (
|
|
"max_loaded_loras should be greater than or equal to max_loras_per_batch. "
|
|
f"max_loaded_loras={self.max_loaded_loras}, max_loras_per_batch={self.max_loras_per_batch}"
|
|
)
|
|
assert len(self.lora_paths) <= self.max_loaded_loras, (
|
|
"The number of LoRA paths should not exceed max_loaded_loras. "
|
|
f"max_loaded_loras={self.max_loaded_loras}, lora_paths={len(self.lora_paths)}"
|
|
)
|
|
|
|
def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
|
|
larger_tp = max(decode_tp, prefill_tp)
|
|
smaller_tp = min(decode_tp, prefill_tp)
|
|
assert larger_tp % smaller_tp == 0, (
|
|
"Different tp size is supported only when one tp is multiple of the other. "
|
|
f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
|
|
)
|
|
|
|
def validate_buckets_rule(self, arg_name: str, buckets_rule: List[str]):
|
|
if not buckets_rule:
|
|
return
|
|
|
|
assert len(buckets_rule) > 0, f"{arg_name} cannot be empty list"
|
|
rule = buckets_rule[0]
|
|
assert rule in [
|
|
"tse",
|
|
"default",
|
|
"customer",
|
|
], f"Unsupported {arg_name} rule type: '{rule}'. Must be one of: 'tse', 'default', 'customer'"
|
|
|
|
if rule == "tse":
|
|
assert (
|
|
len(buckets_rule) == 4
|
|
), f"{arg_name} TSE rule requires exactly 4 parameters: ['tse', middle, base, count], got {len(buckets_rule)}"
|
|
try:
|
|
middle = float(buckets_rule[1])
|
|
base = float(buckets_rule[2])
|
|
count = int(buckets_rule[3])
|
|
except (ValueError, IndexError):
|
|
assert (
|
|
False
|
|
), f"{arg_name} TSE rule parameters must be: ['tse', <float:middle>, <float:base>, <int:count>]"
|
|
assert base > 1, f"{arg_name} TSE base must be larger than 1, got: {base}"
|
|
assert count > 0, f"{arg_name} TSE count must be positive, got: {count}"
|
|
assert middle > 0, f"{arg_name} TSE middle must be positive, got: {middle}"
|
|
|
|
elif rule == "default":
|
|
assert (
|
|
len(buckets_rule) == 1
|
|
), f"{arg_name} default rule should only have one parameter: ['default'], got {len(buckets_rule)}"
|
|
|
|
elif rule == "customer":
|
|
assert (
|
|
len(buckets_rule) >= 2
|
|
), f"{arg_name} customer rule requires at least one bucket value: ['customer', value1, ...]"
|
|
try:
|
|
bucket_values = [float(x) for x in buckets_rule[1:]]
|
|
except ValueError:
|
|
assert False, f"{arg_name} customer rule bucket values must be numeric"
|
|
assert len(set(bucket_values)) == len(
|
|
bucket_values
|
|
), f"{arg_name} customer rule bucket values should not contain duplicates"
|
|
assert all(
|
|
val >= 0 for val in bucket_values
|
|
), f"{arg_name} customer rule bucket values should be non-negative"
|
|
|
|
def model_specific_adjustments(self):
|
|
hf_config = self.get_hf_config()
|
|
model_arch = hf_config.architectures[0]
|
|
if model_arch in ["GptOssForCausalLM"]:
|
|
if self.attention_backend is None:
|
|
if is_cuda() and is_sm100_supported():
|
|
self.attention_backend = "trtllm_mha"
|
|
elif is_cuda() and is_sm90_supported():
|
|
self.attention_backend = "fa3"
|
|
else:
|
|
self.attention_backend = "triton"
|
|
supported_backends = ["triton", "trtllm_mha", "fa3"]
|
|
logger.info(
|
|
f"Use {self.attention_backend} as attention backend for GptOssForCausalLM"
|
|
)
|
|
assert (
|
|
self.attention_backend in supported_backends
|
|
), f"GptOssForCausalLM requires one of {supported_backends} attention backend, but got '{self.attention_backend}'"
|
|
|
|
if is_sm100_supported():
|
|
if not self.enable_dp_attention:
|
|
self.enable_flashinfer_allreduce_fusion = True
|
|
logger.info(
|
|
"Enable FlashInfer AllReduce Fusion on sm100 for GptOssForCausalLM"
|
|
)
|
|
quantization_config = getattr(hf_config, "quantization_config", None)
|
|
is_mxfp4_quant_format = (
|
|
quantization_config is not None
|
|
and quantization_config.get("quant_method") == "mxfp4"
|
|
)
|
|
|
|
if is_sm100_supported() and is_mxfp4_quant_format:
|
|
self.moe_runner_backend = "flashinfer_mxfp4"
|
|
logger.warning(
|
|
"Detected SM100 and MXFP4 quantization format for GPT-OSS model, enabling FlashInfer MXFP4 MOE kernel."
|
|
)
|
|
else:
|
|
if self.moe_runner_backend == "triton_kernel":
|
|
assert (
|
|
self.ep_size == 1
|
|
), "Triton kernel MoE is only supported when ep_size == 1"
|
|
if (
|
|
self.moe_runner_backend == "auto"
|
|
and self.ep_size == 1
|
|
and is_triton_kernels_available()
|
|
):
|
|
self.moe_runner_backend = "triton_kernel"
|
|
logger.warning(
|
|
"Detected GPT-OSS model, enabling triton_kernels MOE kernel."
|
|
)
|
|
self.disable_hybrid_swa_memory = True
|
|
if is_mxfp4_quant_format:
|
|
# use bf16 for mxfp4 triton kernels
|
|
self.dtype = "bfloat16"
|
|
|
|
elif "Llama4" in model_arch:
|
|
assert self.attention_backend in {
|
|
"fa3",
|
|
"aiter",
|
|
"triton",
|
|
}, "fa3, aiter, or triton is required for Llama4 model"
|
|
elif model_arch in [
|
|
"Gemma2ForCausalLM",
|
|
"Gemma3ForCausalLM",
|
|
"Gemma3ForConditionalGeneration",
|
|
"Gemma3nForCausalLM",
|
|
"Gemma3nForConditionalGeneration",
|
|
]:
|
|
# FIXME: https://github.com/sgl-project/sglang/pull/7367 is not compatible with gemma2 model.
|
|
# It failed at this test: https://github.com/sgl-project/sglang/actions/runs/16255155597/job/45890331952#step:4:736
|
|
logger.warning(
|
|
f"Disable hybrid SWA memory for {model_arch} as it is not yet supported."
|
|
)
|
|
self.disable_hybrid_swa_memory = True
|
|
|
|
def adjust_mem_fraction_for_vlm(self, model_config):
|
|
vision_config = getattr(model_config.hf_config, "vision_config", None)
|
|
if vision_config is None:
|
|
return
|
|
|
|
# roughly reduce the mem_fraction_static base on params of Vit
|
|
original_server_arg_mem_fraction = self.mem_fraction_static
|
|
# a base mem_fraction_static factor for regular Vit
|
|
base_mem_fraction_reduction_ratio = 0.95
|
|
|
|
vit_num_layers = getattr(vision_config, "num_hidden_layers", 24)
|
|
vit_hidden_size = getattr(vision_config, "hidden_size", 1024)
|
|
|
|
# baseline ViT params (ViT-L/14)
|
|
baseline_vit_layers = 24
|
|
baseline_vit_hidden_size = 1024
|
|
|
|
# weight params count
|
|
current_complexity_score = vit_num_layers * (vit_hidden_size**2)
|
|
baseline_complexity_score = baseline_vit_layers * (baseline_vit_hidden_size**2)
|
|
complexity_ratio = (
|
|
current_complexity_score / baseline_complexity_score
|
|
if baseline_complexity_score > 0
|
|
else 1.0
|
|
)
|
|
|
|
# every time the complexity grows 100%, adjust final factor for 10%
|
|
sensitivity_scale = 0.1
|
|
dynamic_adjustment_factor = 1.0 - sensitivity_scale * (complexity_ratio - 1.0)
|
|
dynamic_adjustment_factor = max(0.8, min(1.05, dynamic_adjustment_factor))
|
|
|
|
final_overall_factor = (
|
|
base_mem_fraction_reduction_ratio * dynamic_adjustment_factor
|
|
)
|
|
self.mem_fraction_static = (
|
|
original_server_arg_mem_fraction * final_overall_factor
|
|
)
|
|
|
|
|
|
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
|
"""
|
|
Prepare the server arguments from the command line arguments.
|
|
|
|
Args:
|
|
args: The command line arguments. Typically, it should be `sys.argv[1:]`
|
|
to ensure compatibility with `parse_args` when no arguments are passed.
|
|
|
|
Returns:
|
|
The server arguments.
|
|
"""
|
|
parser = argparse.ArgumentParser()
|
|
ServerArgs.add_cli_args(parser)
|
|
raw_args = parser.parse_args(argv)
|
|
server_args = ServerArgs.from_cli_args(raw_args)
|
|
return server_args
|
|
|
|
|
|
ZMQ_TCP_PORT_DELTA = 233
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class PortArgs:
|
|
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
|
|
tokenizer_ipc_name: str
|
|
# The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
|
|
scheduler_input_ipc_name: str
|
|
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
|
|
detokenizer_ipc_name: str
|
|
|
|
# The port for nccl initialization (torch.dist)
|
|
nccl_port: int
|
|
|
|
# The ipc filename for rpc call between Engine and Scheduler
|
|
rpc_ipc_name: str
|
|
|
|
# The ipc filename for Scheduler to send metrics
|
|
metrics_ipc_name: str
|
|
|
|
# The ipc filename for Tokenizer and worker tokenizer
|
|
tokenizer_worker_ipc_name: Optional[str]
|
|
|
|
@staticmethod
|
|
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
|
if server_args.nccl_port is None:
|
|
nccl_port = server_args.port + random.randint(100, 1000)
|
|
while True:
|
|
if is_port_available(nccl_port):
|
|
break
|
|
if nccl_port < 60000:
|
|
nccl_port += 42
|
|
else:
|
|
nccl_port -= 43
|
|
else:
|
|
nccl_port = server_args.nccl_port
|
|
|
|
if not server_args.enable_dp_attention:
|
|
# Normal case, use IPC within a single node
|
|
return PortArgs(
|
|
tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
|
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
|
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
|
nccl_port=nccl_port,
|
|
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
|
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
|
tokenizer_worker_ipc_name=None,
|
|
)
|
|
else:
|
|
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
|
if server_args.nnodes == 1 and server_args.dist_init_addr is None:
|
|
dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
|
|
elif server_args.dist_init_addr.startswith("["): # ipv6 address
|
|
port_num, host = configure_ipv6(server_args.dist_init_addr)
|
|
dist_init_addr = (host, str(port_num))
|
|
else:
|
|
dist_init_addr = server_args.dist_init_addr.split(":")
|
|
|
|
assert (
|
|
len(dist_init_addr) == 2
|
|
), "please provide --dist-init-addr as host:port of head node"
|
|
|
|
dist_init_host, dist_init_port = dist_init_addr
|
|
port_base = int(dist_init_port) + 1
|
|
detokenizer_port = port_base + 1
|
|
rpc_port = port_base + 2
|
|
metrics_ipc_name = port_base + 3
|
|
if dp_rank is None:
|
|
# TokenizerManager to DataParallelController
|
|
scheduler_input_port = port_base + 4
|
|
else:
|
|
scheduler_input_port = port_base + 4 + 1 + dp_rank
|
|
|
|
return PortArgs(
|
|
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
|
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
|
detokenizer_ipc_name=f"tcp://{dist_init_host}:{detokenizer_port}",
|
|
nccl_port=nccl_port,
|
|
rpc_ipc_name=f"tcp://{dist_init_host}:{rpc_port}",
|
|
metrics_ipc_name=f"tcp://{dist_init_host}:{metrics_ipc_name}",
|
|
tokenizer_worker_ipc_name=None,
|
|
)
|
|
|
|
|
|
class LoRAPathAction(argparse.Action):
|
|
def __call__(self, parser, namespace, values, option_string=None):
|
|
lora_paths = []
|
|
if values:
|
|
assert isinstance(values, list), "Expected a list of LoRA paths."
|
|
for lora_path in values:
|
|
lora_path = lora_path.strip()
|
|
if lora_path.startswith("{") and lora_path.endswith("}"):
|
|
obj = json.loads(lora_path)
|
|
assert "lora_path" in obj and "lora_name" in obj, (
|
|
f"{repr(lora_path)} looks like a JSON str, "
|
|
"but it does not contain 'lora_name' and 'lora_path' keys."
|
|
)
|
|
lora_paths.append(obj)
|
|
else:
|
|
lora_paths.append(lora_path)
|
|
|
|
setattr(namespace, self.dest, lora_paths)
|
|
|
|
|
|
class DeprecatedAction(argparse.Action):
|
|
def __init__(self, option_strings, dest, nargs=0, **kwargs):
|
|
super(DeprecatedAction, self).__init__(
|
|
option_strings, dest, nargs=nargs, **kwargs
|
|
)
|
|
|
|
def __call__(self, parser, namespace, values, option_string=None):
|
|
raise ValueError(self.help)
|
|
|
|
|
|
def print_deprecated_warning(message: str):
|
|
logger.warning(f"\033[33m{message}\033[0m")
|
|
|
|
|
|
def auto_choose_speculative_params(self: ServerArgs):
|
|
"""
|
|
Automatically choose the parameters for speculative decoding.
|
|
|
|
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
|
|
"""
|
|
hf_config = self.get_hf_config()
|
|
arch = hf_config.architectures[0]
|
|
if self.speculative_algorithm == "STANDALONE":
|
|
# The default value for standalone speculative decoding
|
|
return (3, 1, 4)
|
|
if arch in ["LlamaForCausalLM"]:
|
|
# The default value for llama
|
|
return (5, 4, 8)
|
|
elif arch in [
|
|
"DeepseekV3ForCausalLM",
|
|
"DeepseekV2ForCausalLM",
|
|
"GptOssForCausalLM",
|
|
]:
|
|
# The default value for deepseek and gpt-oss
|
|
return (3, 1, 4)
|
|
elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
|
|
return (5, 4, 8)
|
|
else:
|
|
# The default value for all other models
|
|
return (5, 4, 8)
|