from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter from sglang.srt.layers.linear import LinearMethodBase from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, cutlass_fp8_supported, input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.utils import is_hip _is_hip = is_hip() class W8A8Fp8Config(QuantizationConfig): """Config class for W8A8 FP8 Quantization. Weight Quantization: - Method: Static quantization - Granularity: Per-channel - Type: Symmetric Activation Quantization: - Method: Dynamic quantization - Granularity: Per-token - Type: Symmetric Note: - For models without offline quantization, weights will be quantized during model loading - If CUTLASS is supported: Per-channel weight quantization is used - If CUTLASS is not supported: Falls back to per-tensor weight quantization """ def __init__(self, is_checkpoint_fp8_serialized: bool = False): self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: return 89 @classmethod def get_name(self) -> str: return "w8a8_fp8" @classmethod def get_config_filenames(cls) -> List[str]: return [] @classmethod def from_config(cls, config: Dict[str, Any]) -> "W8A8Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = "compressed-tensors" in quant_method return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized) def get_quant_method( self, layer: torch.nn.Module, prefix: str, ) -> Optional["QuantizeMethodBase"]: from sglang.srt.layers.linear import LinearBase if isinstance(layer, LinearBase): return W8A8Fp8LinearMethod(self) return None def get_scaled_act_names(self) -> List[str]: return [] class W8A8Fp8LinearMethod(LinearMethodBase): def __init__(self, quantization_config: W8A8Fp8Config): self.cutlass_fp8_supported = cutlass_fp8_supported() self.quantization_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight = layer.weight if self.quantization_config.is_checkpoint_fp8_serialized: weight_scale = layer.weight_scale.detach() # If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly. if _is_hip: weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale ) layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) else: # If checkpoint not offline quantized, quantize the weights with per-channel quantization. if self.cutlass_fp8_supported: # if cutlass supported, we use cutlass_scaled_mm # which requires per-channel quantization on weight qweight, weight_scale = per_token_group_quant_fp8( layer.weight, layer.weight.shape[-1] ) weight_scale = weight_scale.t().contiguous() if _is_hip: weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale ) else: # if cutlass not supported, we fall back to use torch._scaled_mm # which requires per tensor quantization on weight fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype) # Update the layer with the new values. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.input_scale = None def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs ): weight_dtype = ( torch.float8_e4m3fn if self.quantization_config.is_checkpoint_fp8_serialized else params_dtype ) weight_loader = extra_weight_attrs.get("weight_loader") self.logical_widths = output_partition_sizes weight = ModelWeightParameter( data=torch.empty( sum(output_partition_sizes), input_size_per_partition, dtype=weight_dtype, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) if self.quantization_config.is_checkpoint_fp8_serialized: weight_scale = ChannelQuantScaleParameter( data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight_scale", weight_scale) else: layer.weight_scale = None def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ): return apply_fp8_linear( x, layer.weight, layer.weight_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, )