sglang_v0.5.2/flashinfer_0.3.1/benchmarks/bench_trtllm_fmha.py

233 lines
7.2 KiB
Python

import numpy as np
import torch
import flashinfer
from flashinfer.testing.utils import bench_gpu_time, bench_gpu_time_with_cudagraph
page_size = 16
num_kv_heads = 4
num_qo_heads = 32
head_dim = 128
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
def bench_trtllm_fmha(batch_size, seq_len, kv_cache_dtype):
torch.manual_seed(42)
seq_lens = torch.full((batch_size,), seq_len, device="cuda:0", dtype=torch.int32)
seq_lens_blocks = torch.ceil(seq_lens / page_size).int()
kv_indptr = torch.zeros(batch_size + 1, dtype=torch.int, device="cuda:0")
kv_indptr[1:] = torch.cumsum(seq_lens_blocks, dim=0)
last_page_len = (seq_lens - (seq_lens_blocks - 1) * page_size).int()
last_page_len[last_page_len == 0] = page_size
num_blocks = kv_indptr[-1].item()
kv_indices = torch.arange(num_blocks, dtype=torch.int32, device="cuda:0")
q = torch.rand(batch_size, num_qo_heads, head_dim, device="cuda:0").to(
torch.bfloat16
)
kv_data = torch.randn(
num_blocks, 2, num_kv_heads, page_size, head_dim, device="cuda:0"
).to(torch.float8_e4m3fn if kv_cache_dtype == "fp8" else torch.float16)
wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "HND", backend="trtllm-gen"
)
wrapper.plan(
kv_indptr,
kv_indices,
last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
pos_encoding_mode="NONE",
q_data_type=q.dtype,
kv_data_type=kv_data.dtype,
)
# add one warmup here
wrapper.run(q, kv_data)
torch.cuda.synchronize()
measurements = bench_gpu_time(lambda: wrapper.run(q, kv_data))
ms = np.median(measurements)
io = q.numel() * q.element_size() + kv_data.numel() * kv_data.element_size()
print(
f"batch_size={batch_size}, seq_len={seq_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, page_size={page_size}"
)
print(f"execution time: {ms}ms")
print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s")
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax * 0.1
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
def bench_trtllm_fmha_wrapper(
kv_layout,
batch_size,
max_seq_len,
page_size,
num_kv_heads,
head_dim,
q_dtype,
head_grp_size,
kv_cache_dtype,
window_left,
bench_with_sink,
):
torch.manual_seed(42)
device = "cuda:0"
num_qo_heads = num_kv_heads * head_grp_size
batch_size = batch_size
# Initialize tensors
num_tokens = max_seq_len * batch_size
num_blocks = (num_tokens + page_size - 1) // page_size
dtype_map = {
"half": torch.float16,
"bf16": torch.bfloat16,
"fp8": torch.float8_e4m3fn,
}
q = torch.randn(batch_size, num_qo_heads, head_dim, device=device).to(
dtype_map[q_dtype]
)
# Sequence lengths and block tables
seq_lens = torch.full((batch_size,), max_seq_len)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device)
blocks_per_seq = [(seq_len + page_size - 1) // page_size for seq_len in seq_lens]
# Generate random but unique block IDs for all sequences
total_blocks_needed = sum(blocks_per_seq)
all_block_ids = torch.randperm(
total_blocks_needed, device=device
) # Random permutation
kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
kv_cache = torch.randn(size=kv_cache_shape).to(q.dtype)
if kv_cache_dtype.startswith("fp8") and q_dtype != "fp8":
kv_cache, _ = to_float8(kv_cache)
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size
sinks = (
torch.randn(num_qo_heads, device=device, dtype=torch.float32)
if bench_with_sink
else None
)
# Compute kv_indptr as cumulative sum of blocks per sequence
kv_indptr = (
torch.cat(
[torch.tensor([0], device=device), torch.cumsum(blocks_per_seq, dim=0)]
)
.int()
.to(device)
)
kv_indices = all_block_ids.int()
# Calculate last page lengths
kv_last_page_len = seq_lens_tensor % page_size
kv_last_page_len[kv_last_page_len == 0] = page_size
# trtllm-gen
wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer, "HND", backend="trtllm-gen"
)
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
pos_encoding_mode="NONE",
data_type=kv_cache.dtype,
q_data_type=q.dtype,
window_left=window_left,
)
# add one warmup here
wrapper.run(q, kv_cache, sinks=sinks)
torch.cuda.synchronize()
measurements = bench_gpu_time_with_cudagraph(
lambda: wrapper.run(q, kv_cache, sinks=sinks),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
ms = np.median(measurements)
io = q.numel() * q.element_size() + kv_cache.numel() * kv_cache.element_size()
print(
f"batch_size={batch_size}, seq_len={max_seq_len}, num_qo_heads={num_qo_heads}, num_kv_heads={num_kv_heads}, head_dim={head_dim}, page_size={page_size}"
)
print(f"execution time: {ms}ms")
print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Benchmark TRTLLM FMHA")
parser.add_argument(
"--head_dim", type=int, default=64, help="Dimension of each head"
)
parser.add_argument(
"--num_kv_heads", type=int, default=8, help="Number of key/value heads"
)
parser.add_argument(
"--page_size", type=int, default=16, help="Size of each page [16, 32, 64]"
)
parser.add_argument(
"--head_grp_size",
type=int,
default=8,
help="Number of query heads per key-value head (group size)",
)
parser.add_argument("--sink", action="store_true", help="Whether to test with sink")
parser.add_argument(
"--batch_sizes",
type=int,
nargs="+",
default=[4, 128, 256],
help="List of batch sizes to test",
)
parser.add_argument(
"--seq_lens",
type=int,
nargs="+",
default=[1024, 4096, 8192, 16384],
help="List of sequence lengths to test",
)
args = parser.parse_args()
for batch_size in args.batch_sizes:
for seq_len in args.seq_lens:
bench_trtllm_fmha_wrapper(
kv_layout="HND",
batch_size=batch_size,
max_seq_len=seq_len,
page_size=args.page_size,
num_kv_heads=args.num_kv_heads,
head_dim=args.head_dim,
q_dtype="bf16",
head_grp_size=args.head_grp_size,
kv_cache_dtype="auto",
window_left=-1,
bench_with_sink=args.sink,
)