122 lines
3.1 KiB
Python
122 lines
3.1 KiB
Python
import pytest
|
|
import torch
|
|
from sgl_kernel import gptq_marlin_gemm
|
|
from sgl_kernel.scalar_type import scalar_types
|
|
|
|
from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace
|
|
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
|
|
|
|
MNK_FACTORS = [
|
|
(1, 1, 1),
|
|
(1, 4, 8),
|
|
(1, 7, 5),
|
|
(13, 17, 67),
|
|
(26, 37, 13),
|
|
(67, 13, 11),
|
|
(257, 13, 11),
|
|
(658, 13, 11),
|
|
]
|
|
|
|
|
|
# uint4 for awq
|
|
# uint4b8 for gptq
|
|
@pytest.mark.parametrize("k_chunk", [128])
|
|
@pytest.mark.parametrize("n_chunk", [64, 256])
|
|
@pytest.mark.parametrize("quant_type", [scalar_types.uint4, scalar_types.uint4b8])
|
|
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
|
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
|
|
@pytest.mark.parametrize("act_order", [False, True])
|
|
@pytest.mark.parametrize("is_k_full", [False, True])
|
|
@pytest.mark.parametrize("use_atomic_add", [False, True])
|
|
@pytest.mark.parametrize("use_fp32_reduce", [False, True])
|
|
def test_gptq_marlin_gemm(
|
|
k_chunk,
|
|
n_chunk,
|
|
quant_type,
|
|
group_size,
|
|
mnk_factors,
|
|
act_order,
|
|
is_k_full,
|
|
use_atomic_add,
|
|
use_fp32_reduce,
|
|
):
|
|
m_factor, n_factor, k_factor = mnk_factors
|
|
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
|
|
|
size_m = m_factor
|
|
size_k = k_chunk * k_factor
|
|
size_n = n_chunk * n_factor
|
|
|
|
if act_order:
|
|
if group_size == -1:
|
|
return
|
|
if group_size == size_k:
|
|
return
|
|
if has_zp:
|
|
return
|
|
|
|
if size_k % group_size != 0:
|
|
return
|
|
|
|
a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
|
|
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
|
|
|
|
if has_zp:
|
|
# AWQ style, unsigned + runtime zero-point
|
|
if group_size == 16:
|
|
return
|
|
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
|
b_weight, quant_type, group_size
|
|
)
|
|
g_idx = None
|
|
sort_indices = None
|
|
marlin_s2 = None
|
|
else:
|
|
# GPTQ style, unsigned + symmetric bias
|
|
if group_size == 16:
|
|
return
|
|
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
|
|
b_weight, quant_type, group_size, act_order
|
|
)
|
|
marlin_zp = None
|
|
marlin_s2 = None
|
|
|
|
workspace = marlin_make_workspace(w_ref.device)
|
|
|
|
# marlin gemm
|
|
output = gptq_marlin_gemm(
|
|
a_input,
|
|
None,
|
|
marlin_q_w,
|
|
marlin_s,
|
|
marlin_s2,
|
|
marlin_zp,
|
|
g_idx,
|
|
sort_indices,
|
|
workspace,
|
|
quant_type,
|
|
a_input.shape[0],
|
|
b_weight.shape[1],
|
|
a_input.shape[1],
|
|
is_k_full=is_k_full,
|
|
use_atomic_add=use_atomic_add,
|
|
use_fp32_reduce=use_fp32_reduce,
|
|
is_zp_float=False,
|
|
)
|
|
# ref gemm
|
|
output_ref = torch.matmul(a_input, w_ref)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
max_diff = torch.mean(torch.abs(output - output_ref)) / torch.mean(
|
|
torch.abs(output_ref)
|
|
)
|
|
|
|
assert max_diff < 0.04
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import subprocess
|
|
|
|
subprocess.call(["pytest", "--tb=short", str(__file__)])
|