203 lines
5.6 KiB
Python
203 lines
5.6 KiB
Python
import json
|
|
import random
|
|
import cutlass
|
|
from flashinfer.cute_dsl.blockscaled_gemm import (
|
|
create_scale_factor_tensor,
|
|
grouped_gemm_nt_masked, # deepgemm-like python interface for DLFW integration
|
|
)
|
|
import torch
|
|
import cutlass.torch as cutlass_torch
|
|
from flashinfer.cute_dsl.utils import get_cutlass_dtype
|
|
from flashinfer.testing.utils import bench_kineto, count_bytes
|
|
|
|
|
|
ab_dtype = "float4_e2m1fn"
|
|
sf_dtype = "float8_e4m3fn"
|
|
c_dtype = "bfloat16"
|
|
sf_vec_size = 16
|
|
|
|
# DeepGEMM case
|
|
a_major = "k"
|
|
b_major = "k"
|
|
c_major = "n"
|
|
|
|
|
|
def bench_one(num_groups, max_m, expected_m_per_group, n, k):
|
|
data = create_data(
|
|
num_groups=num_groups,
|
|
max_m=max_m,
|
|
expected_m_per_group=expected_m_per_group,
|
|
n=n,
|
|
k=k,
|
|
)
|
|
|
|
def test_func():
|
|
grouped_gemm_nt_masked(
|
|
lhs=data["a"],
|
|
rhs=data["b"],
|
|
out=data["c"],
|
|
masked_m=data["masked_m"],
|
|
ab_dtype=ab_dtype,
|
|
sf_dtype=sf_dtype,
|
|
c_dtype=c_dtype,
|
|
sf_vec_size=sf_vec_size,
|
|
alpha_dtype="float32",
|
|
)
|
|
|
|
t = bench_kineto(
|
|
test_func,
|
|
"Sm100BlockScaledPersistentDenseGemmKernel",
|
|
suppress_kineto_output=True,
|
|
)
|
|
|
|
valid_m = data["masked_m"].sum().item()
|
|
t_calibrated = t / valid_m * (expected_m_per_group * num_groups)
|
|
|
|
tflops = 2 * valid_m * n * k / t / 1e12
|
|
gb_per_s = (
|
|
(
|
|
count_bytes(data["a"], data["c"]) * valid_m / (max_m * num_groups)
|
|
+ count_bytes(data["b"])
|
|
)
|
|
/ 1e9
|
|
/ t
|
|
)
|
|
|
|
print(
|
|
f" > Perf ({num_groups=}, expected_m_per_group={expected_m_per_group:4}, n={n:4}, k={k:4}): "
|
|
f"{t * 1e6:4.0f} us | {tflops:4.0f} TFLOPS | {gb_per_s:4.0f} GB/s"
|
|
)
|
|
|
|
metrics = dict(
|
|
num_groups=num_groups,
|
|
m_per_group=expected_m_per_group,
|
|
valid_m=valid_m,
|
|
n=n,
|
|
k=k,
|
|
t_us_raw=t * 1e6,
|
|
t_us_calibrated=t_calibrated * 1e6,
|
|
tflops=tflops,
|
|
gb_per_s=gb_per_s,
|
|
)
|
|
print(f"MAIN_OUTPUT={json.dumps(metrics)}")
|
|
|
|
|
|
# ref: DeepGEMM
|
|
def enumerate_m_grouped_masked():
|
|
max_m = 4096
|
|
|
|
cases = [
|
|
# GB200 cases
|
|
(6, 1024),
|
|
(6, 512),
|
|
# DeepGEMM default cases
|
|
(1, 1024),
|
|
(2, 512),
|
|
(4, 256),
|
|
]
|
|
# more GB200 cases
|
|
num_experts = 288
|
|
num_experts_per_token = 8
|
|
for num_ranks in [4, 8, 16, 32, 36, 48, 72]:
|
|
for num_tokens in [64, 128, 256, 384, 512, 768, 1024]:
|
|
num_groups = num_experts // num_ranks
|
|
expected_m_per_group = num_tokens * num_experts_per_token // num_groups
|
|
cases.append((num_groups, expected_m_per_group))
|
|
|
|
for num_groups, expected_m_per_group in cases:
|
|
for n, k in (
|
|
(4096, 7168),
|
|
(7168, 2048),
|
|
):
|
|
yield dict(
|
|
num_groups=num_groups,
|
|
max_m=max_m,
|
|
expected_m_per_group=expected_m_per_group,
|
|
n=n,
|
|
k=k,
|
|
)
|
|
|
|
|
|
# Copy and modified from test_cute_dsl_blockscaled_gemm.py, may extract common logic later if needed
|
|
def create_data(num_groups, max_m, expected_m_per_group, n, k, device="cuda:0"):
|
|
device = torch.device(device)
|
|
l = num_groups
|
|
m = max_m
|
|
|
|
a_ref = cutlass_torch.matrix(l, m, k, a_major == "m", cutlass.Float32)
|
|
b_ref = cutlass_torch.matrix(l, n, k, b_major == "n", cutlass.Float32)
|
|
c_ref = cutlass_torch.matrix(l, m, n, c_major == "m", cutlass.Float32)
|
|
|
|
a_tensor, a_torch = cutlass_torch.cute_tensor_like(
|
|
a_ref,
|
|
get_cutlass_dtype(ab_dtype),
|
|
is_dynamic_layout=True,
|
|
assumed_align=16,
|
|
)
|
|
b_tensor, b_torch = cutlass_torch.cute_tensor_like(
|
|
b_ref,
|
|
get_cutlass_dtype(ab_dtype),
|
|
is_dynamic_layout=True,
|
|
assumed_align=16,
|
|
)
|
|
c_tensor, c_torch = cutlass_torch.cute_tensor_like(
|
|
c_ref,
|
|
get_cutlass_dtype(c_dtype),
|
|
is_dynamic_layout=True,
|
|
assumed_align=16,
|
|
)
|
|
|
|
# for deepgemm-like python interface
|
|
if ab_dtype == "float4_e2m1fn":
|
|
m, k, l = a_torch.shape
|
|
n, k, l = b_torch.shape
|
|
# slice into half after flatten
|
|
half_len_a = a_torch.numel() // 2
|
|
half_len_b = b_torch.numel() // 2
|
|
a_torch = (
|
|
a_torch.permute(2, 0, 1)
|
|
.flatten()[:half_len_a]
|
|
.reshape(l, m, k // 2)
|
|
.permute(1, 2, 0)
|
|
)
|
|
b_torch = (
|
|
b_torch.permute(2, 0, 1)
|
|
.flatten()[:half_len_b]
|
|
.reshape(l, n, k // 2)
|
|
.permute(1, 2, 0)
|
|
)
|
|
|
|
sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor(
|
|
l, m, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device
|
|
)
|
|
sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor(
|
|
l, n, k, sf_vec_size, get_cutlass_dtype(sf_dtype), device
|
|
)
|
|
|
|
masked_m_tensor = create_masked_m(
|
|
num_groups=num_groups, expected_m_per_group=expected_m_per_group, max_m=max_m
|
|
)
|
|
|
|
return dict(
|
|
a=(a_torch, sfa_torch),
|
|
b=(b_torch, sfb_torch),
|
|
c=c_torch,
|
|
masked_m=masked_m_tensor,
|
|
)
|
|
|
|
|
|
def create_masked_m(num_groups, expected_m_per_group, max_m):
|
|
"""Align with DeepGEMM :: generate_m_grouped_masked"""
|
|
masked_m = torch.empty((num_groups,), device="cuda", dtype=torch.int)
|
|
for j in range(num_groups):
|
|
masked_m[j] = int(expected_m_per_group * random.uniform(0.7, 1.3))
|
|
assert masked_m.amax().item() <= max_m
|
|
return masked_m
|
|
|
|
|
|
if __name__ == "__main__":
|
|
torch.manual_seed(42)
|
|
random.seed(42)
|
|
for config in enumerate_m_grouped_masked():
|
|
bench_one(**config)
|