134 lines
4.5 KiB
Python
134 lines
4.5 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
import flashinfer
|
|
from flashinfer.testing.utils import bench_gpu_time_with_cudagraph
|
|
|
|
num_q_heads = 128
|
|
num_kv_heads = 1
|
|
qk_nope_head_dim = 128
|
|
qk_rope_head_dim = 64
|
|
kv_lora_rank = 512
|
|
|
|
|
|
def bench_trtllm_mla(batch_size, q_len_per_request, seq_len, page_size, dtype):
|
|
torch.manual_seed(42)
|
|
device = "cuda:0"
|
|
|
|
# Initialize tensors
|
|
query = torch.randn(
|
|
batch_size,
|
|
q_len_per_request,
|
|
num_q_heads,
|
|
kv_lora_rank + qk_rope_head_dim,
|
|
device=device,
|
|
).to(dtype)
|
|
|
|
num_tokens = seq_len * batch_size
|
|
num_blocks = (num_tokens + page_size - 1) // page_size
|
|
|
|
# Sequence lengths and block tables
|
|
seq_lens = [torch.randint(1, seq_len, (1,)).item() for _ in range(batch_size)]
|
|
seq_lens[-1] = seq_len
|
|
max_seq_len = max(seq_lens)
|
|
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device)
|
|
|
|
blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size
|
|
max_num_blocks_per_seq = blocks_per_seq.max().item()
|
|
|
|
# Generate random but unique block IDs for all sequences
|
|
total_blocks_needed = sum(blocks_per_seq)
|
|
all_block_ids = torch.randperm(
|
|
total_blocks_needed, device=device
|
|
) # Random permutation
|
|
|
|
# Generate unique block IDs for all sequences
|
|
block_id = 0
|
|
block_tables = torch.zeros(
|
|
(batch_size, max_num_blocks_per_seq), dtype=torch.int, device=device
|
|
)
|
|
|
|
# Populate block tables and track block assignments
|
|
block_id = 0
|
|
for i in range(batch_size):
|
|
num_blocks_needed = blocks_per_seq[i]
|
|
block_tables[i, :num_blocks_needed] = all_block_ids[
|
|
block_id : block_id + num_blocks_needed
|
|
]
|
|
block_id += num_blocks_needed
|
|
|
|
# Create interleaved KV cache
|
|
# Allocate more than needed blocks, block_id is just enough, to mimick real-world cases
|
|
kv_cache = torch.randn(
|
|
size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), device=device
|
|
).to(dtype)
|
|
# (num_blocks, 1, page_size, kv_lora_rank + qk_rope_head_dim)
|
|
|
|
# Allocate workspace buffer
|
|
# todo(Yingyi): calculate the actual size of workspace buffer
|
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
|
|
|
|
# Run decode-MLA
|
|
# warmup
|
|
flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
|
query=query,
|
|
kv_cache=kv_cache.unsqueeze(1),
|
|
workspace_buffer=workspace_buffer,
|
|
qk_nope_head_dim=qk_nope_head_dim,
|
|
kv_lora_rank=kv_lora_rank,
|
|
qk_rope_head_dim=qk_rope_head_dim,
|
|
block_tables=block_tables,
|
|
seq_lens=seq_lens_tensor,
|
|
max_seq_len=max_seq_len,
|
|
bmm1_scale=1.0 / ((128 + 64) ** 0.5),
|
|
bmm2_scale=1.0,
|
|
)
|
|
# benchmark
|
|
measurements = bench_gpu_time_with_cudagraph(
|
|
lambda: flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
|
query=query,
|
|
kv_cache=kv_cache.unsqueeze(1),
|
|
workspace_buffer=workspace_buffer,
|
|
qk_nope_head_dim=qk_nope_head_dim,
|
|
kv_lora_rank=kv_lora_rank,
|
|
qk_rope_head_dim=qk_rope_head_dim,
|
|
block_tables=block_tables,
|
|
seq_lens=seq_lens_tensor,
|
|
max_seq_len=max_seq_len,
|
|
bmm1_scale=1.0 / ((128 + 64) ** 0.5),
|
|
bmm2_scale=1.0,
|
|
),
|
|
dry_run_time_ms=100,
|
|
repeat_time_ms=1000,
|
|
)
|
|
io = (
|
|
query.numel() * query.element_size()
|
|
+ kv_cache.numel() * kv_cache.element_size()
|
|
)
|
|
ms = np.median(measurements)
|
|
flops = (
|
|
2
|
|
* batch_size
|
|
* num_q_heads
|
|
* (2 * kv_lora_rank + qk_rope_head_dim)
|
|
* seq_len
|
|
* q_len_per_request
|
|
)
|
|
print(
|
|
f"batch_size={batch_size}, q_len_per_request={q_len_per_request}, seq_len={seq_len}, num_q_heads={num_q_heads}, num_kv_heads={num_kv_heads}, qk_nope_head_dim={qk_nope_head_dim}, qk_rope_head_dim={qk_rope_head_dim}, kv_lora_rank={kv_lora_rank}, page_size={page_size}"
|
|
)
|
|
print(f"execution time: {ms} ms")
|
|
print(f"memory bandwidth: {io / ms / 1024 / 1024:.2f} GB/s")
|
|
print(f"FLOPs: {flops * 1e-9 / ms:.2f} TFLOPs/s")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
for dtype in [torch.bfloat16, torch.float8_e4m3fn]:
|
|
for page_size in [32, 64]:
|
|
for batch_size in [1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024]:
|
|
for seq_len in [1024, 4096, 8192]:
|
|
for q_len_per_request in [1, 2, 4, 8, 16]:
|
|
bench_trtllm_mla(
|
|
batch_size, q_len_per_request, seq_len, page_size, dtype
|
|
)
|