sglang_v0.5.2/flashinfer_0.3.1/benchmarks/bench_persistent_gemm.py

79 lines
2.4 KiB
Python

import numpy as np
import torch
import triton
import flashinfer
import flashinfer.triton
from flashinfer.testing.utils import bench_gpu_time
def is_cuda():
return triton.runtime.driver.active.get_current_target().backend == "cuda"
def supports_tma():
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
def bench_gemm_persistent(num_sms, dtype, M, N, K, reps=1000, warmup_reps=10000):
measurements = bench_gpu_time(
lambda: flashinfer.triton.sm_constraint_gemm.gemm_persistent(
a=torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype),
b=torch.randn((N, K), device="cuda", dtype=torch.float16).to(dtype),
alpha=1.0,
beta=0.0,
num_sms=num_sms,
),
dry_run_time_ms=warmup_reps,
repeat_time_ms=reps,
)
ms = np.median(measurements)
# matmul: 2 * M * N * K
# scale and add: 3 * M * N
flops = (2 * M * N * K + 3 * M * N) / ms / 1e9
print(
f"GEMM Persistent | num_sms: {num_sms}, M: {M}, N: {N}, K: {K}, {dtype}: {flops:.3f} TFLOPs/s"
)
def bench_gemm_descriptor_persistent(
num_sms, dtype, M, N, K, reps=1000, warmup_reps=10000
):
if dtype == torch.float32:
return
measurements = bench_gpu_time(
lambda: flashinfer.triton.sm_constraint_gemm.gemm_descriptor_persistent(
a=torch.randn((M, K), device="cuda", dtype=torch.float16).to(dtype),
b=torch.randn((N, K), device="cuda", dtype=torch.float16).to(dtype),
alpha=1.0,
beta=0.0,
num_sms=num_sms,
),
dry_run_time_ms=warmup_reps,
repeat_time_ms=reps,
)
ms = np.median(measurements)
# matmul: 2 * M * N * K
# scale and add: 3 * M * N
flops = (2 * M * N * K + 3 * M * N) / ms / 1e9
print(
f"GEMM Descriptor | num_sms: {num_sms}, M: {M}, N: {N}, K: {K}, {dtype}: {flops:.3f} TFLOPs/s"
)
if __name__ == "__main__":
assert supports_tma()
for M, N, K in [(4096, 4096, 4096), (8192, 8192, 8192)]:
for dtype in [
torch.float8_e4m3fn,
torch.float16,
torch.bfloat16,
torch.float32,
]:
for num_sms in [1, 16, 32, 64, 128, 132, 133, 256]:
bench_gemm_persistent(num_sms, dtype, M, N, K)
bench_gemm_descriptor_persistent(num_sms, dtype, M, N, K)