sglang0.4.5.post1/python/sglang/srt/layers/quantization/__init__.py

323 lines
12 KiB
Python

# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
import builtins
import inspect
import re
from copy import deepcopy
from typing import Callable, Dict, Optional, Type, Union
import torch
try:
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
AWQMarlinConfig,
AWQMoEMethod,
)
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
CompressedTensorsW8A8Fp8MoEMethod,
CompressedTensorsWNA16MoEMethod,
)
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config,
)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
VLLM_AVAILABLE = True
except ImportError:
VLLM_AVAILABLE = False
# Define empty classes as placeholders when vllm is not available
class DummyConfig:
def override_quantization_method(self, *args, **kwargs):
return None
AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
DeepSpeedFPConfig
) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
MarlinConfig
) = QQQConfig = Int8TpuConfig = DummyConfig
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.awq import AWQConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
CompressedTensorsConfig,
)
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
UnquantizedEmbeddingMethod,
)
# Base quantization methods that don't depend on vllm
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"fp8": Fp8Config,
"blockwise_int8": BlockInt8Config,
"modelopt": ModelOptFp8Config,
"w8a8_int8": W8A8Int8Config,
"w8a8_fp8": W8A8Fp8Config,
"compressed-tensors": CompressedTensorsConfig,
}
# VLLM-dependent quantization methods
VLLM_QUANTIZATION_METHODS = {
"aqlm": AQLMConfig,
"awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fbgemm_fp8": FBGEMMFp8Config,
"marlin": MarlinConfig,
"gguf": GGUFConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"awq_marlin": AWQMarlinConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"experts_int8": ExpertsInt8Config,
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig,
}
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
if quantization not in QUANTIZATION_METHODS:
raise ValueError(
f"Invalid quantization method: {quantization}. "
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
)
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
raise ValueError(
f"{quantization} quantization requires some operators from vllm. "
"Pleaes install vllm by `pip install vllm==0.7.2`"
)
return QUANTIZATION_METHODS[quantization]
# Match dynamic rules with module name (prefix) and override quantize
# config if module (prefix) matches a rule
def override_config(config: QuantizationConfig, prefix: str):
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
if isinstance(weight_bits, int):
config.weight_bits = weight_bits
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
if isinstance(group_size, int):
config.group_size = group_size
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
if isinstance(desc_act, bool):
config.desc_act = desc_act
config.pack_factor = 32 // config.weight_bits # packed into int32
if config.get_name() == "gptq_marlin":
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
if isinstance(is_sym, bool):
config.is_sym = is_sym
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
raise ValueError(
"Unsupported quantization config: "
f"bits={config.weight_bits}, sym={config.is_sym}"
)
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
elif config.get_name() == "gptq":
if config.weight_bits not in [2, 3, 4, 8]:
raise ValueError(
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {config.weight_bits} bits."
)
def get_dynamic_override(
config: QuantizationConfig,
layer_name: str,
key: Optional[str] = None,
default_value: Union[int, bool, None] = None,
) -> Union[Dict, int, bool, None]:
for pattern, pattern_dict in config.dynamic.items():
# Negative match: matched modules are excluded from quantized init
if pattern.startswith("-:"):
if re.match(pattern.removeprefix("-:"), layer_name):
return False
# Positive match: matched modules have quant properties overrides
# base quant config
elif re.match(pattern.removeprefix("+:"), layer_name):
if key is None:
return pattern_dict
else:
return pattern_dict.get(key, default_value)
return default_value
def get_linear_quant_method(
config: QuantizationConfig,
layer: torch.nn.Module,
prefix: str,
linear_method_cls: type,
):
cloned_config = deepcopy(config)
parallel_lm_head_quantized = (
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
)
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
# False = skip module, None = no override, else = Positive match
if (
get_dynamic_override( # noqa: E712
cloned_config, layer_name=prefix # noqa: E712
)
== False
): # noqa: E712
if parallel_lm_head_quantized:
return UnquantizedEmbeddingMethod()
return UnquantizedLinearMethod()
if prefix:
# Dynamic per module/layer rules may override base config
override_config(cloned_config, prefix=prefix)
return linear_method_cls(cloned_config)
return None
def gptq_get_quant_method(self, layer, prefix):
if isinstance(layer, FusedMoE):
return GPTQMarlinMoEMethod(self)
if isinstance(self, GPTQConfig):
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
)
elif isinstance(self, GPTQMarlinConfig):
return get_linear_quant_method(
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
)
return None
original_isinstance = builtins.isinstance
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
"""
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
can recognize sglang layers
"""
if not VLLM_AVAILABLE:
return
if reverse:
builtins.isinstance = original_isinstance
return
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import LinearBase
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
from sglang.srt.layers.vocab_parallel_embedding import (
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
)
def patched_isinstance(obj, classinfo):
if classinfo is LinearBase:
return original_isinstance(obj, PatchedLinearBase)
if classinfo is FusedMoE:
return original_isinstance(obj, PatchedFusedMoE)
if classinfo is VocabParallelEmbedding:
return original_isinstance(obj, PatchedVocabParallelEmbedding)
return original_isinstance(obj, classinfo)
builtins.isinstance = patched_isinstance
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
"""
Monkey patch the apply function of vllm's FusedMoEMethodBase.
Convert sglang arguments to vllm arguments.
"""
original_apply = class_obj.apply
sig = inspect.signature(original_apply)
param_names = list(sig.parameters.keys())
has_correction_bias = "e_score_correction_bias" in param_names
def new_apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
):
assert activation == "silu"
assert inplace and not no_combine
kwargs = {
"self": self,
"layer": layer,
"x": x,
"router_logits": router_logits,
"top_k": top_k,
"renormalize": renormalize,
"use_grouped_topk": use_grouped_topk,
"topk_group": topk_group,
"num_expert_group": num_expert_group,
"custom_routing_function": custom_routing_function,
}
if correction_bias is not None:
if not has_correction_bias:
raise ValueError(
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
)
kwargs["e_score_correction_bias"] = correction_bias
return original_apply(**kwargs)
setattr(class_obj, "apply", new_apply)
def monkey_patch_quant_configs():
"""Apply all monkey patches in one place."""
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
monkey_patch_moe_apply(AWQMoEMethod)
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
# Only apply monkey patches if vllm is available
if VLLM_AVAILABLE:
monkey_patch_quant_configs()