176 lines
5.4 KiB
Python
176 lines
5.4 KiB
Python
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/quantization/test_awq_triton.py
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
unittest version of the AWQ Triton kernel tests.
|
|
|
|
Run with:
|
|
python -m unittest test_awq_dequant.py
|
|
"""
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from sglang.srt.layers.quantization.awq_triton import (
|
|
AWQ_TRITON_SUPPORTED_GROUP_SIZES,
|
|
awq_dequantize_triton,
|
|
awq_gemm_triton,
|
|
)
|
|
from sglang.test.test_utils import CustomTestCase
|
|
|
|
device = "cuda"
|
|
|
|
|
|
def reverse_awq_order(t: torch.Tensor) -> torch.Tensor:
|
|
bits = 4
|
|
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
|
idx = torch.arange(t.shape[-1], dtype=torch.int32, device=t.device)
|
|
idx = idx.view(-1, 32 // bits)[:, AWQ_REVERSE_ORDER].view(-1)
|
|
return (t[:, idx] & 0xF).contiguous()
|
|
|
|
|
|
def awq_dequantize_torch(
|
|
qweight: torch.Tensor,
|
|
scales: torch.Tensor,
|
|
qzeros: torch.Tensor,
|
|
group_size: int,
|
|
) -> torch.Tensor:
|
|
if group_size == -1:
|
|
group_size = qweight.shape[0]
|
|
|
|
bits = 4
|
|
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
|
|
|
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
|
torch.int8
|
|
)
|
|
iweights = reverse_awq_order(iweights.view(iweights.shape[0], -1))
|
|
|
|
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
|
|
torch.int8
|
|
)
|
|
zeros = reverse_awq_order(zeros.view(qzeros.shape[0], -1))
|
|
|
|
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
|
|
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
|
|
|
scales = scales.repeat_interleave(group_size, dim=0)
|
|
zeros = zeros.repeat_interleave(group_size, dim=0)
|
|
return (iweights - zeros) * scales
|
|
|
|
|
|
class TestAWQTriton(CustomTestCase):
|
|
def test_dequantize(self):
|
|
rows_list = [3584, 18944, 128, 256, 512, 1024]
|
|
cols_list = [448, 576, 4736, 16, 32, 64, 128]
|
|
|
|
for qweight_rows in rows_list:
|
|
for qweight_cols in cols_list:
|
|
for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES:
|
|
with self.subTest(
|
|
rows=qweight_rows, cols=qweight_cols, g=group_size
|
|
):
|
|
self._run_dequant_case(
|
|
qweight_rows=qweight_rows,
|
|
qweight_cols=qweight_cols,
|
|
group_size=group_size,
|
|
)
|
|
|
|
def _run_dequant_case(self, qweight_rows, qweight_cols, group_size):
|
|
if group_size == -1:
|
|
group_size = qweight_rows
|
|
|
|
torch.manual_seed(0)
|
|
|
|
qweight = torch.randint(
|
|
0,
|
|
torch.iinfo(torch.int32).max,
|
|
(qweight_rows, qweight_cols),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
scales = torch.rand(
|
|
qweight_rows // group_size,
|
|
qweight_cols * 8,
|
|
dtype=torch.float16,
|
|
device=device,
|
|
)
|
|
zeros = torch.randint(
|
|
0,
|
|
torch.iinfo(torch.int32).max,
|
|
(qweight_rows // group_size, qweight_cols),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
|
|
ref = awq_dequantize_torch(qweight, scales, zeros, group_size)
|
|
tri = awq_dequantize_triton(qweight, scales, zeros)
|
|
|
|
# sanity
|
|
self.assertFalse(torch.any(torch.isinf(tri)) or torch.any(torch.isnan(tri)))
|
|
torch.testing.assert_close(ref, tri)
|
|
|
|
# GEMM
|
|
def test_gemm(self):
|
|
N_list = [1, 2, 4, 8, 14, 17, 23, 32]
|
|
K_list = [128]
|
|
M_list = [16, 24, 32]
|
|
splitK_list = [1, 8]
|
|
|
|
for N in N_list:
|
|
for K in K_list:
|
|
for M in M_list:
|
|
for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES:
|
|
for splitK in splitK_list:
|
|
with self.subTest(N=N, K=K, M=M, g=group_size, sk=splitK):
|
|
self._run_gemm_case(
|
|
N=N,
|
|
K=K,
|
|
M=M,
|
|
group_size=group_size,
|
|
splitK=splitK,
|
|
)
|
|
|
|
def _run_gemm_case(self, N, K, M, group_size, splitK):
|
|
if group_size == -1:
|
|
group_size = K
|
|
|
|
torch.manual_seed(0)
|
|
|
|
x = torch.rand((N, K), dtype=torch.float32, device=device)
|
|
qweight = torch.randint(
|
|
0,
|
|
torch.iinfo(torch.int32).max,
|
|
(K, M // 8),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
qzeros = torch.randint(
|
|
0,
|
|
torch.iinfo(torch.int32).max,
|
|
(K // group_size, M // 8),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
scales = torch.rand((K // group_size, M), dtype=torch.float32, device=device)
|
|
|
|
tri_out = awq_gemm_triton(x, qweight, scales, qzeros, splitK)
|
|
|
|
self.assertFalse(
|
|
torch.any(torch.isinf(tri_out)) or torch.any(torch.isnan(tri_out))
|
|
)
|
|
|
|
# dequantize & compare
|
|
w_deq = awq_dequantize_triton(qweight, scales, qzeros)
|
|
ref_out = torch.matmul(x, w_deq)
|
|
|
|
self.assertFalse(
|
|
torch.any(torch.isinf(ref_out)) or torch.any(torch.isnan(ref_out))
|
|
)
|
|
|
|
torch.testing.assert_close(tri_out.cpu(), ref_out.cpu(), atol=1e-1, rtol=1e-1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|