sglang_v0.5.2/flashinfer_0.3.1/benchmarks/bench_deepgemm_blackwell.py

150 lines
5.1 KiB
Python

"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
import torch
from flashinfer.gemm import (
batch_deepgemm_fp8_nt_groupwise,
group_deepgemm_fp8_nt_groupwise,
)
from flashinfer.testing.utils import bench_gpu_time, quantize_fp8
def bench_deepgemm_grouped_fp8_blackwell(batch_size, m, n, k, in_dtype, out_dtype):
"""Benchmark DeepGEMM-based grouped GEMM with FP8 quantization."""
# Create float32 input tensors
a_f32 = torch.randn(batch_size * m, k, device="cuda", dtype=torch.float32)
b_f32 = torch.randn(batch_size, n, k, device="cuda", dtype=torch.float32)
# Quantize tensor A using per-token quantization
a_fp8, a_scale = quantize_fp8(a_f32, (batch_size * m, k // 128), (1, 128), "K")
# Quantize tensor B using per-block quantization
b_fp8, b_scale = quantize_fp8(
b_f32, (batch_size, n // 128, k // 128), (1, 128, 128), "K"
)
# Create group assignment indices
m_indices = torch.arange(
batch_size, device="cuda", dtype=torch.int32
).repeat_interleave(m)
# Pre-allocate output tensor
out = torch.empty(batch_size * m, n, device="cuda", dtype=out_dtype)
# Benchmark the DeepGEMM function
measurements = bench_gpu_time(
lambda: group_deepgemm_fp8_nt_groupwise(
a_fp8, b_fp8, a_scale, b_scale, m_indices, out=out, out_dtype=out_dtype
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)
tflops_per_second = 2 * batch_size * m * n * k * 1e-9 / ms
memory_bandwidth_per_second = (
sum(
[
_.numel() * _.element_size()
for _ in [a_fp8, b_fp8, a_scale, b_scale, m_indices, out]
]
)
* 1e-9
/ ms
)
print(
f"group_deepgemm_fp8_nt_groupwise batch_size={batch_size} m={m} n={n} k={k} "
f"in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s"
f"memory_bandwidth: {memory_bandwidth_per_second:.2f} TB/s"
)
return tflops_per_second
def bench_deepgemm_batch_fp8_blackwell(batch_size, m, n, k, in_dtype, out_dtype):
"""Benchmark DeepGEMM-based batch GEMM with FP8 quantization."""
a = torch.randn((batch_size, m, k), device="cuda", dtype=torch.float32)
b = torch.randn((batch_size, n, k), device="cuda", dtype=torch.float32)
masked_m = torch.randint(0, m, (batch_size,), device="cuda", dtype=torch.int32)
a_fp8, a_scale = quantize_fp8(a, (batch_size, m, k // 128), (1, 1, 128), "K")
b_fp8, b_scale = quantize_fp8(
b, (batch_size, n // 128, k // 128), (1, 128, 128), "K"
)
expected_m = min(int(masked_m.float().mean()) + 1, m)
out = torch.empty((batch_size, m, n), device="cuda", dtype=out_dtype)
# Benchmark the DeepGEMM function
measurements = bench_gpu_time(
lambda: batch_deepgemm_fp8_nt_groupwise(
a_fp8,
b_fp8,
a_scale,
b_scale,
masked_m,
expected_m,
out=out,
out_dtype=out_dtype,
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)
tflops_per_second = 2 * batch_size * m * n * k * 1e-9 / ms
memory_bandwidth_per_second = (
sum(
[
_.numel() * _.element_size()
for _ in [a_fp8, b_fp8, a_scale, b_scale, masked_m, out]
]
)
* 1e-9
/ ms
)
print(
f"group_deepgemm_fp8_nt_groupwise batch_size={batch_size} m={m} n={n} k={k} "
f"in_dtype={in_dtype} out_dtype={out_dtype}: {tflops_per_second:.2f} TFLOPs/s"
f"memory_bandwidth: {memory_bandwidth_per_second:.2f} TB/s"
)
return tflops_per_second
if __name__ == "__main__":
print("=== DeepGEMM Grouped FP8 GEMM Benchmark ===\n")
for batch_size in [1, 4, 8, 64, 128, 256]:
for m in [128, 256, 1024, 8192, 16384]:
for n, k in [(128, 512), (512, 128), (4096, 7168), (7168, 2048)]:
if m // batch_size < 128:
continue
if m * batch_size <= 16384: # Limit total problem size
bench_deepgemm_grouped_fp8_blackwell(
batch_size, m, n, k, torch.float8_e4m3fn, torch.bfloat16
)
for batch_size in [1, 4, 8, 64, 128, 256]:
for m in [128, 256, 1024, 8192, 16384]:
for n, k in [(128, 512), (512, 128), (4096, 7168), (7168, 2048)]:
if m * batch_size <= 16384: # Limit total problem size
bench_deepgemm_batch_fp8_blackwell(
batch_size, m, n, k, torch.float8_e4m3fn, torch.bfloat16
)