# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging from typing import Any, Callable, Dict, List, Optional import torch from torch.nn import Module from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.linear import ( LinearBase, LinearMethodBase, UnquantizedLinearMethod, ) from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import set_weight_attrs ACTIVATION_SCHEMES = ["static", "dynamic"] logger = logging.getLogger(__name__) class BlockInt8Config(QuantizationConfig): """Config class for INT8.""" def __init__( self, is_checkpoint_int8_serialized: bool = False, activation_scheme: str = "dynamic", ignored_layers: Optional[List[str]] = None, weight_block_size: List[int] = None, ) -> None: self.is_checkpoint_int8_serialized = is_checkpoint_int8_serialized if is_checkpoint_int8_serialized: logger.warning( "Detected int8 checkpoint. Please note that the " "format is experimental and subject to change." ) if activation_scheme not in ACTIVATION_SCHEMES: raise ValueError(f"Unsupported activation scheme {activation_scheme}") self.activation_scheme = activation_scheme self.ignored_layers = ignored_layers or [] if weight_block_size is not None: if not is_checkpoint_int8_serialized: raise ValueError( f"The block-wise quantization only supports int8-serialized checkpoint for now." ) if len(weight_block_size) != 2: raise ValueError( f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions." ) if activation_scheme != "dynamic": raise ValueError( f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme." ) self.weight_block_size = weight_block_size @classmethod def get_name(cls) -> str: return "blockwise_int8" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: return 80 @classmethod def get_config_filenames(cls) -> List[str]: return [] @classmethod def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_int8_serialized = "int8" in quant_method activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) return cls( is_checkpoint_int8_serialized=is_checkpoint_int8_serialized, activation_scheme=activation_scheme, ignored_layers=ignored_layers, weight_block_size=weight_block_size, ) def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): return UnquantizedLinearMethod() return BlockInt8LinearMethod(self) elif isinstance(layer, FusedMoE): return BlockInt8MoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: return [] class BlockInt8LinearMethod(LinearMethodBase): """Linear method for INT8. Supports loading INT8 checkpoints with static weight scale and dynamic activation scale. Limitations: Only support block-wise int8 quantization and int8 checkpoint Args: quant_config: The quantization config. """ def __init__(self, quant_config: BlockInt8Config): self.quant_config = quant_config assert self.quant_config.weight_block_size is not None assert self.quant_config.is_checkpoint_int8_serialized def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") tp_size = get_tensor_model_parallel_world_size() block_n, block_k = ( self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[1], ) # Required by row parallel if tp_size > 1 and input_size // input_size_per_partition == tp_size: if input_size_per_partition % block_k != 0: raise ValueError( f"Weight input_size_per_partition = " f"{input_size_per_partition} is not divisible by " f"weight quantization block_k = {block_k}." ) # Required by collum parallel or enabling merged weights if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len( output_partition_sizes ) > 1: for output_partition_size in output_partition_sizes: if output_partition_size % block_n != 0: raise ValueError( f"Weight output_partition_size = " f"{output_partition_size} is not divisible by " f"weight quantization block_n = {block_n}." ) layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.orig_dtype = params_dtype # WEIGHT weight_dtype = ( torch.int8 if self.quant_config.is_checkpoint_int8_serialized else params_dtype ) weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, dtype=weight_dtype ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) # WEIGHT SCALE scale = BlockQuantScaleParameter( data=torch.empty( (output_size_per_partition + block_n - 1) // block_n, (input_size_per_partition + block_k - 1) // block_k, dtype=torch.float32, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale_inv", scale) # INPUT ACTIVATION SCALE assert self.quant_config.activation_scheme == "dynamic" layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: # Block quant doesn't need to process weights after loading # Use torch Parameter to avoid cuda graph capturing issue layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) layer.weight_scale_inv = torch.nn.Parameter( layer.weight_scale_inv.data, requires_grad=False ) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return apply_w8a8_block_int8_linear( input=x, weight=layer.weight, block_size=self.quant_config.weight_block_size, weight_scale=layer.weight_scale_inv, input_scale=None, bias=bias, ) class BlockInt8MoEMethod: """MoE method for INT8. Supports loading INT8 checkpoints with static weight scale and dynamic activation scale. Limitations: Only support block-wise int8 quantization and int8 checkpoint Args: quant_config: The quantization config. """ def __new__(cls, *args, **kwargs): from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase if not hasattr(cls, "_initialized"): original_init = cls.__init__ new_cls = type( cls.__name__, (FusedMoEMethodBase,), { "__init__": original_init, **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, }, ) obj = super(new_cls, new_cls).__new__(new_cls) obj.__init__(*args, **kwargs) return obj return super().__new__(cls) def __init__(self, quant_config): self.quant_config = quant_config assert self.quant_config.weight_block_size is not None assert self.quant_config.is_checkpoint_int8_serialized def create_weights( self, layer: Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported if self.quant_config.is_checkpoint_int8_serialized: params_dtype = torch.int8 tp_size = get_tensor_model_parallel_world_size() block_n, block_k = ( self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[1], ) # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. # Required by collum parallel or enabling merged weights if intermediate_size % block_n != 0: raise ValueError( f"The output_size of gate's and up's weight = " f"{intermediate_size} is not divisible by " f"weight quantization block_n = {block_n}." ) if tp_size > 1: # Required by row parallel if intermediate_size % block_k != 0: raise ValueError( f"The input_size of down's weight = " f"{intermediate_size} is not divisible by " f"weight quantization block_k = {block_k}." ) # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( torch.empty( num_experts, hidden_size, intermediate_size, dtype=params_dtype ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, 2 * ((intermediate_size + block_n - 1) // block_n), (hidden_size + block_k - 1) // block_k, dtype=torch.float32, ), requires_grad=False, ) w2_weight_scale = torch.nn.Parameter( torch.ones( num_experts, (hidden_size + block_n - 1) // block_n, (intermediate_size + block_k - 1) // block_k, dtype=torch.float32, ), requires_grad=False, ) layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES assert self.quant_config.activation_scheme == "dynamic" layer.w13_input_scale = None layer.w2_input_scale = None def process_weights_after_loading(self, layer: Module) -> None: # Block quant doesn't need to process weights after loading return def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", inplace: bool = True, no_combine: bool = False, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.topk import select_experts # Expert selection topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, ) # Expert fusion with INT8 quantization return fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=inplace, activation=activation, use_int8_w8a8=True, w1_scale=(layer.w13_weight_scale_inv), w2_scale=(layer.w2_weight_scale_inv), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, no_combine=no_combine, )