116 lines
3.2 KiB
Python
116 lines
3.2 KiB
Python
import itertools
|
|
from typing import Optional, Tuple
|
|
|
|
import pytest
|
|
import torch
|
|
from sgl_kernel import awq_dequantize
|
|
|
|
|
|
def reverse_awq_order(t: torch.Tensor):
|
|
bits = 4
|
|
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
|
reverse_order_tensor = torch.arange(
|
|
t.shape[-1],
|
|
dtype=torch.int32,
|
|
device=t.device,
|
|
)
|
|
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
|
|
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
|
|
reverse_order_tensor = reverse_order_tensor.view(-1)
|
|
|
|
t = t[:, reverse_order_tensor] & 0xF
|
|
return t
|
|
|
|
|
|
# qweights - [R , C // 8], int32
|
|
# scales - [R // G, C ], float16
|
|
# zeros - [R // G, C // 8], int32
|
|
def awq_dequantize_torch(
|
|
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor, group_size: int
|
|
) -> torch.Tensor:
|
|
|
|
if group_size == -1:
|
|
group_size = qweight.shape[0]
|
|
|
|
bits = 4
|
|
shifts = torch.arange(0, 32, bits, device=qzeros.device)
|
|
|
|
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
|
|
torch.int8
|
|
)
|
|
|
|
iweights = iweights.view(iweights.shape[0], -1)
|
|
|
|
zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
|
|
torch.int8
|
|
)
|
|
zeros = zeros.view(qzeros.shape[0], -1)
|
|
zeros = reverse_awq_order(zeros)
|
|
|
|
iweights = reverse_awq_order(iweights)
|
|
|
|
iweights = torch.bitwise_and(iweights, (2**bits) - 1)
|
|
zeros = torch.bitwise_and(zeros, (2**bits) - 1)
|
|
|
|
scales = scales.repeat_interleave(group_size, dim=0)
|
|
zeros = zeros.repeat_interleave(group_size, dim=0)
|
|
return (iweights - zeros) * scales
|
|
|
|
|
|
def sglang_awq_dequantize(
|
|
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return awq_dequantize(qweight, scales, qzeros)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"qweight_row,qweight_col,is_bf16_act",
|
|
list(
|
|
itertools.product(
|
|
[3584, 18944, 128, 256, 512, 1024, 1536],
|
|
[448, 576, 4736, 16, 32, 64, 128, 72],
|
|
[True, False],
|
|
)
|
|
),
|
|
)
|
|
def test_awq_dequant_compare_implementations(
|
|
qweight_row: int, qweight_col: int, is_bf16_act: bool
|
|
):
|
|
device = torch.device("cuda")
|
|
qweight = torch.randint(
|
|
0,
|
|
torch.iinfo(torch.int32).max,
|
|
(qweight_row, qweight_col),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
group_size = qweight_row
|
|
scales_row = qweight_row // group_size
|
|
scales_col = qweight_col * 8
|
|
|
|
if is_bf16_act:
|
|
scales = torch.rand(scales_row, scales_col, dtype=torch.bfloat16, device=device)
|
|
else:
|
|
scales = torch.rand(scales_row, scales_col, dtype=torch.float16, device=device)
|
|
|
|
qzeros = torch.randint(
|
|
0,
|
|
torch.iinfo(torch.int32).max,
|
|
(scales_row, qweight_col),
|
|
dtype=torch.int32,
|
|
device=device,
|
|
)
|
|
|
|
# Run both implementations
|
|
torch_out = awq_dequantize_torch(qweight, scales, qzeros, group_size)
|
|
sglang_out = sglang_awq_dequantize(qweight, scales, qzeros)
|
|
|
|
# Compare results
|
|
torch.testing.assert_close(
|
|
torch_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__])
|