sglang_v0.5.2/sglang/sgl-kernel/tests/test_marlin_gemm.py

122 lines
3.1 KiB
Python

import pytest
import torch
from sgl_kernel import gptq_marlin_gemm
from sgl_kernel.scalar_type import scalar_types
from sglang.srt.layers.quantization.marlin_utils import marlin_make_workspace
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
(257, 13, 11),
(658, 13, 11),
]
# uint4 for awq
# uint4b8 for gptq
@pytest.mark.parametrize("k_chunk", [128])
@pytest.mark.parametrize("n_chunk", [64, 256])
@pytest.mark.parametrize("quant_type", [scalar_types.uint4, scalar_types.uint4b8])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("mnk_factors", MNK_FACTORS)
@pytest.mark.parametrize("act_order", [False, True])
@pytest.mark.parametrize("is_k_full", [False, True])
@pytest.mark.parametrize("use_atomic_add", [False, True])
@pytest.mark.parametrize("use_fp32_reduce", [False, True])
def test_gptq_marlin_gemm(
k_chunk,
n_chunk,
quant_type,
group_size,
mnk_factors,
act_order,
is_k_full,
use_atomic_add,
use_fp32_reduce,
):
m_factor, n_factor, k_factor = mnk_factors
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor
if act_order:
if group_size == -1:
return
if group_size == size_k:
return
if has_zp:
return
if size_k % group_size != 0:
return
a_input = torch.randn((size_m, size_k), dtype=torch.float16, device="cuda")
b_weight = torch.randn((size_k, size_n), dtype=torch.float16, device="cuda")
if has_zp:
# AWQ style, unsigned + runtime zero-point
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
b_weight, quant_type, group_size
)
g_idx = None
sort_indices = None
marlin_s2 = None
else:
# GPTQ style, unsigned + symmetric bias
if group_size == 16:
return
w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize(
b_weight, quant_type, group_size, act_order
)
marlin_zp = None
marlin_s2 = None
workspace = marlin_make_workspace(w_ref.device)
# marlin gemm
output = gptq_marlin_gemm(
a_input,
None,
marlin_q_w,
marlin_s,
marlin_s2,
marlin_zp,
g_idx,
sort_indices,
workspace,
quant_type,
a_input.shape[0],
b_weight.shape[1],
a_input.shape[1],
is_k_full=is_k_full,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
is_zp_float=False,
)
# ref gemm
output_ref = torch.matmul(a_input, w_ref)
torch.cuda.synchronize()
max_diff = torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref)
)
assert max_diff < 0.04
if __name__ == "__main__":
import subprocess
subprocess.call(["pytest", "--tb=short", str(__file__)])