from typing import Any, Callable, Dict, List, Optional import torch from sglang.srt.utils import is_cuda_available, set_weight_attrs is_cuda = is_cuda_available() if is_cuda: from sgl_kernel import int8_scaled_mm from torch.nn.parameter import Parameter from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.linear import LinearMethodBase from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 class W8A8Int8Config(QuantizationConfig): """Config class for W8A8 Int8 Quantization. - Weight: static, per-channel, symmetric - Activation: dynamic, per-token, symmetric """ def __init__(self): pass @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.float16, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: return 75 @classmethod def get_name(self) -> str: return "w8a8_int8" @classmethod def get_config_filenames(cls) -> List[str]: return [] @classmethod def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config": return cls() def get_quant_method( self, layer: torch.nn.Module, prefix: str, ) -> Optional["QuantizeMethodBase"]: from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): return W8A8Int8LinearMethod(self) elif isinstance(layer, FusedMoE): return W8A8Int8MoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: return [] class W8A8Int8LinearMethod(LinearMethodBase): def __init__(self, quantization_config: W8A8Int8Config): self.quantization_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = Parameter(layer.weight.t(), requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): weight_loader = extra_weight_attrs.get("weight_loader") self.logical_widths = output_partition_sizes weight = ModelWeightParameter( data=torch.empty( sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) weight_scale = ChannelQuantScaleParameter( data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight_scale", weight_scale) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ): x_q, x_scale = per_token_quant_int8(x) return int8_scaled_mm( x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias ) class W8A8Int8MoEMethod: """MoE method for INT8. Supports loading INT8 checkpoints with static weight scale and dynamic/static activation scale. Also supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded. Args: quant_config: The quantization config. """ def __new__(cls, *args, **kwargs): from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase if not hasattr(cls, "_initialized"): original_init = cls.__init__ new_cls = type( cls.__name__, (FusedMoEMethodBase,), { "__init__": original_init, **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, }, ) obj = super(new_cls, new_cls).__new__(new_cls) obj.__init__(*args, **kwargs) return obj return super().__new__(cls) def __init__(self, quant_config): self.quant_config = quant_config def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported tp_size = get_tensor_model_parallel_world_size() # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8 ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) w13_weight_scale = torch.nn.Parameter( torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), requires_grad=False, ) w2_weight_scale = torch.nn.Parameter( torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), requires_grad=False, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) w13_input_scale = None layer.register_parameter("w13_input_scale", w13_input_scale) w2_input_scale = None layer.register_parameter("w2_input_scale", w2_input_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False) layer.w13_weight_scale = Parameter( layer.w13_weight_scale.data, requires_grad=False ) layer.w2_weight_scale = Parameter( layer.w2_weight_scale.data, requires_grad=False ) def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, activation: str = "silu", inplace: bool = True, no_combine: bool = False, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.topk import select_experts # Expert selection topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, correction_bias=correction_bias, ) return fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=inplace, activation=activation, use_int8_w8a8=True, w1_scale=(layer.w13_weight_scale), w2_scale=(layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, no_combine=no_combine, )