# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py import logging from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.linear import LinearBase, LinearMethodBase from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, cutlass_fp8_supported, ) from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.utils import ( convert_to_channelwise, requantize_with_max_scale, ) # Initialize logger for the module logger = logging.getLogger(__name__) # Supported activation schemes for the current configuration ACTIVATION_SCHEMES = ["static"] class ModelOptFp8Config(QuantizationConfig): """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks.""" def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None: """ Args: is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format. """ self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized if is_checkpoint_fp8_serialized: logger.warning( "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change." ) @classmethod def get_name(cls) -> str: return "modelopt" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: return 89 # Minimum hardware capability (e.g., Hopper GPUs). @classmethod def get_config_filenames(cls) -> List[str]: return ["hf_quant_config.json"] @classmethod def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo") if "FP8" not in quant_method: raise ValueError( "ModelOpt only supports static FP8 quantization in SGLang. " "Check the `hf_quant_config.json` file for your model's configuration." ) return cls(is_checkpoint_fp8_serialized=True) def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): return ModelOptFp8LinearMethod(self) if isinstance(layer, AttentionBackend): return ModelOptFp8KVCacheMethod(self) return None def get_scaled_act_names(self) -> List[str]: return [] class ModelOptFp8LinearMethod(LinearMethodBase): """Linear method for ModelOpt static FP8 quantization. Supports loading FP8 checkpoints with static weight and activation scales. Future support may include dynamic scales. **Limitations**: 1. Only supports per-tensor quantization due to `torch._scaled_mm` limitations. 2. Only supports the `float8_e4m3fn` data type. Args: quant_config (ModelOptFp8Config): The ModelOpt quantization configuration. """ def __init__(self, quant_config: ModelOptFp8Config): super().__init__() self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: """Creates and registers weights, weight scales, and input scales for FP8 quantization.""" output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") weight_dtype = ( torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype ) # Set layer attributes layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition # Register weight layer.register_parameter( "weight", ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, dtype=weight_dtype, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ), ) if self.quant_config.is_checkpoint_fp8_serialized: # Register weight and input scales for scale_name in ["weight_scale", "input_scale"]: layer.register_parameter( scale_name, PerTensorScaleParameter( data=torch.full( (len(output_partition_sizes),), torch.finfo(torch.float32).min, dtype=torch.float32, ), weight_loader=weight_loader, ), ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: """Requantizes weights after loading using the maximum scale.""" max_w_scale, quantized_weight = requantize_with_max_scale( layer.weight, layer.weight_scale, layer.logical_widths ) layer.weight = Parameter(quantized_weight.t(), requires_grad=False) # cutlass sgl-kernel only supports per-channel scale if self.cutlass_fp8_supported: max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Applies FP8 linear transformation.""" return apply_fp8_linear( input=x, weight=layer.weight, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, ) class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): """ Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints. """ def __init__(self, quant_config: ModelOptFp8Config): super().__init__(quant_config)