sglang_v0.5.2/flashinfer_0.3.1/benchmarks/bench_trtllm_gen_mla.py

134 lines
4.5 KiB
Python

import numpy as np
import torch
import flashinfer
from flashinfer.testing.utils import bench_gpu_time_with_cudagraph
num_q_heads = 128
num_kv_heads = 1
qk_nope_head_dim = 128
qk_rope_head_dim = 64
kv_lora_rank = 512
def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype):
torch.manual_seed(42)
device = "cuda:0"
# Initialize tensors
query = torch.randn(
batch_size,
q_len_per_request,
num_q_heads,
kv_lora_rank + qk_rope_head_dim,
device=device,
).to(dtype)
num_tokens = seq_len * batch_size
num_blocks = (num_tokens + page_size - 1) // page_size
# Sequence lengths and block tables
seq_lens = [torch.randint(1, seq_len, (1,)).item() for _ in range(batch_size)]
seq_lens[-1] = seq_len
max_seq_len = max(seq_lens)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device)
blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size
max_num_blocks_per_seq = blocks_per_seq.max().item()
# Generate random but unique block IDs for all sequences
total_blocks_needed = sum(blocks_per_seq)
all_block_ids = torch.randperm(
total_blocks_needed, device=device
) # Random permutation
# Generate unique block IDs for all sequences
block_id = 0
block_tables = torch.zeros(
(batch_size, max_num_blocks_per_seq), dtype=torch.int, device=device
)
# Populate block tables and track block assignments
block_id = 0
for i in range(batch_size):
num_blocks_needed = blocks_per_seq[i]
block_tables[i, :num_blocks_needed] = all_block_ids[
block_id : block_id + num_blocks_needed
]
block_id += num_blocks_needed
# Create interleaved KV cache
# Allocate more than needed blocks, block_id is just enough, to mimick real-world cases
kv_cache = torch.randn(
size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), device=device
).to(dtype)
# (num_blocks, 1, page_size, kv_lora_rank + qk_rope_head_dim)
# Allocate workspace buffer
# todo(Yingyi): calculate the actual size of workspace buffer
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
# Run decode-MLA
# warmup
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache.unsqueeze(1),
workspace_buffer=workspace_buffer,
qk_nope_head_dim=qk_nope_head_dim,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
block_tables=block_tables,
seq_lens=seq_lens_tensor,
max_seq_len=max_seq_len,
bmm1_scale=1.0 / ((128 + 64) ** 0.5),
bmm2_scale=1.0,
)
# benchmark
measurements = bench_gpu_time_with_cudagraph(
lambda: flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache.unsqueeze(1),
workspace_buffer=workspace_buffer,
qk_nope_head_dim=qk_nope_head_dim,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
block_tables=block_tables,
seq_lens=seq_lens_tensor,
max_seq_len=max_seq_len,
bmm1_scale=1.0 / ((128 + 64) ** 0.5),
bmm2_scale=1.0,
),
dry_run_time_ms=100,
repeat_time_ms=1000,
)
io = (
query.numel() * query.element_size()
+ kv_cache.numel() * kv_cache.element_size()
)
ms = np.median(measurements)
flops = (
2
* batch_size
* num_q_heads
* (2 * kv_lora_rank + qk_rope_head_dim)
* seq_len
* q_len_per_request
)
print(
f"batch_size={batch_size}, q_len_per_request={q_len_per_request}, seq_len={seq_len}, num_q_heads={num_q_heads}, num_kv_heads={num_kv_heads}, qk_nope_head_dim={qk_nope_head_dim}, qk_rope_head_dim={qk_rope_head_dim}, kv_lora_rank={kv_lora_rank}, page_size={page_size}"
)
print(f"execution time: {ms} ms")
print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s")
print(f"FLOPs: {flops * 1e-9 / ms:.2f} TFLOPs/s")
if __name__ == "__main__":
for dtype in [torch.bfloat16, torch.float8_e4m3fn]:
for page_size in [32, 64]:
for batch_size in [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024]:
for seq_len in [1024, 4096, 8192]:
for q_len_per_request in [1, 2, 4, 8, 16]:
bench_trtllm_mla(
batch_size, q_len_per_request, seq_len, page_size, dtype
)