sglang_v0.5.2/sglang/test/srt/quant/test_awq_dequant.py

176 lines
5.4 KiB
Python

# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/quantization/test_awq_triton.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
unittest version of the AWQ Triton kernel tests.
Run with:
python -m unittest test_awq_dequant.py
"""
import unittest
import torch
from sglang.srt.layers.quantization.awq_triton import (
AWQ_TRITON_SUPPORTED_GROUP_SIZES,
awq_dequantize_triton,
awq_gemm_triton,
)
from sglang.test.test_utils import CustomTestCase
device = "cuda"
def reverse_awq_order(t: torch.Tensor) -> torch.Tensor:
bits = 4
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
idx = torch.arange(t.shape[-1], dtype=torch.int32, device=t.device)
idx = idx.view(-1, 32 // bits)[:, AWQ_REVERSE_ORDER].view(-1)
return (t[:, idx] & 0xF).contiguous()
def awq_dequantize_torch(
qweight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
group_size: int,
) -> torch.Tensor:
if group_size == -1:
group_size = qweight.shape[0]
bits = 4
shifts = torch.arange(0, 32, bits, device=qzeros.device)
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8
)
iweights = reverse_awq_order(iweights.view(iweights.shape[0], -1))
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
torch.int8
)
zeros = reverse_awq_order(zeros.view(qzeros.shape[0], -1))
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
scales = scales.repeat_interleave(group_size, dim=0)
zeros = zeros.repeat_interleave(group_size, dim=0)
return (iweights - zeros) * scales
class TestAWQTriton(CustomTestCase):
def test_dequantize(self):
rows_list = [3584, 18944, 128, 256, 512, 1024]
cols_list = [448, 576, 4736, 16, 32, 64, 128]
for qweight_rows in rows_list:
for qweight_cols in cols_list:
for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES:
with self.subTest(
rows=qweight_rows, cols=qweight_cols, g=group_size
):
self._run_dequant_case(
qweight_rows=qweight_rows,
qweight_cols=qweight_cols,
group_size=group_size,
)
def _run_dequant_case(self, qweight_rows, qweight_cols, group_size):
if group_size == -1:
group_size = qweight_rows
torch.manual_seed(0)
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_rows, qweight_cols),
dtype=torch.int32,
device=device,
)
scales = torch.rand(
qweight_rows // group_size,
qweight_cols * 8,
dtype=torch.float16,
device=device,
)
zeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(qweight_rows // group_size, qweight_cols),
dtype=torch.int32,
device=device,
)
ref = awq_dequantize_torch(qweight, scales, zeros, group_size)
tri = awq_dequantize_triton(qweight, scales, zeros)
# sanity
self.assertFalse(torch.any(torch.isinf(tri)) or torch.any(torch.isnan(tri)))
torch.testing.assert_close(ref, tri)
# GEMM
def test_gemm(self):
N_list = [1, 2, 4, 8, 14, 17, 23, 32]
K_list = [128]
M_list = [16, 24, 32]
splitK_list = [1, 8]
for N in N_list:
for K in K_list:
for M in M_list:
for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES:
for splitK in splitK_list:
with self.subTest(N=N, K=K, M=M, g=group_size, sk=splitK):
self._run_gemm_case(
N=N,
K=K,
M=M,
group_size=group_size,
splitK=splitK,
)
def _run_gemm_case(self, N, K, M, group_size, splitK):
if group_size == -1:
group_size = K
torch.manual_seed(0)
x = torch.rand((N, K), dtype=torch.float32, device=device)
qweight = torch.randint(
0,
torch.iinfo(torch.int32).max,
(K, M // 8),
dtype=torch.int32,
device=device,
)
qzeros = torch.randint(
0,
torch.iinfo(torch.int32).max,
(K // group_size, M // 8),
dtype=torch.int32,
device=device,
)
scales = torch.rand((K // group_size, M), dtype=torch.float32, device=device)
tri_out = awq_gemm_triton(x, qweight, scales, qzeros, splitK)
self.assertFalse(
torch.any(torch.isinf(tri_out)) or torch.any(torch.isnan(tri_out))
)
# dequantize & compare
w_deq = awq_dequantize_triton(qweight, scales, qzeros)
ref_out = torch.matmul(x, w_deq)
self.assertFalse(
torch.any(torch.isinf(ref_out)) or torch.any(torch.isnan(ref_out))
)
torch.testing.assert_close(tri_out.cpu(), ref_out.cpu(), atol=1e-1, rtol=1e-1)
if __name__ == "__main__":
unittest.main(verbosity=2)