sglang_v0.5.2/flashinfer_0.3.1/benchmarks/bench_hopper_fp8_attention.py

70 lines
2.5 KiB
Python

import numpy as np
import torch
import flashinfer
from flashinfer.testing.utils import bench_gpu_time
def bench_single_prefill(seq_len, num_heads, causal, head_dim):
num_qo_heads = num_kv_heads = num_heads
q = torch.randn(seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda")
k = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda")
v = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda")
sm80_ms, sm90_ms = (
np.median(
bench_gpu_time(
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
q, k, v, causal=causal, backend=backend
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
)
for backend in ["fa2", "fa3"]
)
q = torch.randn(
seq_len, num_qo_heads, head_dim, dtype=torch.half, device="cuda"
).to(dtype=torch.float8_e4m3fn)
k = torch.randn(
seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
).to(dtype=torch.float8_e4m3fn)
v = torch.randn(
seq_len, num_kv_heads, head_dim, dtype=torch.half, device="cuda"
).to(dtype=torch.float8_e4m3fn)
fp8_sm90_ms = np.median(
bench_gpu_time(
lambda: flashinfer.single_prefill_with_kv_cache_return_lse(
q, k, v, causal=causal, backend="fa3", o_dtype=torch.half
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
)
def flops(ms):
if causal:
return seq_len * seq_len * num_qo_heads * head_dim * 2 / ms / 1e9
else:
return seq_len * seq_len * num_qo_heads * head_dim * 4 / ms / 1e9
print(
f"bench_single_prefill (seq_len={seq_len}, num_heads={num_heads}, causal={causal}, head_dim={head_dim}), fa2-template: {flops(sm80_ms):.3f} TFLOPs/s, fa3-template: {flops(sm90_ms):.3f} TFLOPs/s, fa3-fp8: {flops(fp8_sm90_ms):.3f} TFLOPs/s"
)
if __name__ == "__main__":
device_capability = torch.cuda.get_device_capability()
if device_capability[0] != 9:
print(f"Current device capability: {device_capability}.")
print("Current benchmark targets capability (9, 0). Returning...")
exit()
for seq_len in [4096, 8192, 16384]:
for num_heads in [24, 32]:
for causal in [True, False]:
for head_dim in [64, 128, 256]:
bench_single_prefill(seq_len, num_heads, causal, head_dim)