79 lines
2.4 KiB
Python
79 lines
2.4 KiB
Python
import numpy as np
|
|
import torch
|
|
import triton
|
|
|
|
import flashinfer
|
|
import flashinfer.triton
|
|
from flashinfer.testing.utils import bench_gpu_time
|
|
|
|
|
|
def is_cuda():
|
|
return triton.runtime.driver.active.get_current_target().backend == "cuda"
|
|
|
|
|
|
def supports_tma():
|
|
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
|
|
|
|
|
|
def bench_gemm_persistent(num_sms, dtype, M, N, K, reps=1000, warmup_reps=10000):
|
|
measurements = bench_gpu_time(
|
|
lambda: flashinfer.triton.sm_constraint_gemm.gemm_persistent(
|
|
a=torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype),
|
|
b=torch.randn((N, K), device="cuda", dtype=torch.float16).to(dtype),
|
|
alpha=1.0,
|
|
beta=0.0,
|
|
num_sms=num_sms,
|
|
),
|
|
dry_run_time_ms=warmup_reps,
|
|
repeat_time_ms=reps,
|
|
)
|
|
ms = np.median(measurements)
|
|
|
|
# matmul: 2 * M * N * K
|
|
# scale and add: 3 * M * N
|
|
flops = (2 * M * N * K + 3 * M * N) / ms / 1e9
|
|
print(
|
|
f"GEMM Persistent | num_sms: {num_sms}, M: {M}, N: {N}, K: {K}, {dtype}: {flops:.3f} TFLOPs/s"
|
|
)
|
|
|
|
|
|
def bench_gemm_descriptor_persistent(
|
|
num_sms, dtype, M, N, K, reps=1000, warmup_reps=10000
|
|
):
|
|
if dtype == torch.float32:
|
|
return
|
|
measurements = bench_gpu_time(
|
|
lambda: flashinfer.triton.sm_constraint_gemm.gemm_descriptor_persistent(
|
|
a=torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype),
|
|
b=torch.randn((N, K), device="cuda", dtype=torch.float16).to(dtype),
|
|
alpha=1.0,
|
|
beta=0.0,
|
|
num_sms=num_sms,
|
|
),
|
|
dry_run_time_ms=warmup_reps,
|
|
repeat_time_ms=reps,
|
|
)
|
|
ms = np.median(measurements)
|
|
|
|
# matmul: 2 * M * N * K
|
|
# scale and add: 3 * M * N
|
|
flops = (2 * M * N * K + 3 * M * N) / ms / 1e9
|
|
print(
|
|
f"GEMM Descriptor | num_sms: {num_sms}, M: {M}, N: {N}, K: {K}, {dtype}: {flops:.3f} TFLOPs/s"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
assert supports_tma()
|
|
|
|
for M, N, K in [(4096, 4096, 4096), (8192, 8192, 8192)]:
|
|
for dtype in [
|
|
torch.float8_e4m3fn,
|
|
torch.float16,
|
|
torch.bfloat16,
|
|
torch.float32,
|
|
]:
|
|
for num_sms in [1, 16, 32, 64, 128, 132, 133, 256]:
|
|
bench_gemm_persistent(num_sms, dtype, M, N, K)
|
|
bench_gemm_descriptor_persistent(num_sms, dtype, M, N, K)
|