154 lines
5.5 KiB
Python
154 lines
5.5 KiB
Python
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
|
|
|
from types import MappingProxyType
|
|
from typing import List, Mapping, Tuple, Union
|
|
|
|
import torch
|
|
|
|
from sglang.srt.utils import is_cuda
|
|
|
|
_is_cuda = is_cuda()
|
|
|
|
if _is_cuda:
|
|
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
|
else:
|
|
from vllm import _custom_ops as vllm_ops
|
|
|
|
|
|
def is_fp8_fnuz() -> bool:
|
|
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
|
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
|
|
|
|
|
def is_layer_skipped(
|
|
prefix: str,
|
|
ignored_layers: List[str],
|
|
fused_mapping: Mapping[str, List[str]] = MappingProxyType({}),
|
|
) -> bool:
|
|
# prefix: model.layers.0.self_attn.q_proj
|
|
# proj_name: q_proj
|
|
proj_name = prefix.split(".")[-1]
|
|
|
|
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
|
# in the safetensors checkpoint. So, we convert the name
|
|
# from the fused version to unfused + check to make sure that
|
|
# each shard of the fused layer has the same scheme.
|
|
if proj_name in fused_mapping:
|
|
shard_prefixes = [
|
|
prefix.replace(proj_name, shard_proj_name)
|
|
for shard_proj_name in fused_mapping[proj_name]
|
|
]
|
|
|
|
is_skipped = None
|
|
for shard_prefix in shard_prefixes:
|
|
is_shard_skipped = shard_prefix in ignored_layers
|
|
|
|
if is_skipped is None:
|
|
is_skipped = is_shard_skipped
|
|
elif is_shard_skipped != is_skipped:
|
|
raise ValueError(
|
|
f"Detected some but not all shards of {prefix} "
|
|
"are quantized. All shards of fused layers "
|
|
"to have the same precision."
|
|
)
|
|
else:
|
|
is_skipped = prefix in ignored_layers
|
|
|
|
assert is_skipped is not None
|
|
return is_skipped
|
|
|
|
|
|
def per_tensor_dequantize(
|
|
tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]
|
|
) -> torch.Tensor:
|
|
fake_qweight = tensor.to(torch.float16)
|
|
dq_weight = fake_qweight * inv_scale
|
|
return dq_weight
|
|
|
|
|
|
def all_close_1d(x: torch.Tensor) -> bool:
|
|
assert len(x.shape) == 1
|
|
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
|
|
|
|
|
def convert_to_channelwise(
|
|
weight_scale: torch.Tensor, logical_widths: List[int]
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# Create channelwise buffer
|
|
weight_scale_channel = torch.empty(
|
|
(sum(logical_widths), 1), dtype=torch.float32, device=weight_scale.device
|
|
)
|
|
|
|
# Handle scalar tensor case: broadcast same scale to all channels
|
|
if weight_scale.dim() == 0:
|
|
weight_scale_channel.fill_(weight_scale.item())
|
|
return weight_scale_channel
|
|
|
|
# Expand each scale to match the size of each logical matrix.
|
|
start = 0
|
|
for idx, logical_width in enumerate(logical_widths):
|
|
end = start + logical_width
|
|
weight_scale_channel[start:end, :] = weight_scale[idx]
|
|
start = end
|
|
|
|
return weight_scale_channel
|
|
|
|
|
|
def requantize_with_max_scale(
|
|
weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int]
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# Max scale to be used for requanitzation.
|
|
max_w_scale = weight_scale.max()
|
|
|
|
# QKV / MLP is fused in the on disk checkpoint if any of the
|
|
# weight scales are still set to the default since we initialize
|
|
# N weight scales for N shards but we only load 1 weight scale
|
|
# from disk in this case. Skip requantization in this case (since)
|
|
# we already are quantized with the single scale.
|
|
# * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8
|
|
unfused_module_in_checkpoint = (
|
|
weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min
|
|
)
|
|
|
|
# If unfused checkpoint, need requanize with the single scale.
|
|
if unfused_module_in_checkpoint:
|
|
start = 0
|
|
for idx, logical_width in enumerate(logical_widths):
|
|
end = start + logical_width
|
|
weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx])
|
|
if _is_cuda:
|
|
weight[start:end, :], _ = sgl_scaled_fp8_quant(weight_dq, max_w_scale)
|
|
else:
|
|
weight[start:end, :], _ = vllm_ops.scaled_fp8_quant(
|
|
weight_dq, max_w_scale
|
|
)
|
|
start = end
|
|
|
|
return max_w_scale, weight
|
|
|
|
|
|
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/layer_utils.py
|
|
# Newly generated tensors need to replace existing tensors that are
|
|
# already registered as parameters by vLLM (and won't be freed)
|
|
def replace_parameter(
|
|
mod: torch.nn.Module, name: str, new: Union[torch.Tensor, torch.nn.Parameter]
|
|
) -> None:
|
|
|
|
old = getattr(mod, name)
|
|
if (
|
|
type(old) is type(new)
|
|
and old.dtype == new.dtype
|
|
and old.untyped_storage().nbytes() == new.untyped_storage().nbytes()
|
|
):
|
|
# If we can just update in-place to avoid re-registering
|
|
# can be faster if the underlying storage is the same
|
|
update_tensor_inplace(old, new)
|
|
else:
|
|
# Fallback re-register parameter, convert to Parameter if necessary
|
|
# this not only ensures we don't register a tensor as a parameter, but
|
|
# also ensures that all parameter subclasses get re-registered as
|
|
# parameters for `torch.compile` compatibility
|
|
if not isinstance(new, torch.nn.Parameter):
|
|
new = torch.nn.Parameter(new, requires_grad=False)
|
|
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|