sglang0.4.5.post1/python/sglang/srt/layers/quantization/modelopt_quant.py

197 lines
6.8 KiB
Python

# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
import logging
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
cutlass_fp8_supported,
)
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt.layers.quantization.utils import (
convert_to_channelwise,
requantize_with_max_scale,
)
# Initialize logger for the module
logger = logging.getLogger(__name__)
# Supported activation schemes for the current configuration
ACTIVATION_SCHEMES = ["static"]
class ModelOptFp8Config(QuantizationConfig):
"""Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None:
"""
Args:
is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
"""
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
logger.warning(
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
)
@classmethod
def get_name(cls) -> str:
return "modelopt"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.half]
@classmethod
def get_min_capability(cls) -> int:
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["hf_quant_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
if "FP8" not in quant_method:
raise ValueError(
"ModelOpt only supports static FP8 quantization in SGLang. "
"Check the `hf_quant_config.json` file for your model's configuration."
)
return cls(is_checkpoint_fp8_serialized=True)
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return ModelOptFp8LinearMethod(self)
if isinstance(layer, AttentionBackend):
return ModelOptFp8KVCacheMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class ModelOptFp8LinearMethod(LinearMethodBase):
"""Linear method for ModelOpt static FP8 quantization.
Supports loading FP8 checkpoints with static weight and activation scales.
Future support may include dynamic scales.
**Limitations**:
1. Only supports per-tensor quantization due to `torch._scaled_mm` limitations.
2. Only supports the `float8_e4m3fn` data type.
Args:
quant_config (ModelOptFp8Config): The ModelOpt quantization configuration.
"""
def __init__(self, quant_config: ModelOptFp8Config):
super().__init__()
self.quant_config = quant_config
self.cutlass_fp8_supported = cutlass_fp8_supported()
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
"""Creates and registers weights, weight scales, and input scales for FP8 quantization."""
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
weight_dtype = (
torch.float8_e4m3fn
if self.quant_config.is_checkpoint_fp8_serialized
else params_dtype
)
# Set layer attributes
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
# Register weight
layer.register_parameter(
"weight",
ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
),
)
if self.quant_config.is_checkpoint_fp8_serialized:
# Register weight and input scales
for scale_name in ["weight_scale", "input_scale"]:
layer.register_parameter(
scale_name,
PerTensorScaleParameter(
data=torch.full(
(len(output_partition_sizes),),
torch.finfo(torch.float32).min,
dtype=torch.float32,
),
weight_loader=weight_loader,
),
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
"""Requantizes weights after loading using the maximum scale."""
max_w_scale, quantized_weight = requantize_with_max_scale(
layer.weight, layer.weight_scale, layer.logical_widths
)
layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
# cutlass sgl-kernel only supports per-channel scale
if self.cutlass_fp8_supported:
max_w_scale = convert_to_channelwise(max_w_scale, layer.logical_widths)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Applies FP8 linear transformation."""
return apply_fp8_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported,
)
class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
"""
Handles loading FP8 kv-cache scaling factors from modelopt quantized checkpoints.
"""
def __init__(self, quant_config: ModelOptFp8Config):
super().__init__(quant_config)