import logging from fractions import Fraction from typing import Any, Dict, List, Optional, Union import torch from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import is_cuda _is_cuda = is_cuda() try: from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinLinearMethod, GPTQMarlinMoEMethod, ) from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, ) from vllm.scalar_type import scalar_types VLLM_AVAILABLE = True except ImportError: VLLM_AVAILABLE = False GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any class scalar_types: uint4b8 = "uint4b8" uint8b128 = "uint8b128" logger = logging.getLogger(__name__) class GPTQConfig(QuantizationConfig): """Config class for GPTQ. Reference: https://arxiv.org/abs/2210.17323 """ def __init__( self, weight_bits: int, group_size: int, desc_act: bool, lm_head_quantized: bool, dynamic: Dict[str, Dict[str, Union[int, bool]]], ) -> None: # GPTQModel use `dynamic` config property to allow per module # quantization config so each module can be individually optimized. # Format is Dict[str, Dict] where key is a regex string that can # perform both positive ("+:" prefixed) or negative ("-:" prefixed) # matching of a module. # Default to positive match, override base quant config mode, if no # prefix is used. Value is in dict format of field key and override # value. # Negative matching will skip quantization init for this module # entirely: # non-quantized inference. More details and quantization examples can be # found at: https://github.com/ModelCloud/GPTQModel # Example: # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 # # last 1/4 of the layers 16-21 has 8bit and group_size 64 # dynamic = { # #`.*\.` matches the layers_node prefix # # positive match layer 10-15 # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, # # positive match layer 16-21 # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers # } super().__init__() self.dynamic = dynamic self.weight_bits = weight_bits self.group_size = group_size self.desc_act = desc_act self.lm_head_quantized = lm_head_quantized self.pack_factor = Fraction(32, self.weight_bits) if self.weight_bits not in [2, 3, 4, 8]: raise ValueError( "Currently, only 2/3/4/8-bit weight quantization is " f"supported for GPTQ, but got {self.weight_bits} bits." ) def __repr__(self) -> str: return ( f"GPTQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " f"desc_act={self.desc_act})," f"lm_head_quantized={self.lm_head_quantized}), " f"dynamic={self.dynamic}" ) def get_scaled_act_names(self) -> List[str]: """Returns the activation function names that should be post-scaled. For now, this is only used by AWQ. """ raise NotImplementedError @classmethod def get_name(cls) -> str: return "gptq" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.half] @classmethod # Need to figure it out def get_min_capability(cls) -> int: return 60 @classmethod def get_config_filenames(cls) -> List[str]: return ["quantize_config.json"] @classmethod def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = {} if dynamic is None else dynamic weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) return cls(weight_bits, group_size, desc_act, lm_head_quantized, dynamic) def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[GPTQLinearMethod]: # Delay the import to avoid circular dependency from sglang.srt.layers.quantization import get_linear_quant_method return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) class GPTQMarlinConfig(QuantizationConfig): """Config class for GPTQ Marlin""" # (num_bits, is_sym) -> quant_type TYPE_MAP = { (4, True): scalar_types.uint4b8, (8, True): scalar_types.uint8b128, } def __init__( self, weight_bits: int, group_size: int, desc_act: bool, is_sym: bool, lm_head_quantized: bool, dynamic: Dict[str, Dict[str, Union[int, bool]]], full_config: Dict[str, Any], ) -> None: super().__init__() if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) desc_act = False # GPTQModel use `dynamic` config property to allow per module # quantization config so each module can be individually optimized. # Format is Dict[str, Dict] where key is a regex string that can # perform both positive ("+:" prefixed) or negative ("-:" prefixed) # matching of a module. # Default to positive match, override base quant config mode, if no # prefix is used. Value is in dict format of field key and override # value. # Negative matching will skip quantization init for this module # entirely: # non-quantized inference. More details and quantization examples can be # found at: https://github.com/ModelCloud/GPTQModel # Example: # # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9 # # last 1/4 of the layers 16-21 has 8bit and group_size 64 # dynamic = { # #`.*\.` matches the layers_node prefix # # positive match layer 10-15 # r"+:.*\.(?:1[0-5])\..*": {"bits": 8,}, # # positive match layer 16-21 # r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,}, # r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers # } self.dynamic = dynamic self.weight_bits = weight_bits self.is_sym = is_sym self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.desc_act = desc_act self.lm_head_quantized = lm_head_quantized self.full_config = full_config if (weight_bits, is_sym) not in self.TYPE_MAP: raise ValueError( "Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}" ) # (num_bits, is_sym) -> quant_type self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] def __repr__(self) -> str: return ( f"GPTQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " f"desc_act={self.desc_act}, " f"lm_head_quantized={self.lm_head_quantized}), " f"dynamic={self.dynamic}" ) def get_scaled_act_names(self) -> List[str]: """Returns the activation function names that should be post-scaled. For now, this is only used by AWQ. """ raise NotImplementedError @classmethod def get_name(cls) -> str: return "gptq_marlin" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.half, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: return 80 @classmethod def get_config_filenames(cls) -> List[str]: return ["quantize_config.json"] @classmethod def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": dynamic = cls.get_from_keys_or(config, ["dynamic"], default={}) dynamic = {} if dynamic is None else dynamic weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) desc_act = cls.get_from_keys(config, ["desc_act"]) is_sym = cls.get_from_keys(config, ["sym"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) return cls( weight_bits, group_size, desc_act, is_sym, lm_head_quantized, dynamic, config, ) @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) is_valid_user_quant = ( user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin" ) if can_convert and is_valid_user_quant: msg = ( "The model is convertible to {} during runtime." " Using {} kernel.".format(cls.get_name(), cls.get_name()) ) logger.info(msg) return cls.get_name() if can_convert and user_quant == "gptq": logger.info( "Detected that the model can run with gptq_marlin" ", however you specified quantization=gptq explicitly," " so forcing gptq. Use quantization=gptq_marlin for" " faster inference" ) return None def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[QuantizeMethodBase]: # Delay the import to avoid circular dependency from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization import get_linear_quant_method if isinstance(layer, FusedMoE): return GPTQMarlinMoEMethod(self) # TODO: re-enable after SGLang syncs with vllm >= 0.7.3 # if layer.num_experts > 32: # # For MoEs with many experts the moe_wna16 kernel is faster # return MoeWNA16Config.from_config(self.full_config).get_quant_method( # layer, prefix # ) # else: # return GPTQMarlinMoEMethod(self) return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod) @classmethod def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") group_size = quant_config.get("group_size") sym = quant_config.get("sym") desc_act = quant_config.get("desc_act") if not _is_cuda: return False if quant_method != "gptq": return False # Marlin conversion is only valid if required properties are found if num_bits is None or group_size is None or sym is None or desc_act is None: return False if (num_bits, sym) not in cls.TYPE_MAP: return False return check_marlin_supported( quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size ) class MarlinConfig(QuantizationConfig): """Config class for Marlin. Reference: https://github.com/IST-DASLab/marlin/tree/master """ def __init__( self, group_size: int, lm_head_quantized: bool, ) -> None: # Group size for the quantization. self.group_size = group_size self.lm_head_quantized = lm_head_quantized if self.group_size != 128 and self.group_size != -1: raise ValueError( "Currently, only group size 128 and -1 (channelwise) " "is supported for Marlin, but got group_size of " f"{self.group_size}" ) # 4 Bits packed into 32 bit datatype. self.pack_factor = 32 // 4 # Tile size used by marlin kernels. self.tile_size = 16 # Min out_features dim self.min_n_threads = 64 # Min in_features dim self.min_k_threads = 128 # Max parallel problems to solve at once (improves large # batch performance) self.max_parallel = 16 # Permutation length used by the marlin kernels. self.perm_len = 1024 def __repr__(self) -> str: return ( f"MarlinConfig(group_size={self.group_size}, " f"lm_head_quantized={self.lm_head_quantized})" ) @classmethod def get_name(cls) -> str: return "marlin" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.half] @classmethod # Need to figure it out def get_min_capability(cls) -> int: return 80 @classmethod def get_config_filenames(cls) -> List[str]: return ["quantize_config.json"] @classmethod def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) return cls(group_size, lm_head_quantized) @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: # compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq <=0.7.1 is_marlin_format: bool is_marlin_format = hf_quant_cfg.get( "checkpoint_format" ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False) is_valid_user_quant = ( user_quant is None or user_quant == "gptq" or user_quant == "marlin" ) if is_marlin_format and is_valid_user_quant: msg = "The model is serialized in {} format. Using {} kernel.".format( cls.get_name(), cls.get_name() ) logger.info(msg) return cls.get_name() return None def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[MarlinLinearMethod]: # Delay the import to avoid circular dependency from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead if isinstance(layer, LinearBase) or ( isinstance(layer, ParallelLMHead) and self.lm_head_quantized ): return MarlinLinearMethod(self) return None