sglang_v0.5.2/flashinfer_0.3.1/benchmarks/bench_fused_add_rmsnorm.py

67 lines
2.1 KiB
Python

import argparse
import numpy as np
import torch
import flashinfer
from flashinfer.testing.utils import bench_gpu_time
@torch.inference_mode()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--batch-sizes", nargs="+", type=int, default=[1, 19, 99, 989])
parser.add_argument(
"--hidden-sizes",
nargs="+",
type=int,
default=[111, 500, 1024, 3072, 4096, 8192],
)
parser.add_argument(
"--dtypes", nargs="+", choices=["float16", "bfloat16"], default=["float16"]
)
args = parser.parse_args()
eps = 1e-6
# Loop over each combination of batch_size, hidden_size, and dtype
for batch_size in args.batch_sizes:
for hidden_size in args.hidden_sizes:
for dtype_str in args.dtypes:
dtype = getattr(torch, dtype_str)
# Define tensors with the correct dtype
x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
@torch.cuda.nvtx.range(
f"fused_add_rmsnorm batch_size={batch_size}, hidden_size={hidden_size}, dtype={dtype_str}"
)
def fn() -> None:
flashinfer.fused_add_rmsnorm(x, residual, weight, eps)
# Run benchmarking
measurements = bench_gpu_time(fn)
latency_ms = np.median(measurements)
throughput = (
x.numel() * x.element_size() * 2
+ residual.numel() * residual.element_size() * 2
+ weight.numel() * weight.element_size()
) / (latency_ms * 1e-3)
print(
f"batch_size: {batch_size:3},",
f"hidden_size: {hidden_size:5},",
f"dtype: {dtype_str:8},",
f"latency: {latency_ms * 1e3:2.0f}us,",
f"throughput: {throughput * 1e-9:7.3f}GB/s",
)
print("---")
torch.cuda.profiler.stop()
if __name__ == "__main__":
main()