sglang_v0.5.2/sglang/sgl-kernel/benchmark/bench_fp4_gemm.py

211 lines
5.7 KiB
Python
Executable File

import argparse
import copy
import csv
import itertools
import pytest
import torch
import triton
from flashinfer import mm_fp4
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
def get_weight_shapes(args):
models_tps = args.tp_sizes
if models_tps == [4]:
return [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]]
if models_tps == [8]:
return [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]]
return [
[1024, 3584],
[7168, 256],
[7168, 2304],
[9216, 3584],
[512, 3584],
[7168, 128],
[7168, 1152],
[4608, 3584],
]
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[
1,
2,
4,
8,
16,
32,
64,
128,
256,
512,
1024,
2048,
3072,
4096,
8192,
16384,
],
# x_vals = [64],
x_log=False,
line_arg="provider",
line_vals=["cutlass", "cudnn", "trtllm"],
line_names=["baseline cutlass fp4", "cudnn fp4", "trtllm fp4"],
styles=[("red", "solid"), ("blue", "solid"), ("green", "solid")],
ylabel="latency (ms)",
plot_name="fp4_gemm_benchmark",
args={},
)
)
def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
M = batch_size
packed_k = K
K = 2 * packed_k
a_dtype = torch.randn((M, K), dtype=dtype, device="cuda")
b_dtype = torch.randn((N, K), dtype=dtype, device="cuda")
a_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
).to(torch.float32)
b_global_scale = (
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
).to(torch.float32)
alpha = 1.0 / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
# print("a_fp4", a_fp4)
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
res_fi = torch.empty((M, N), dtype=dtype, device="cuda")
quantiles = [0.5, 0.2, 0.8]
if provider == "cutlass":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
),
quantiles=quantiles,
)
if provider == "cudnn":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
),
quantiles=quantiles,
)
if provider == "trtllm":
a_scale_interleaved = a_scale_interleaved.to(torch.uint8)
b_scale_interleaved = b_scale_interleaved.to(torch.uint8)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
backend="trtllm",
),
quantiles=quantiles,
)
if correctness:
res_cutlass = cutlass_scaled_fp4_mm(
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
)
mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
backend="cudnn",
)
assert torch.allclose(
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
), "cudnn fp4 doesn't match cutlass fp4"
mm_fp4(
a_fp4,
b_fp4.T,
a_scale_interleaved,
b_scale_interleaved.T,
alpha,
dtype,
res_fi,
backend="trtllm",
)
assert torch.allclose(
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
), "trtllm fp4 doesn't match cutlass fp4"
if csv_file:
with open(csv_file, "a", newline="") as f:
writer = csv.writer(f)
writer.writerow([provider, M, N, K, ms])
return ms, min_ms, max_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tp-sizes",
nargs="+",
type=int,
default=[1],
help="List of tensor parallel sizes",
)
parser.add_argument(
"--dtype",
type=torch.dtype,
default=torch.bfloat16,
help="Data type",
)
parser.add_argument(
"--correctness",
action="store_true",
help="Check correctness",
)
parser.add_argument(
"--csv",
type=str,
default="results_cutlass_cudnn.csv",
help="CSV file to save results",
)
args = parser.parse_args()
if args.csv:
with open(args.csv, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["provider", "m", "n", "k", "time_ms"])
NKs = get_weight_shapes(args)
for N, K in NKs:
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_fp4_res",
N=N,
K=K,
dtype=args.dtype,
correctness=args.correctness,
csv_file=args.csv,
)
print("Benchmark finished!")