264 lines
8.7 KiB
Python
264 lines
8.7 KiB
Python
"""
|
|
This is the test file for MaskedBatchedMatmulCuteDSL kernel.
|
|
`test_blockscaled_gemm_python_interface` is the python interface test. For pytorch DLFW, refer to this.
|
|
"""
|
|
|
|
from typing import Tuple
|
|
|
|
import cutlass
|
|
import cutlass.cute as cute
|
|
import cutlass.torch as cutlass_torch
|
|
import pytest
|
|
import torch
|
|
from cutlass.cute.runtime import from_dlpack
|
|
|
|
from flashinfer.cute_dsl.blockscaled_gemm import (
|
|
Sm100BlockScaledPersistentDenseGemmKernel, # not used in python interface
|
|
grouped_gemm_nt_masked, # deepgemm-like python interface for DLFW integration
|
|
create_scale_factor_tensor,
|
|
)
|
|
from flashinfer.cute_dsl.utils import (
|
|
get_cutlass_dtype,
|
|
is_cute_dsl_available,
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not is_cute_dsl_available(), reason="Please `pip install nvidia-cutlass-dsl`"
|
|
)
|
|
@pytest.mark.parametrize("lm", [(1, 1024), (2, 512), (4, 256)])
|
|
@pytest.mark.parametrize("kn", [(7168, 4096), (2048, 7168)])
|
|
@pytest.mark.parametrize(
|
|
"ab_dtype,sf_dtype,c_dtype,sf_vec_size",
|
|
[
|
|
("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
|
|
("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
|
|
("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
|
|
("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
|
|
("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
|
|
("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
|
|
("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
|
|
("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
|
|
("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
|
|
("float8_e4m3fn", "float8_e8m0fnu", "float8_e4m3fn", 32),
|
|
("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
|
|
("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
|
|
("float8_e5m2", "float8_e8m0fnu", "float16", 32),
|
|
("float8_e5m2", "float8_e8m0fnu", "float32", 32),
|
|
("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
|
|
("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
|
|
],
|
|
)
|
|
@pytest.mark.parametrize("a_major", ["k"])
|
|
@pytest.mark.parametrize("b_major", ["k"])
|
|
@pytest.mark.parametrize("c_major", ["n"])
|
|
@pytest.mark.parametrize("fuse_alpha", [False, True])
|
|
@pytest.mark.parametrize("alpha_dtype", ["float32"])
|
|
@pytest.mark.parametrize("mma_tiler_mn", [(128, 128)])
|
|
@pytest.mark.parametrize("cluster_shape_mn", [(1, 1)])
|
|
@pytest.mark.parametrize("sm_count", [132, None])
|
|
@pytest.mark.parametrize("tolerance", [1e-01])
|
|
@pytest.mark.parametrize("iterations", [3])
|
|
def test_blockscaled_gemm_python_interface(
|
|
lm: Tuple[int, int],
|
|
kn: Tuple[int, int],
|
|
ab_dtype: cutlass.dtype,
|
|
sf_dtype: cutlass.dtype,
|
|
sf_vec_size: int,
|
|
c_dtype: cutlass.dtype,
|
|
a_major: str,
|
|
b_major: str,
|
|
c_major: str,
|
|
fuse_alpha: bool,
|
|
alpha_dtype: cutlass.dtype,
|
|
mma_tiler_mn: Tuple[int, int],
|
|
cluster_shape_mn: Tuple[int, int],
|
|
sm_count: int,
|
|
tolerance: float,
|
|
iterations: int,
|
|
):
|
|
torch.manual_seed(42)
|
|
device = torch.device("cuda:0")
|
|
major, minor = torch.cuda.get_device_capability(device)
|
|
|
|
if not (major == 10 and minor == 0):
|
|
pytest.skip("Cute-dsl backend is only supported on SM100.")
|
|
|
|
l, m = lm
|
|
k, n = kn
|
|
|
|
print(f"device: {device}")
|
|
|
|
if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
|
|
get_cutlass_dtype(ab_dtype),
|
|
get_cutlass_dtype(sf_dtype),
|
|
sf_vec_size,
|
|
get_cutlass_dtype(c_dtype),
|
|
mma_tiler_mn,
|
|
cluster_shape_mn,
|
|
m,
|
|
n,
|
|
k,
|
|
l,
|
|
a_major,
|
|
b_major,
|
|
c_major,
|
|
):
|
|
pytest.skip(
|
|
f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}"
|
|
)
|
|
|
|
if not (a_major == "k" and b_major == "k" and c_major == "n"):
|
|
# not supported since we try to align deepgemm for now
|
|
pytest.skip(
|
|
f"Skip non deepgemm-like cases {a_major}, {b_major}, {c_major}. Might be added later"
|
|
)
|
|
|
|
a_ref = cutlass_torch.matrix(
|
|
l, m, k, a_major == "m", cutlass.Float32, device=device
|
|
)
|
|
b_ref = cutlass_torch.matrix(
|
|
l, n, k, b_major == "n", cutlass.Float32, device=device
|
|
)
|
|
c_ref = cutlass_torch.matrix(
|
|
l, m, n, c_major == "m", cutlass.Float32, device=device
|
|
)
|
|
|
|
a_tensor, a_torch = cutlass_torch.cute_tensor_like(
|
|
a_ref,
|
|
get_cutlass_dtype(ab_dtype),
|
|
is_dynamic_layout=True,
|
|
assumed_align=16,
|
|
)
|
|
b_tensor, b_torch = cutlass_torch.cute_tensor_like(
|
|
b_ref,
|
|
get_cutlass_dtype(ab_dtype),
|
|
is_dynamic_layout=True,
|
|
assumed_align=16,
|
|
)
|
|
c_tensor, c_torch = cutlass_torch.cute_tensor_like(
|
|
c_ref,
|
|
get_cutlass_dtype(c_dtype),
|
|
is_dynamic_layout=True,
|
|
assumed_align=16,
|
|
)
|
|
alpha_tensor = (
|
|
torch.randn(l, dtype=torch.float32, device=device) if fuse_alpha else None
|
|
)
|
|
|
|
# for deepgemm-like python interface
|
|
if ab_dtype == "float4_e2m1fn":
|
|
m, k, l = a_torch.shape
|
|
n, k, l = b_torch.shape
|
|
# slice into half after flatten
|
|
half_len_a = a_torch.numel() // 2
|
|
half_len_b = b_torch.numel() // 2
|
|
a_torch = (
|
|
a_torch.permute(2, 0, 1)
|
|
.flatten()[:half_len_a]
|
|
.reshape(l, m, k // 2)
|
|
.permute(1, 2, 0)
|
|
)
|
|
b_torch = (
|
|
b_torch.permute(2, 0, 1)
|
|
.flatten()[:half_len_b]
|
|
.reshape(l, n, k // 2)
|
|
.permute(1, 2, 0)
|
|
)
|
|
|
|
sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor(
|
|
l, m, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device
|
|
)
|
|
sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor(
|
|
l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device
|
|
)
|
|
masked_m_tensor = torch.randint(0, m, (l,), dtype=torch.int32, device=device)
|
|
|
|
for _ in range(iterations):
|
|
# deepgemm-like python interface: fp4 packed, for DLFW integration
|
|
grouped_gemm_nt_masked(
|
|
(a_torch, sfa_torch),
|
|
(b_torch, sfb_torch),
|
|
c_torch,
|
|
masked_m_tensor,
|
|
ab_dtype=ab_dtype,
|
|
sf_dtype=sf_dtype,
|
|
c_dtype=c_dtype,
|
|
sf_vec_size=sf_vec_size,
|
|
mma_tiler_mn=mma_tiler_mn,
|
|
cluster_shape_mn=cluster_shape_mn,
|
|
alpha=alpha_tensor,
|
|
alpha_dtype=alpha_dtype,
|
|
sm_count=sm_count,
|
|
)
|
|
|
|
# compute ref output
|
|
if not fuse_alpha:
|
|
alpha_tensor = torch.ones(l, dtype=torch.float32, device=device)
|
|
res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref)
|
|
res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref)
|
|
ref = torch.einsum("mkl,nkl->mnl", res_a, res_b)
|
|
ref = torch.einsum("mnl,l->mnl", ref, alpha_tensor)
|
|
|
|
# Convert c back to f32 for comparison.
|
|
cute.testing.convert(
|
|
c_tensor,
|
|
from_dlpack(c_ref, assumed_align=16).mark_layout_dynamic(
|
|
leading_dim=(1 if c_major == "n" else 0)
|
|
),
|
|
)
|
|
|
|
if c_dtype in ("float32", "float16", "bfloat16"):
|
|
for i in range(l):
|
|
# skip testing c_ref & ref
|
|
torch.testing.assert_close(
|
|
c_ref[: masked_m_tensor[i].item(), :, i],
|
|
ref[: masked_m_tensor[i].item(), :, i],
|
|
atol=tolerance,
|
|
rtol=1e-02,
|
|
)
|
|
elif c_dtype in ("float8_e5m2", "float8_e4m3fn"):
|
|
# Convert ref : f32 -> f8 -> f32
|
|
ref_f8_ = torch.empty(*(l, m, n), dtype=torch.uint8, device=device).permute(
|
|
1, 2, 0
|
|
)
|
|
ref_f8 = from_dlpack(ref_f8_, assumed_align=16).mark_layout_dynamic(
|
|
leading_dim=1
|
|
)
|
|
ref_f8.element_type = get_cutlass_dtype(c_dtype)
|
|
ref = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0)
|
|
ref_tensor = from_dlpack(ref, assumed_align=16).mark_layout_dynamic(
|
|
leading_dim=1
|
|
)
|
|
cute.testing.convert(ref_tensor, ref_f8)
|
|
cute.testing.convert(ref_f8, ref_tensor)
|
|
for i in range(l):
|
|
# skip testing c_ref & ref
|
|
torch.testing.assert_close(
|
|
c_ref[: masked_m_tensor[i].item(), :, i],
|
|
ref[: masked_m_tensor[i].item(), :, i],
|
|
atol=tolerance,
|
|
rtol=1e-02,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_blockscaled_gemm_python_interface(
|
|
lm=(1, 1024),
|
|
kn=(7168, 4096),
|
|
ab_dtype="float4_e2m1fn",
|
|
sf_dtype="float8_e8m0fnu",
|
|
sf_vec_size=16,
|
|
c_dtype="float16",
|
|
a_major="k",
|
|
b_major="k",
|
|
c_major="n",
|
|
fuse_alpha=False,
|
|
alpha_dtype="float32",
|
|
mma_tiler_mn=(128, 128),
|
|
cluster_shape_mn=(2, 1),
|
|
tolerance=1e-01,
|
|
iterations=3,
|
|
sm_count=132,
|
|
)
|