434 lines
15 KiB
Python
434 lines
15 KiB
Python
import logging
|
|
from fractions import Fraction
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import torch
|
|
|
|
from sglang.srt.layers.linear import LinearBase
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
from sglang.srt.utils import is_cuda
|
|
|
|
_is_cuda = is_cuda()
|
|
|
|
try:
|
|
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
|
GPTQMarlinLinearMethod,
|
|
GPTQMarlinMoEMethod,
|
|
)
|
|
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
|
check_marlin_supported,
|
|
)
|
|
from vllm.scalar_type import scalar_types
|
|
|
|
VLLM_AVAILABLE = True
|
|
except ImportError:
|
|
VLLM_AVAILABLE = False
|
|
|
|
GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
|
|
|
|
class scalar_types:
|
|
uint4b8 = "uint4b8"
|
|
uint8b128 = "uint8b128"
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GPTQConfig(QuantizationConfig):
|
|
"""Config class for GPTQ.
|
|
|
|
Reference: https://arxiv.org/abs/2210.17323
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
weight_bits: int,
|
|
group_size: int,
|
|
desc_act: bool,
|
|
lm_head_quantized: bool,
|
|
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
|
) -> None:
|
|
# GPTQModel use `dynamic` config property to allow per module
|
|
# quantization config so each module can be individually optimized.
|
|
# Format is Dict[str, Dict] where key is a regex string that can
|
|
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
|
# matching of a module.
|
|
# Default to positive match, override base quant config mode, if no
|
|
# prefix is used. Value is in dict format of field key and override
|
|
# value.
|
|
# Negative matching will skip quantization init for this module
|
|
# entirely:
|
|
# non-quantized inference. More details and quantization examples can be
|
|
# found at: https://github.com/ModelCloud/GPTQModel
|
|
# Example:
|
|
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
|
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
|
# dynamic = {
|
|
# #`.*\.` matches the layers_node prefix
|
|
# # positive match layer 10-15
|
|
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
|
# # positive match layer 16-21
|
|
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
|
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
|
# }
|
|
super().__init__()
|
|
self.dynamic = dynamic
|
|
|
|
self.weight_bits = weight_bits
|
|
self.group_size = group_size
|
|
self.desc_act = desc_act
|
|
self.lm_head_quantized = lm_head_quantized
|
|
self.pack_factor = Fraction(32, self.weight_bits)
|
|
if self.weight_bits not in [2, 3, 4, 8]:
|
|
raise ValueError(
|
|
"Currently, only 2/3/4/8-bit weight quantization is "
|
|
f"supported for GPTQ, but got {self.weight_bits} bits."
|
|
)
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"GPTQConfig(weight_bits={self.weight_bits}, "
|
|
f"group_size={self.group_size}, "
|
|
f"desc_act={self.desc_act}),"
|
|
f"lm_head_quantized={self.lm_head_quantized}), "
|
|
f"dynamic={self.dynamic}"
|
|
)
|
|
|
|
def get_scaled_act_names(self) -> List[str]:
|
|
"""Returns the activation function names that should be post-scaled.
|
|
|
|
For now, this is only used by AWQ.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def get_name(cls) -> str:
|
|
return "gptq"
|
|
|
|
@classmethod
|
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
|
return [torch.half]
|
|
|
|
@classmethod
|
|
# Need to figure it out
|
|
def get_min_capability(cls) -> int:
|
|
return 60
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> List[str]:
|
|
return ["quantize_config.json"]
|
|
|
|
@classmethod
|
|
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
|
|
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
|
dynamic = {} if dynamic is None else dynamic
|
|
|
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
|
group_size = cls.get_from_keys(config, ["group_size"])
|
|
desc_act = cls.get_from_keys(config, ["desc_act"])
|
|
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
|
return cls(weight_bits, group_size, desc_act, lm_head_quantized, dynamic)
|
|
|
|
def get_quant_method(
|
|
self, layer: torch.nn.Module, prefix: str
|
|
) -> Optional[GPTQLinearMethod]:
|
|
# Delay the import to avoid circular dependency
|
|
from sglang.srt.layers.quantization import get_linear_quant_method
|
|
|
|
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
|
|
|
|
|
class GPTQMarlinConfig(QuantizationConfig):
|
|
"""Config class for GPTQ Marlin"""
|
|
|
|
# (num_bits, is_sym) -> quant_type
|
|
TYPE_MAP = {
|
|
(4, True): scalar_types.uint4b8,
|
|
(8, True): scalar_types.uint8b128,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
weight_bits: int,
|
|
group_size: int,
|
|
desc_act: bool,
|
|
is_sym: bool,
|
|
lm_head_quantized: bool,
|
|
dynamic: Dict[str, Dict[str, Union[int, bool]]],
|
|
full_config: Dict[str, Any],
|
|
) -> None:
|
|
super().__init__()
|
|
if desc_act and group_size == -1:
|
|
# In this case, act_order == True is the same as act_order == False
|
|
# (since we have only one group per output channel)
|
|
desc_act = False
|
|
|
|
# GPTQModel use `dynamic` config property to allow per module
|
|
# quantization config so each module can be individually optimized.
|
|
# Format is Dict[str, Dict] where key is a regex string that can
|
|
# perform both positive ("+:" prefixed) or negative ("-:" prefixed)
|
|
# matching of a module.
|
|
# Default to positive match, override base quant config mode, if no
|
|
# prefix is used. Value is in dict format of field key and override
|
|
# value.
|
|
# Negative matching will skip quantization init for this module
|
|
# entirely:
|
|
# non-quantized inference. More details and quantization examples can be
|
|
# found at: https://github.com/ModelCloud/GPTQModel
|
|
# Example:
|
|
# # last 1/2 of the layers 10-21 has 8bit vs 4bit for 0-9
|
|
# # last 1/4 of the layers 16-21 has 8bit and group_size 64
|
|
# dynamic = {
|
|
# #`.*\.` matches the layers_node prefix
|
|
# # positive match layer 10-15
|
|
# r"+:.*\.(?:1[0-5])\..*": {"bits": 8,},
|
|
# # positive match layer 16-21
|
|
# r"+:.*\.(?:1[6-9]|20|21)\..*": {"bits": 8, "group_size": 64,},
|
|
# r"-:.*\.moe\..*": {}, # negative match (skip) all `moe` layers
|
|
# }
|
|
self.dynamic = dynamic
|
|
|
|
self.weight_bits = weight_bits
|
|
self.is_sym = is_sym
|
|
|
|
self.pack_factor = 32 // weight_bits # packed into int32
|
|
self.group_size = group_size
|
|
self.desc_act = desc_act
|
|
self.lm_head_quantized = lm_head_quantized
|
|
self.full_config = full_config
|
|
|
|
if (weight_bits, is_sym) not in self.TYPE_MAP:
|
|
raise ValueError(
|
|
"Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
|
|
)
|
|
|
|
# (num_bits, is_sym) -> quant_type
|
|
self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"GPTQMarlinConfig(quant_type={self.quant_type}, "
|
|
f"group_size={self.group_size}, "
|
|
f"desc_act={self.desc_act}, "
|
|
f"lm_head_quantized={self.lm_head_quantized}), "
|
|
f"dynamic={self.dynamic}"
|
|
)
|
|
|
|
def get_scaled_act_names(self) -> List[str]:
|
|
"""Returns the activation function names that should be post-scaled.
|
|
|
|
For now, this is only used by AWQ.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
@classmethod
|
|
def get_name(cls) -> str:
|
|
return "gptq_marlin"
|
|
|
|
@classmethod
|
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
|
return [torch.half, torch.bfloat16]
|
|
|
|
@classmethod
|
|
def get_min_capability(cls) -> int:
|
|
return 80
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> List[str]:
|
|
return ["quantize_config.json"]
|
|
|
|
@classmethod
|
|
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
|
|
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
|
dynamic = {} if dynamic is None else dynamic
|
|
|
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
|
group_size = cls.get_from_keys(config, ["group_size"])
|
|
desc_act = cls.get_from_keys(config, ["desc_act"])
|
|
is_sym = cls.get_from_keys(config, ["sym"])
|
|
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
|
return cls(
|
|
weight_bits,
|
|
group_size,
|
|
desc_act,
|
|
is_sym,
|
|
lm_head_quantized,
|
|
dynamic,
|
|
config,
|
|
)
|
|
|
|
@classmethod
|
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
|
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
|
|
|
is_valid_user_quant = (
|
|
user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
|
|
)
|
|
|
|
if can_convert and is_valid_user_quant:
|
|
msg = (
|
|
"The model is convertible to {} during runtime."
|
|
" Using {} kernel.".format(cls.get_name(), cls.get_name())
|
|
)
|
|
logger.info(msg)
|
|
return cls.get_name()
|
|
|
|
if can_convert and user_quant == "gptq":
|
|
logger.info(
|
|
"Detected that the model can run with gptq_marlin"
|
|
", however you specified quantization=gptq explicitly,"
|
|
" so forcing gptq. Use quantization=gptq_marlin for"
|
|
" faster inference"
|
|
)
|
|
return None
|
|
|
|
def get_quant_method(
|
|
self, layer: torch.nn.Module, prefix: str
|
|
) -> Optional[QuantizeMethodBase]:
|
|
# Delay the import to avoid circular dependency
|
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
|
from sglang.srt.layers.quantization import get_linear_quant_method
|
|
|
|
if isinstance(layer, FusedMoE):
|
|
return GPTQMarlinMoEMethod(self)
|
|
# TODO: re-enable after SGLang syncs with vllm >= 0.7.3
|
|
# if layer.num_experts > 32:
|
|
# # For MoEs with many experts the moe_wna16 kernel is faster
|
|
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
|
# layer, prefix
|
|
# )
|
|
# else:
|
|
# return GPTQMarlinMoEMethod(self)
|
|
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
|
|
|
|
@classmethod
|
|
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
|
|
quant_method = quant_config.get("quant_method", "").lower()
|
|
num_bits = quant_config.get("bits")
|
|
group_size = quant_config.get("group_size")
|
|
sym = quant_config.get("sym")
|
|
desc_act = quant_config.get("desc_act")
|
|
|
|
if not _is_cuda:
|
|
return False
|
|
|
|
if quant_method != "gptq":
|
|
return False
|
|
|
|
# Marlin conversion is only valid if required properties are found
|
|
if num_bits is None or group_size is None or sym is None or desc_act is None:
|
|
return False
|
|
|
|
if (num_bits, sym) not in cls.TYPE_MAP:
|
|
return False
|
|
|
|
return check_marlin_supported(
|
|
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
|
|
)
|
|
|
|
|
|
class MarlinConfig(QuantizationConfig):
|
|
"""Config class for Marlin.
|
|
|
|
Reference: https://github.com/IST-DASLab/marlin/tree/master
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
group_size: int,
|
|
lm_head_quantized: bool,
|
|
) -> None:
|
|
# Group size for the quantization.
|
|
self.group_size = group_size
|
|
self.lm_head_quantized = lm_head_quantized
|
|
if self.group_size != 128 and self.group_size != -1:
|
|
raise ValueError(
|
|
"Currently, only group size 128 and -1 (channelwise) "
|
|
"is supported for Marlin, but got group_size of "
|
|
f"{self.group_size}"
|
|
)
|
|
|
|
# 4 Bits packed into 32 bit datatype.
|
|
self.pack_factor = 32 // 4
|
|
|
|
# Tile size used by marlin kernels.
|
|
self.tile_size = 16
|
|
|
|
# Min out_features dim
|
|
self.min_n_threads = 64
|
|
|
|
# Min in_features dim
|
|
self.min_k_threads = 128
|
|
|
|
# Max parallel problems to solve at once (improves large
|
|
# batch performance)
|
|
self.max_parallel = 16
|
|
|
|
# Permutation length used by the marlin kernels.
|
|
self.perm_len = 1024
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"MarlinConfig(group_size={self.group_size}, "
|
|
f"lm_head_quantized={self.lm_head_quantized})"
|
|
)
|
|
|
|
@classmethod
|
|
def get_name(cls) -> str:
|
|
return "marlin"
|
|
|
|
@classmethod
|
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
|
return [torch.half]
|
|
|
|
@classmethod
|
|
# Need to figure it out
|
|
def get_min_capability(cls) -> int:
|
|
return 80
|
|
|
|
@classmethod
|
|
def get_config_filenames(cls) -> List[str]:
|
|
return ["quantize_config.json"]
|
|
|
|
@classmethod
|
|
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
|
|
group_size = cls.get_from_keys(config, ["group_size"])
|
|
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
|
return cls(group_size, lm_head_quantized)
|
|
|
|
@classmethod
|
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
|
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
|
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
|
is_marlin_format = hf_quant_cfg.get(
|
|
"checkpoint_format"
|
|
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
|
|
|
|
is_valid_user_quant = (
|
|
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
|
)
|
|
|
|
if is_marlin_format and is_valid_user_quant:
|
|
msg = "The model is serialized in {} format. Using {} kernel.".format(
|
|
cls.get_name(), cls.get_name()
|
|
)
|
|
logger.info(msg)
|
|
return cls.get_name()
|
|
|
|
return None
|
|
|
|
def get_quant_method(
|
|
self, layer: torch.nn.Module, prefix: str
|
|
) -> Optional[MarlinLinearMethod]:
|
|
# Delay the import to avoid circular dependency
|
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
|
|
|
if isinstance(layer, LinearBase) or (
|
|
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
|
):
|
|
return MarlinLinearMethod(self)
|
|
return None
|