sglang_v0.5.2/flashinfer_0.3.1/tests/test_trtllm_gen_mla.py

217 lines
6.5 KiB
Python

import math
import pytest
import torch
import flashinfer
global_workspace_buffer = None
workspace_size = 128 * 1024 * 1024
@pytest.mark.parametrize(
"batch_size",
[1, 2, 4, 16, 32, 64, 128, 256, 512, 768, 1024],
)
@pytest.mark.parametrize("scale", [1.0, 0.5])
@pytest.mark.parametrize("dtype", [torch.float8_e4m3fn, torch.bfloat16])
@pytest.mark.parametrize("page_size", [32, 64])
@pytest.mark.parametrize(
"q_len_per_request", [1, 2]
) # todo(Yingyi): verify larger q_len_per_request
@pytest.mark.parametrize("dynamic_scale", [False])
@pytest.mark.parametrize("enable_pdl", [True, False, None])
def test_trtllm_batch_decode_mla(
batch_size: int,
scale: float,
dtype: torch.dtype,
page_size: int,
q_len_per_request: int,
dynamic_scale: bool,
enable_pdl: bool,
):
if dynamic_scale and dtype != torch.float8_e4m3fn:
pytest.skip("Dynamic scale is not supported for non-fp8 dtype")
torch.manual_seed(42)
device = "cuda:0"
# Fixed max sequence length
MAX_SEQ_LEN = 1024
# Deepseek attention config (decode-MLA)
num_q_heads = 128
qk_nope_head_dim = 128
qk_rope_head_dim = 64
kv_lora_rank = 512
# Initialize tensors
query = torch.randn(
batch_size,
q_len_per_request,
num_q_heads,
kv_lora_rank + qk_rope_head_dim,
device=device,
).to(dtype)
num_tokens = MAX_SEQ_LEN * batch_size
num_blocks = (num_tokens + page_size - 1) // page_size
# Sequence lengths and block tables
seq_lens = [torch.randint(1, MAX_SEQ_LEN, (1,)).item() for _ in range(batch_size)]
seq_lens[-1] = MAX_SEQ_LEN
max_seq_len = max(seq_lens)
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=device)
blocks_per_seq = (seq_lens_tensor + page_size - 1) // page_size
max_num_blocks_per_seq = blocks_per_seq.max().item()
# Generate random but unique block IDs for all sequences
total_blocks_needed = sum(blocks_per_seq)
all_block_ids = torch.randperm(
total_blocks_needed, device=device
) # Random permutation
# Generate unique block IDs for all sequences
block_id = 0
block_tables = torch.zeros(
(batch_size, max_num_blocks_per_seq), dtype=torch.int, device=device
)
# Populate block tables and track block assignments
block_id = 0
for i in range(batch_size):
num_blocks_needed = blocks_per_seq[i]
block_tables[i, :num_blocks_needed] = all_block_ids[
block_id : block_id + num_blocks_needed
]
block_id += num_blocks_needed
# Create interleaved KV cache
# Allocate more than needed blocks, block_id is just enough, to mimick real-world cases
kv_cache = torch.randn(
size=(num_blocks, page_size, kv_lora_rank + qk_rope_head_dim), device=device
).to(dtype)
# (num_blocks, 1, page_size, kv_lora_rank + qk_rope_head_dim)
# Allocate workspace buffer
# todo(Yingyi): calculate the actual size of workspace buffer
global global_workspace_buffer
if global_workspace_buffer is None:
global_workspace_buffer = torch.zeros(
workspace_size, dtype=torch.int8, device=device
)
workspace_buffer = global_workspace_buffer
bmm1_log2_scale_tensor = (
torch.tensor(
[scale / ((128 + 64) ** 0.5 * math.log2(math.e))],
dtype=torch.float32,
device=device,
)
if dynamic_scale
else None
)
bmm2_scale_tensor = (
torch.tensor([1.0], dtype=torch.float32, device=device)
if dynamic_scale
else None
)
# Run decode-MLA
output = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
query=query,
kv_cache=kv_cache.unsqueeze(1),
workspace_buffer=workspace_buffer,
qk_nope_head_dim=qk_nope_head_dim,
kv_lora_rank=kv_lora_rank,
qk_rope_head_dim=qk_rope_head_dim,
block_tables=block_tables,
seq_lens=seq_lens_tensor,
max_seq_len=max_seq_len,
bmm1_scale=scale / ((128 + 64) ** 0.5),
bmm2_scale=1.0,
bmm1_scale_log2_tensor=bmm1_log2_scale_tensor,
bmm2_scale_tensor=bmm2_scale_tensor,
enable_pdl=enable_pdl,
)
# Run reference attention and align output
sm_scale = scale / (
(128 + 64) ** 0.5
) # use head dimension before matrix absorption
workspace_buffer_ref = torch.empty(workspace_size, dtype=torch.int8, device=device)
wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
workspace_buffer_ref,
backend="fa2",
)
if dtype == torch.float8_e4m3fn:
# convert query and kv_cache to bfloat16
query = query.to(torch.bfloat16)
kv_cache = kv_cache.to(torch.bfloat16)
q_indptr = (
torch.arange(0, batch_size + 1, device=device, dtype=torch.int32)
* q_len_per_request
)
kv_indptr = torch.zeros_like(q_indptr)
kv_indptr[1:] = torch.cumsum(blocks_per_seq, dim=0)
kv_indices = all_block_ids.int()
wrapper.plan(
q_indptr,
kv_indptr,
kv_indices,
seq_lens_tensor,
num_q_heads,
kv_lora_rank,
qk_rope_head_dim,
page_size,
True,
sm_scale,
query.dtype,
kv_cache.dtype,
)
q_nope = query[..., :kv_lora_rank].view(
batch_size * q_len_per_request, num_q_heads, kv_lora_rank
)
q_pe = query[..., kv_lora_rank:].view(
batch_size * q_len_per_request, num_q_heads, qk_rope_head_dim
)
# todo: fix kv_cache
ckv = kv_cache[..., :kv_lora_rank]
kpe = kv_cache[..., kv_lora_rank:]
o_ref = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=False)
# check is nan
assert not torch.isnan(o_ref).any(), "o_ref is nan"
assert not torch.isnan(output).any(), "output is nan"
if dtype == torch.float8_e4m3fn:
try:
torch.testing.assert_close(
output,
o_ref.view(batch_size, q_len_per_request, num_q_heads, -1),
rtol=1e-1,
atol=1e-1,
) # todo: do reference with normal attention?
except AssertionError as e:
print("output:", output)
print("o_ref:", o_ref)
raise e
else:
try:
torch.testing.assert_close(
output,
o_ref.view(batch_size, q_len_per_request, num_q_heads, -1),
rtol=1e-2,
atol=1e-2,
)
except AssertionError as e:
print("output:", output)
print("o_ref:", o_ref)
raise e