sglang_v0.5.2/flashinfer_0.3.1/benchmarks/routines/gemm.py

918 lines
32 KiB
Python

from collections import defaultdict
import numpy as np
import torch
import torch.nn.functional as F
from einops import einsum
import flashinfer
from flashinfer.testing.utils import (
bench_gpu_time,
bench_gpu_time_with_cudagraph,
dequantize_fp8,
quantize_fp8,
)
from .flashinfer_benchmark_utils import (
dtype_str_to_torch_dtype,
get_device,
print_perf_metrics,
)
def run_gemm_test(args):
"""
Run a gemm test.
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.routine == "gemm_fp8_nt_groupwise":
return testGemmFp8NtGroupwise(args)
elif args.routine == "group_gemm_fp8_nt_groupwise":
return testGroupGemmFp8NtGroupwise(args)
elif args.routine == "bmm_fp8":
return testBmmFp8(args)
elif args.routine == "mm_fp4":
return testMmFp4(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")
def parse_gemm_args(line, parser):
"""
Parse command line arguments for gemm test configuration.
Args:
line: Command line arguments
parser: ArgumentParser object already populated with shared arguments
Returns:
Parsed argument namespace
"""
parser.add_argument(
"--batch_size",
type=int,
required=False,
default=1,
help="Batch size of test case.",
)
parser.add_argument(
"--m", type=int, required=True, help="Number of rows in the first matrix."
)
parser.add_argument(
"--n", type=int, required=True, help="Number of columns in the second matrix."
)
parser.add_argument(
"--k",
type=int,
required=True,
help="Number of columns in the first matrix and number of rows in the second matrix.",
)
parser.add_argument(
"--tile_size",
type=int,
required=False,
default=128,
help="Tile size for the gemm operation.",
)
parser.add_argument(
"--group_size",
type=int,
required=False,
default=1,
help="Group size for the group gemm operation.",
)
parser.add_argument(
"--scale_major_mode",
type=str,
required=False,
default="MN",
choices=["MN", "K"],
help="Scale major mode.",
)
parser.add_argument(
"--input_dtype",
type=str,
required=False,
default="fp8_e4m3",
help="Data type of the input.",
)
parser.add_argument(
"--mat2_dtype",
type=str,
required=False,
default="fp8_e4m3",
help="Data type of the mat2.",
)
parser.add_argument(
"--out_dtype",
type=str,
required=False,
default="bfloat16",
help="Data type of the output.",
)
parser.add_argument(
"--mma_sm",
type=int,
required=False,
default=1,
choices=[1, 2],
help="How many SMs to use for the MMA operation, must be 1 or 2",
)
parser.add_argument(
"--backends",
type=str,
required=False,
nargs="+",
default=["cudnn"],
choices=["cudnn", "cublas", "trtllm", "cutlass"],
help="Kernel backends to test. Default: cudnn",
)
parser.add_argument(
"--use_128x4_sf_layout",
action="store_true",
help="Use 128x4 SF layout for the input and mat2.",
)
args = parser.parse_args(line)
if args.verbose >= 1:
print(f"[INFO] {args = }")
return args
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
def testGemmFp8NtGroupwise(args):
"""
Test gemm_fp8_nt_groupwise API.
This test:
1. Generates random input tensors
2. Quantizes input tensors to FP8
3. Runs gemm_fp8_nt_groupwise
4. Runs reference check
5. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testGemmFp8NtGroupwise")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
## Parse input arguments
backends = args.backends
m = args.m
n = args.n
k = args.k
tile_size = args.tile_size
scale_major_mode = args.scale_major_mode
mma_sm = args.mma_sm
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
out_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if out_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(f"Unsupported output dtype: {args.out_dtype}")
## Done parsing input arguments
if "trtllm" in backends:
remove_trtllm = False
if scale_major_mode != "MN":
print(
"[INFO] trtllm only supports MN scale_major_mode, removing trtllm from backends"
)
remove_trtllm = True
if k < 256:
print("[INFO] trtllm only supports k >= 256, removing trtllm from backends")
remove_trtllm = True
if remove_trtllm:
backends.remove("trtllm")
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return
## Prepare input tensors
a_val = torch.randn((m, k), dtype=torch.float, device=device)
b_val = torch.randn((n, k), dtype=torch.float, device=device) / np.sqrt(k)
if args.verbose >= 2:
print(f"[VVERBOSE] {a_val.shape = }")
print(f"[VVERBOSE] {b_val.shape = }")
if scale_major_mode == "K":
a_scale_shape = (m, k // tile_size)
b_scale_shape = (n // tile_size, k // tile_size)
else:
a_scale_shape = (k // tile_size, m)
b_scale_shape = (k // tile_size, n // tile_size)
a_tile_shape = (1, tile_size)
b_tile_shape = (tile_size, tile_size)
a_fp8, a_scale = quantize_fp8(a_val, a_scale_shape, a_tile_shape, scale_major_mode)
b_fp8, b_scale = quantize_fp8(b_val, b_scale_shape, b_tile_shape, scale_major_mode)
if args.verbose >= 2:
print(f"[VVERBOSE] {a_fp8.shape = }")
print(f"[VVERBOSE] {b_fp8.shape = }")
print(f"[VVERBOSE] {a_scale.shape = }")
print(f"[VVERBOSE] {b_scale.shape = }")
a_dequant = dequantize_fp8(a_fp8, a_scale, scale_major_mode)
b_dequant = dequantize_fp8(b_fp8, b_scale, scale_major_mode)
def run_backend(backend):
if backend in ["cutlass", "trtllm"]:
return flashinfer.gemm.gemm_fp8_nt_groupwise(
a=a_fp8,
b=b_fp8,
a_scale=a_scale,
b_scale=b_scale,
scale_major_mode=scale_major_mode,
out_dtype=out_dtype,
mma_sm=mma_sm,
backend=backend,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
has_reference_output = False
if run_refcheck:
reference_output = einsum(a_dequant, b_dequant, "m k, n k -> m n").to(out_dtype)
has_reference_output = True
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend(cur_backend).detach()
if is_cuda_graph_compatible:
backend_times[cur_backend] = bench_gpu_time_with_cudagraph(
fn=lambda: run_backend(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=True, # GEMMs are very MMA-heavy, so prefer sleep to reduce throttling.
)
else:
backend_times[cur_backend] = bench_gpu_time(
fn=lambda: run_backend(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=True, # GEMMs are very MMA-heavy, so prefer sleep to reduce throttling.
)
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 0:
if run_refcheck and has_reference_output:
for i in range(len(tested_backends)):
try:
torch.testing.assert_close(
reference_output, tested_outputs[i], rtol=1e-2, atol=1e-2
)
except AssertionError as e:
print(
f"[ERROR] Output tensor mismatch from backend {tested_backends[i]}"
)
if not args.allow_output_mismatch:
print(e)
raise
res = []
for backend in backends:
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
problem_flops = 2 * m * n * k
problem_bytes = (m * k + n * k) * torch.float8_e4m3fn.itemsize + (
m * n
) * out_dtype.itemsize
tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec
tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["m"] = m
cur_res["n"] = n
cur_res["k"] = k
cur_res["tile_size"] = tile_size
cur_res["scale_major_mode"] = scale_major_mode
cur_res["out_dtype"] = out_dtype
cur_res["mma_sm"] = mma_sm
cur_res["backend"] = backend
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
def testGroupGemmFp8NtGroupwise(args):
"""
Test group_gemm_fp8_nt_groupwise API.
This test:
1. Generates random input tensors
2. Quantizes input tensors to FP8
3. Runs group_gemm_fp8_nt_groupwise
4. Runs reference check
5. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testGroupGemmFp8NtGroupwise")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
## Parse input arguments
backends = ["cutlass"] # Cutlass is currently the only supported backend
m = args.m
n = args.n
k = args.k
group_size = args.group_size
tile_size = args.tile_size
scale_major_mode = args.scale_major_mode
mma_sm = args.mma_sm
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
out_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if out_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(f"Unsupported output dtype: {args.out_dtype}")
## Done parsing input arguments
## Prepare input tensors
a_val = torch.randn((group_size * m, k), dtype=torch.float, device="cuda")
b_val = torch.randn((group_size, n, k), dtype=torch.float, device="cuda") / np.sqrt(
k
)
if args.verbose >= 2:
print(f"[VVERBOSE] {a_val.shape = }")
print(f"[VVERBOSE] {b_val.shape = }")
if scale_major_mode == "K":
a_scale_shape = (group_size * m, k // tile_size)
b_scale_shape = (group_size, n // tile_size, k // tile_size)
else:
a_scale_shape = (k // tile_size, m * group_size)
b_scale_shape = (group_size, k // tile_size, n // tile_size)
a_tile_shape = (1, tile_size)
b_tile_shape = (1, tile_size, tile_size)
a_fp8, a_scale = quantize_fp8(a_val, a_scale_shape, a_tile_shape, scale_major_mode)
b_fp8, b_scale = quantize_fp8(b_val, b_scale_shape, b_tile_shape, scale_major_mode)
a_dequant = dequantize_fp8(a_fp8, a_scale, scale_major_mode)
b_dequant = dequantize_fp8(b_fp8, b_scale, scale_major_mode)
m_indptr = torch.arange(0, group_size + 1, dtype=torch.int32, device="cuda") * m
if args.verbose >= 2:
print(f"[VVERBOSE] {a_fp8.shape = }")
print(f"[VVERBOSE] {b_fp8.shape = }")
print(f"[VVERBOSE] {a_scale.shape = }")
print(f"[VVERBOSE] {b_scale.shape = }")
print(f"[VVERBOSE] {m_indptr.shape = }")
def run_backend(backend):
if backend == "cutlass":
return flashinfer.gemm.group_gemm_fp8_nt_groupwise(
a=a_fp8,
b=b_fp8,
a_scale=a_scale,
b_scale=b_scale,
m_indptr=m_indptr,
scale_major_mode=scale_major_mode,
out_dtype=out_dtype,
mma_sm=mma_sm,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
has_reference_output = False
if run_refcheck:
reference_output = (
einsum(
a_dequant.view((group_size, m, k)), b_dequant, "b m k, b n k -> b m n"
)
.view((group_size * m, n))
.to(out_dtype)
)
has_reference_output = True
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend(cur_backend).detach()
if is_cuda_graph_compatible:
backend_times[cur_backend] = bench_gpu_time_with_cudagraph(
fn=lambda: run_backend(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=True, # GEMMs are very MMA-heavy, so prefer sleep to reduce throttling.
)
else:
backend_times[cur_backend] = bench_gpu_time(
fn=lambda: run_backend(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=True, # GEMMs are very MMA-heavy, so prefer sleep to reduce throttling.
)
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 0:
if run_refcheck and has_reference_output:
for i in range(len(tested_backends)):
try:
torch.testing.assert_close(
reference_output, tested_outputs[i], rtol=1e-2, atol=1e-2
)
except AssertionError as e:
print(
f"[ERROR] Output tensor mismatch from backend {tested_backends[i]}"
)
if not args.allow_output_mismatch:
print(e)
raise
res = []
for backend in backends:
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
problem_flops = 2 * m * n * k * group_size
problem_bytes = (
group_size * m * k + group_size * n * k
) * torch.float8_e4m3fn.itemsize + (group_size * m * n) * out_dtype.itemsize
tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec
tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["m"] = m
cur_res["n"] = n
cur_res["k"] = k
cur_res["group_size"] = group_size
cur_res["tile_size"] = tile_size
cur_res["scale_major_mode"] = scale_major_mode
cur_res["out_dtype"] = out_dtype
cur_res["mma_sm"] = mma_sm
cur_res["backend"] = backend
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
def testBmmFp8(args):
"""
Test bmm_fp8 API.
This test:
1. Generates random input tensors
2. Quantizes input tensors to FP8
3. Runs bmm_fp8
4. Runs reference check
5. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testBmmFp8")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
## Parse input arguments
backends = args.backends
batch_size = args.batch_size
m = args.m
n = args.n
k = args.k
input_dtype = args.input_dtype
mat2_dtype = args.mat2_dtype
res_dtype = args.out_dtype
backends = args.backends
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
input_dtype = dtype_str_to_torch_dtype(args.input_dtype)
if input_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
raise ValueError(
f"Unsupported input dtype: {input_dtype}. Supported dtypes are fp8_e4m3 and fp8_e5m2."
)
mat2_dtype = dtype_str_to_torch_dtype(args.mat2_dtype)
if mat2_dtype not in [torch.float8_e4m3fn, torch.float8_e5m2]:
raise ValueError(
f"Unsupported mat2 dtype: {mat2_dtype}. Supported dtypes are fp8_e4m3 and fp8_e5m2."
)
res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if res_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(
f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16."
)
## Done parsing input arguments
## Prepare input tensors
input = torch.randn([batch_size, m, k], device=device, dtype=torch.bfloat16)
input_fp8, input_inv_s = to_float8(input, dtype=input_dtype)
mat2 = torch.randn(
[batch_size, n, k], device=device, dtype=torch.bfloat16
).transpose(-2, -1)
mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype)
if args.verbose >= 2:
print(f"[VVERBOSE] {input_fp8.shape = }")
print(f"[VVERBOSE] {input_fp8.dtype = }")
print(f"[VVERBOSE] {mat2_fp8.shape = }")
print(f"[VVERBOSE] {mat2_fp8.dtype = }")
print(f"[VVERBOSE] {input_inv_s = }")
print(f"[VVERBOSE] {input_inv_s.dtype = }")
print(f"[VVERBOSE] {mat2_inv_s = }")
print(f"[VVERBOSE] {mat2_inv_s.dtype = }")
def run_backend(backend):
if backend in ["cudnn", "cublas", "cutlass"]:
return flashinfer.gemm.bmm_fp8(
A=input_fp8,
B=mat2_fp8,
A_scale=input_inv_s,
B_scale=mat2_inv_s,
dtype=res_dtype,
backend=backend,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
has_reference_output = False
if run_refcheck:
reference_output = torch.bmm(input, mat2)
has_reference_output = True
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend(cur_backend).detach()
if is_cuda_graph_compatible:
backend_times[cur_backend] = bench_gpu_time_with_cudagraph(
fn=lambda: run_backend(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
num_iters_within_graph=20,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=True,
)
else:
backend_times[cur_backend] = bench_gpu_time(
fn=lambda: run_backend(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=True,
)
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 0:
if run_refcheck and has_reference_output:
if reference_output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
print(
"[INFO] Reference output is FP8. Converting to float32 for reference check."
)
reference_output = reference_output.to(torch.float32)
tested_outputs = [output.to(torch.float32) for output in tested_outputs]
for i in range(len(tested_backends)):
try:
cos_sim = F.cosine_similarity(
reference_output.reshape(-1),
tested_outputs[i].reshape(-1),
dim=0,
)
assert cos_sim > 0.99
except AssertionError as e:
print(
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}"
)
if not args.allow_output_mismatch:
print(e)
raise
res = []
for backend in backends:
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
problem_flops = 2 * m * n * k * batch_size
problem_bytes = (
m * k * input_dtype.itemsize
+ n * k * mat2_dtype.itemsize
+ m * n * res_dtype.itemsize
)
tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec
tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["batch_size"] = batch_size
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["m"] = m
cur_res["n"] = n
cur_res["k"] = k
cur_res["input_dtype"] = input_dtype
cur_res["mat2_dtype"] = mat2_dtype
cur_res["out_dtype"] = res_dtype
cur_res["backend"] = backend
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res
def testMmFp4(args):
"""
Test mm_fp4 API.
This test:
1. Generates random input tensors
2. Quantizes input tensors to FP4
3. Runs mm_fp4
4. Runs reference check
5. Measures performance metrics (TFLOPS, TB/sec)
Args:
args: Parsed command line arguments containing test configuration
Returns:
dict: List of dictionaries containing performance results
"""
if args.verbose >= 1:
print("[INFO] Running testMmFp4")
print(f"[INFO] FlashInfer version: {flashinfer.__version__}")
device = get_device(args)
if args.generate_repro_command:
print(
f"[INFO] To reproduce this test case, run the following command: {args.repro_command}"
)
## Parse input arguments
backends = args.backends
m = args.m
n = args.n
k = args.k
res_dtype = args.out_dtype
backends = args.backends
is_cuda_graph_compatible = not args.no_cuda_graph
run_refcheck = args.refcheck
use_128x4_sf_layout = args.use_128x4_sf_layout
res_dtype = dtype_str_to_torch_dtype(args.out_dtype)
if res_dtype not in [torch.bfloat16, torch.float16]:
raise ValueError(
f"Unsupported res dtype: {res_dtype}. Supported dtypes are bfloat16 and float16."
)
## Done parsing input arguments
if "trtllm" in backends:
remove_trtllm = False
if res_dtype == torch.float16:
print("[INFO] trtllm backend does not suppot float16 output")
remove_trtllm = True
if remove_trtllm:
backends.remove("trtllm")
if "cutlass" in backends:
remove_cutlass = False
if not use_128x4_sf_layout:
print("[INFO] cutlass backend does not suppot use_128x4_sf_layout=False")
remove_cutlass = True
if remove_cutlass:
backends.remove("cutlass")
if "cudnn" in backends:
remove_cudnn = False
if not use_128x4_sf_layout:
print("[INFO] cudnn backend does not suppot use_128x4_sf_layout=False")
remove_cudnn = True
if remove_cudnn:
backends.remove("cudnn")
if len(backends) == 0:
print("[ERROR] No backends to test. Exiting.")
return
input = torch.randn([m, k], device=device, dtype=torch.bfloat16)
mat2 = torch.randn([n, k], device=device, dtype=torch.bfloat16)
a_sf_layout = (
flashinfer.SfLayout.layout_128x4
if use_128x4_sf_layout
else flashinfer.SfLayout.layout_8x4
)
global_sf_input = (448 * 6) / input.float().abs().nan_to_num().max()
global_sf_mat2 = (448 * 6) / mat2.float().abs().nan_to_num().max()
input_fp4, input_inv_s = flashinfer.nvfp4_quantize(
input, global_sf_input, sfLayout=a_sf_layout, do_shuffle=False
)
mat2_fp4, mat2_inv_s = flashinfer.nvfp4_quantize(
mat2,
global_sf_mat2,
sfLayout=flashinfer.SfLayout.layout_128x4,
do_shuffle=False,
)
if "trtllm" in backends:
mat2_fp4_trtllm, mat2_inv_s_trtllm = flashinfer.nvfp4_quantize(
mat2,
global_sf_mat2,
sfLayout=flashinfer.SfLayout.layout_128x4,
do_shuffle=True,
)
if args.verbose >= 2:
print(f"[VVERBOSE] {input_fp4.shape = }")
print(f"[VVERBOSE] {input_fp4.dtype = }")
print(f"[VVERBOSE] {mat2_fp4.shape = }")
print(f"[VVERBOSE] {mat2_fp4.dtype = }")
alpha = 1.0 / (global_sf_input * global_sf_mat2)
# res = torch.empty([m, n], device="cuda", dtype=res_dtype)
def run_backend(backend):
if backend in ["cudnn", "trtllm", "cutlass"]:
return flashinfer.gemm.mm_fp4(
a=input_fp4,
b=mat2_fp4.T if backend != "trtllm" else mat2_fp4_trtllm.T,
a_descale=input_inv_s,
b_descale=mat2_inv_s.T if backend != "trtllm" else mat2_inv_s_trtllm.T,
alpha=alpha,
out_dtype=res_dtype,
block_size=16, # Only supports 16
use_8x4_sf_layout=not use_128x4_sf_layout,
backend=backend,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
has_reference_output = False
if run_refcheck:
reference_output = torch.mm(input, mat2.T)
has_reference_output = True
# Storage for timing results and outputs
backend_times = {backend: [] for backend in backends}
outputs = {}
for cur_backend in backends:
if run_refcheck:
outputs[cur_backend] = run_backend(cur_backend).detach()
if is_cuda_graph_compatible:
backend_times[cur_backend] = bench_gpu_time_with_cudagraph(
fn=lambda: run_backend(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
num_iters_within_graph=20,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=True,
)
else:
backend_times[cur_backend] = bench_gpu_time(
fn=lambda: run_backend(cur_backend),
dry_run_iters=args.dry_run_iters,
repeat_iters=args.num_iters,
l2_flush=True,
l2_flush_size_mb=256,
l2_flush_device=device,
sleep_after_run=True,
)
tested_backends = list(outputs.keys())
tested_outputs = list(outputs.values())
if len(tested_backends) > 0:
if run_refcheck and has_reference_output:
for i in range(len(tested_backends)):
try:
cos_sim = F.cosine_similarity(
reference_output.reshape(-1),
tested_outputs[i].reshape(-1),
dim=0,
)
assert cos_sim > 0.97
except AssertionError as e:
print(
f"[ERROR] Output tensor mismatch between backends {tested_backends[0]} and {tested_backends[i]}"
)
if not args.allow_output_mismatch:
print(e)
raise
res = []
for backend in backends:
if len(backend_times[backend]) > 0:
median_time = np.median(backend_times[backend])
std_time = np.std(backend_times[backend])
problem_flops = 2 * m * n * k
problem_bytes = (
m * k * 0.5 + n * k * 0.5 + m * n * res_dtype.itemsize
) # 0.5 for fp4
tflops = problem_flops / (10**9 * median_time) # in TFLOPs/sec
tb_per_sec = problem_bytes / (10**9 * median_time) # in TB/sec
print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec)
if args.output_path is not None:
cur_res = defaultdict(str)
cur_res["routine"] = args.routine
cur_res["median_time"] = median_time
cur_res["std_time"] = std_time
cur_res["tflops"] = tflops
cur_res["tb_per_sec"] = tb_per_sec
cur_res["m"] = m
cur_res["n"] = n
cur_res["k"] = k
cur_res["out_dtype"] = res_dtype
cur_res["use_128x4_sf_layout"] = use_128x4_sf_layout
cur_res["backend"] = backend
cur_res["case_tag"] = args.case_tag
res.append(cur_res)
return res