67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
import argparse
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import flashinfer
|
|
from flashinfer.testing.utils import bench_gpu_time
|
|
|
|
|
|
@torch.inference_mode()
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--batch-sizes", nargs="+", type=int, default=[1, 19, 99, 989])
|
|
parser.add_argument(
|
|
"--hidden-sizes",
|
|
nargs="+",
|
|
type=int,
|
|
default=[111, 500, 1024, 3072, 4096, 8192],
|
|
)
|
|
parser.add_argument(
|
|
"--dtypes", nargs="+", choices=["float16", "bfloat16"], default=["float16"]
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
eps = 1e-6
|
|
|
|
# Loop over each combination of batch_size, hidden_size, and dtype
|
|
for batch_size in args.batch_sizes:
|
|
for hidden_size in args.hidden_sizes:
|
|
for dtype_str in args.dtypes:
|
|
dtype = getattr(torch, dtype_str)
|
|
|
|
# Define tensors with the correct dtype
|
|
x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda")
|
|
residual = torch.randn_like(x)
|
|
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
|
|
|
|
@torch.cuda.nvtx.range(
|
|
f"fused_add_rmsnorm batch_size={batch_size}, hidden_size={hidden_size}, dtype={dtype_str}"
|
|
)
|
|
def fn() -> None:
|
|
flashinfer.fused_add_rmsnorm(x, residual, weight, eps)
|
|
|
|
# Run benchmarking
|
|
measurements = bench_gpu_time(fn)
|
|
latency_ms = np.median(measurements)
|
|
throughput = (
|
|
x.numel() * x.element_size() * 2
|
|
+ residual.numel() * residual.element_size() * 2
|
|
+ weight.numel() * weight.element_size()
|
|
) / (latency_ms * 1e-3)
|
|
print(
|
|
f"batch_size: {batch_size:3},",
|
|
f"hidden_size: {hidden_size:5},",
|
|
f"dtype: {dtype_str:8},",
|
|
f"latency: {latency_ms * 1e3:2.0f}us,",
|
|
f"throughput: {throughput * 1e-9:7.3f}GB/s",
|
|
)
|
|
|
|
print("---")
|
|
|
|
torch.cuda.profiler.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|