211 lines
5.7 KiB
Python
Executable File
211 lines
5.7 KiB
Python
Executable File
import argparse
|
|
import copy
|
|
import csv
|
|
import itertools
|
|
|
|
import pytest
|
|
import torch
|
|
import triton
|
|
from flashinfer import mm_fp4
|
|
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
|
|
|
FLOAT4_E2M1_MAX = 6.0
|
|
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
|
|
|
|
|
def get_weight_shapes(args):
|
|
models_tps = args.tp_sizes
|
|
|
|
if models_tps == [4]:
|
|
return [[1024, 3584], [7168, 256], [7168, 2304], [9216, 3584]]
|
|
|
|
if models_tps == [8]:
|
|
return [[512, 3584], [7168, 128], [7168, 1152], [4608, 3584]]
|
|
return [
|
|
[1024, 3584],
|
|
[7168, 256],
|
|
[7168, 2304],
|
|
[9216, 3584],
|
|
[512, 3584],
|
|
[7168, 128],
|
|
[7168, 1152],
|
|
[4608, 3584],
|
|
]
|
|
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["batch_size"],
|
|
x_vals=[
|
|
1,
|
|
2,
|
|
4,
|
|
8,
|
|
16,
|
|
32,
|
|
64,
|
|
128,
|
|
256,
|
|
512,
|
|
1024,
|
|
2048,
|
|
3072,
|
|
4096,
|
|
8192,
|
|
16384,
|
|
],
|
|
# x_vals = [64],
|
|
x_log=False,
|
|
line_arg="provider",
|
|
line_vals=["cutlass", "cudnn", "trtllm"],
|
|
line_names=["baseline cutlass fp4", "cudnn fp4", "trtllm fp4"],
|
|
styles=[("red", "solid"), ("blue", "solid"), ("green", "solid")],
|
|
ylabel="latency (ms)",
|
|
plot_name="fp4_gemm_benchmark",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark(batch_size, provider, N, K, dtype, correctness, csv_file):
|
|
M = batch_size
|
|
packed_k = K
|
|
K = 2 * packed_k
|
|
a_dtype = torch.randn((M, K), dtype=dtype, device="cuda")
|
|
b_dtype = torch.randn((N, K), dtype=dtype, device="cuda")
|
|
a_global_scale = (
|
|
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(a_dtype.flatten(), dim=-1)
|
|
).to(torch.float32)
|
|
b_global_scale = (
|
|
(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.amax(b_dtype.flatten(), dim=-1)
|
|
).to(torch.float32)
|
|
|
|
alpha = 1.0 / (a_global_scale * b_global_scale)
|
|
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
|
|
# print("a_fp4", a_fp4)
|
|
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
|
|
res_fi = torch.empty((M, N), dtype=dtype, device="cuda")
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
if provider == "cutlass":
|
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
|
lambda: cutlass_scaled_fp4_mm(
|
|
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
|
),
|
|
quantiles=quantiles,
|
|
)
|
|
if provider == "cudnn":
|
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
|
lambda: mm_fp4(
|
|
a_fp4,
|
|
b_fp4.T,
|
|
a_scale_interleaved,
|
|
b_scale_interleaved.T,
|
|
alpha,
|
|
dtype,
|
|
res_fi,
|
|
),
|
|
quantiles=quantiles,
|
|
)
|
|
if provider == "trtllm":
|
|
a_scale_interleaved = a_scale_interleaved.to(torch.uint8)
|
|
b_scale_interleaved = b_scale_interleaved.to(torch.uint8)
|
|
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
|
lambda: mm_fp4(
|
|
a_fp4,
|
|
b_fp4.T,
|
|
a_scale_interleaved,
|
|
b_scale_interleaved.T,
|
|
alpha,
|
|
dtype,
|
|
res_fi,
|
|
backend="trtllm",
|
|
),
|
|
quantiles=quantiles,
|
|
)
|
|
if correctness:
|
|
res_cutlass = cutlass_scaled_fp4_mm(
|
|
a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, dtype
|
|
)
|
|
mm_fp4(
|
|
a_fp4,
|
|
b_fp4.T,
|
|
a_scale_interleaved,
|
|
b_scale_interleaved.T,
|
|
alpha,
|
|
dtype,
|
|
res_fi,
|
|
backend="cudnn",
|
|
)
|
|
assert torch.allclose(
|
|
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
|
|
), "cudnn fp4 doesn't match cutlass fp4"
|
|
mm_fp4(
|
|
a_fp4,
|
|
b_fp4.T,
|
|
a_scale_interleaved,
|
|
b_scale_interleaved.T,
|
|
alpha,
|
|
dtype,
|
|
res_fi,
|
|
backend="trtllm",
|
|
)
|
|
assert torch.allclose(
|
|
res_fi, res_cutlass, atol=1e-3, rtol=1e-3
|
|
), "trtllm fp4 doesn't match cutlass fp4"
|
|
|
|
if csv_file:
|
|
with open(csv_file, "a", newline="") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow([provider, M, N, K, ms])
|
|
|
|
return ms, min_ms, max_ms
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--tp-sizes",
|
|
nargs="+",
|
|
type=int,
|
|
default=[1],
|
|
help="List of tensor parallel sizes",
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
type=torch.dtype,
|
|
default=torch.bfloat16,
|
|
help="Data type",
|
|
)
|
|
parser.add_argument(
|
|
"--correctness",
|
|
action="store_true",
|
|
help="Check correctness",
|
|
)
|
|
parser.add_argument(
|
|
"--csv",
|
|
type=str,
|
|
default="results_cutlass_cudnn.csv",
|
|
help="CSV file to save results",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
if args.csv:
|
|
with open(args.csv, "w", newline="") as f:
|
|
writer = csv.writer(f)
|
|
writer.writerow(["provider", "m", "n", "k", "time_ms"])
|
|
|
|
NKs = get_weight_shapes(args)
|
|
for N, K in NKs:
|
|
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
|
|
benchmark.run(
|
|
print_data=True,
|
|
show_plots=True,
|
|
save_path="bench_fp4_res",
|
|
N=N,
|
|
K=K,
|
|
dtype=args.dtype,
|
|
correctness=args.correctness,
|
|
csv_file=args.csv,
|
|
)
|
|
|
|
print("Benchmark finished!")
|