323 lines
12 KiB
Python
323 lines
12 KiB
Python
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
|
import builtins
|
|
import inspect
|
|
import re
|
|
from copy import deepcopy
|
|
from typing import Callable, Dict, Optional, Type, Union
|
|
|
|
import torch
|
|
|
|
try:
|
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
|
from vllm.model_executor.layers.quantization.awq_marlin import (
|
|
AWQMarlinConfig,
|
|
AWQMoEMethod,
|
|
)
|
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
|
CompressedTensorsW8A8Fp8MoEMethod,
|
|
CompressedTensorsWNA16MoEMethod,
|
|
)
|
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
|
GPTQMarlinLinearMethod,
|
|
GPTQMarlinMoEMethod,
|
|
)
|
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
|
GPTQMarlin24Config,
|
|
)
|
|
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
|
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
|
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
|
|
|
VLLM_AVAILABLE = True
|
|
except ImportError:
|
|
VLLM_AVAILABLE = False
|
|
|
|
# Define empty classes as placeholders when vllm is not available
|
|
class DummyConfig:
|
|
def override_quantization_method(self, *args, **kwargs):
|
|
return None
|
|
|
|
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
|
|
DeepSpeedFPConfig
|
|
) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
|
|
MarlinConfig
|
|
) = QQQConfig = Int8TpuConfig = DummyConfig
|
|
|
|
|
|
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
from sglang.srt.layers.quantization.awq import AWQConfig
|
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
|
CompressedTensorsConfig,
|
|
)
|
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
UnquantizedEmbeddingMethod,
|
|
)
|
|
|
|
# Base quantization methods that don't depend on vllm
|
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
"fp8": Fp8Config,
|
|
"blockwise_int8": BlockInt8Config,
|
|
"modelopt": ModelOptFp8Config,
|
|
"w8a8_int8": W8A8Int8Config,
|
|
"w8a8_fp8": W8A8Fp8Config,
|
|
"compressed-tensors": CompressedTensorsConfig,
|
|
}
|
|
|
|
# VLLM-dependent quantization methods
|
|
VLLM_QUANTIZATION_METHODS = {
|
|
"aqlm": AQLMConfig,
|
|
"awq": AWQConfig,
|
|
"deepspeedfp": DeepSpeedFPConfig,
|
|
"tpu_int8": Int8TpuConfig,
|
|
"fbgemm_fp8": FBGEMMFp8Config,
|
|
"marlin": MarlinConfig,
|
|
"gguf": GGUFConfig,
|
|
"gptq_marlin_24": GPTQMarlin24Config,
|
|
"awq_marlin": AWQMarlinConfig,
|
|
"bitsandbytes": BitsAndBytesConfig,
|
|
"qqq": QQQConfig,
|
|
"experts_int8": ExpertsInt8Config,
|
|
"gptq_marlin": GPTQMarlinConfig,
|
|
"gptq": GPTQConfig,
|
|
}
|
|
|
|
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
|
|
|
|
|
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
if quantization not in QUANTIZATION_METHODS:
|
|
raise ValueError(
|
|
f"Invalid quantization method: {quantization}. "
|
|
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
|
)
|
|
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
|
raise ValueError(
|
|
f"{quantization} quantization requires some operators from vllm. "
|
|
"Pleaes install vllm by `pip install vllm==0.7.2`"
|
|
)
|
|
|
|
return QUANTIZATION_METHODS[quantization]
|
|
|
|
|
|
# Match dynamic rules with module name (prefix) and override quantize
|
|
# config if module (prefix) matches a rule
|
|
def override_config(config: QuantizationConfig, prefix: str):
|
|
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
|
if isinstance(weight_bits, int):
|
|
config.weight_bits = weight_bits
|
|
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
|
if isinstance(group_size, int):
|
|
config.group_size = group_size
|
|
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
|
if isinstance(desc_act, bool):
|
|
config.desc_act = desc_act
|
|
|
|
config.pack_factor = 32 // config.weight_bits # packed into int32
|
|
if config.get_name() == "gptq_marlin":
|
|
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
|
if isinstance(is_sym, bool):
|
|
config.is_sym = is_sym
|
|
|
|
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
|
raise ValueError(
|
|
"Unsupported quantization config: "
|
|
f"bits={config.weight_bits}, sym={config.is_sym}"
|
|
)
|
|
|
|
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
|
elif config.get_name() == "gptq":
|
|
if config.weight_bits not in [2, 3, 4, 8]:
|
|
raise ValueError(
|
|
"Currently, only 2/3/4/8-bit weight quantization is "
|
|
f"supported for GPTQ, but got {config.weight_bits} bits."
|
|
)
|
|
|
|
|
|
def get_dynamic_override(
|
|
config: QuantizationConfig,
|
|
layer_name: str,
|
|
key: Optional[str] = None,
|
|
default_value: Union[int, bool, None] = None,
|
|
) -> Union[Dict, int, bool, None]:
|
|
for pattern, pattern_dict in config.dynamic.items():
|
|
# Negative match: matched modules are excluded from quantized init
|
|
if pattern.startswith("-:"):
|
|
if re.match(pattern.removeprefix("-:"), layer_name):
|
|
return False
|
|
# Positive match: matched modules have quant properties overrides
|
|
# base quant config
|
|
elif re.match(pattern.removeprefix("+:"), layer_name):
|
|
if key is None:
|
|
return pattern_dict
|
|
else:
|
|
return pattern_dict.get(key, default_value)
|
|
return default_value
|
|
|
|
|
|
def get_linear_quant_method(
|
|
config: QuantizationConfig,
|
|
layer: torch.nn.Module,
|
|
prefix: str,
|
|
linear_method_cls: type,
|
|
):
|
|
cloned_config = deepcopy(config)
|
|
parallel_lm_head_quantized = (
|
|
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
|
)
|
|
|
|
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
|
# False = skip module, None = no override, else = Positive match
|
|
if (
|
|
get_dynamic_override( # noqa: E712
|
|
cloned_config, layer_name=prefix # noqa: E712
|
|
)
|
|
== False
|
|
): # noqa: E712
|
|
if parallel_lm_head_quantized:
|
|
return UnquantizedEmbeddingMethod()
|
|
return UnquantizedLinearMethod()
|
|
|
|
if prefix:
|
|
# Dynamic per module/layer rules may override base config
|
|
override_config(cloned_config, prefix=prefix)
|
|
|
|
return linear_method_cls(cloned_config)
|
|
return None
|
|
|
|
|
|
def gptq_get_quant_method(self, layer, prefix):
|
|
if isinstance(layer, FusedMoE):
|
|
return GPTQMarlinMoEMethod(self)
|
|
|
|
if isinstance(self, GPTQConfig):
|
|
return get_linear_quant_method(
|
|
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
|
)
|
|
elif isinstance(self, GPTQMarlinConfig):
|
|
return get_linear_quant_method(
|
|
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
|
)
|
|
return None
|
|
|
|
|
|
original_isinstance = builtins.isinstance
|
|
|
|
|
|
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
|
"""
|
|
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
|
|
can recognize sglang layers
|
|
"""
|
|
if not VLLM_AVAILABLE:
|
|
return
|
|
|
|
if reverse:
|
|
builtins.isinstance = original_isinstance
|
|
return
|
|
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm.model_executor.layers.linear import LinearBase
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
VocabParallelEmbedding,
|
|
)
|
|
|
|
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
|
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
|
)
|
|
|
|
def patched_isinstance(obj, classinfo):
|
|
if classinfo is LinearBase:
|
|
return original_isinstance(obj, PatchedLinearBase)
|
|
if classinfo is FusedMoE:
|
|
return original_isinstance(obj, PatchedFusedMoE)
|
|
if classinfo is VocabParallelEmbedding:
|
|
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
|
return original_isinstance(obj, classinfo)
|
|
|
|
builtins.isinstance = patched_isinstance
|
|
|
|
|
|
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
"""
|
|
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
|
Convert sglang arguments to vllm arguments.
|
|
"""
|
|
original_apply = class_obj.apply
|
|
sig = inspect.signature(original_apply)
|
|
param_names = list(sig.parameters.keys())
|
|
has_correction_bias = "e_score_correction_bias" in param_names
|
|
|
|
def new_apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
activation: str = "silu",
|
|
inplace: bool = True,
|
|
no_combine: bool = False,
|
|
):
|
|
assert activation == "silu"
|
|
assert inplace and not no_combine
|
|
|
|
kwargs = {
|
|
"self": self,
|
|
"layer": layer,
|
|
"x": x,
|
|
"router_logits": router_logits,
|
|
"top_k": top_k,
|
|
"renormalize": renormalize,
|
|
"use_grouped_topk": use_grouped_topk,
|
|
"topk_group": topk_group,
|
|
"num_expert_group": num_expert_group,
|
|
"custom_routing_function": custom_routing_function,
|
|
}
|
|
if correction_bias is not None:
|
|
if not has_correction_bias:
|
|
raise ValueError(
|
|
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
|
)
|
|
kwargs["e_score_correction_bias"] = correction_bias
|
|
return original_apply(**kwargs)
|
|
|
|
setattr(class_obj, "apply", new_apply)
|
|
|
|
|
|
def monkey_patch_quant_configs():
|
|
"""Apply all monkey patches in one place."""
|
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
|
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
|
|
|
monkey_patch_moe_apply(AWQMoEMethod)
|
|
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
|
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
|
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
|
|
|
|
|
# Only apply monkey patches if vllm is available
|
|
if VLLM_AVAILABLE:
|
|
monkey_patch_quant_configs()
|