sglang_v0.5.2/flashinfer_0.3.1/tests/test_cudnn_prefill_deepseek.py

173 lines
4.4 KiB
Python

import pytest
import torch
import flashinfer
@pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("s_qo", [32, 64, 87])
@pytest.mark.parametrize("s_kv", [32, 64, 87])
@pytest.mark.parametrize("num_kv_heads", [1])
@pytest.mark.parametrize("num_qo_heads", [1, 16])
@pytest.mark.parametrize("causal", [True, False])
def test_cudnn_prefill_deepseek(
batch_size, s_qo, s_kv, num_kv_heads, num_qo_heads, causal
):
if s_qo > s_kv:
pytest.skip("s_qo > s_kv, skipping test as causal")
head_dim_qk = 192
head_dim_vo = 128
return_lse = True
# test set up basics
seed = 0
torch.manual_seed(seed)
device = "cuda:0"
actual_seq_lens_q = torch.randint(
1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device
)
actual_seq_lens_kv = torch.randint(
s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device
)
cumsum_s_qo = torch.sum(actual_seq_lens_q)
q = torch.randn(
cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16
)
q_indptr = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_q.view(-1), dim=0)
* head_dim_qk
* num_qo_heads,
]
).int()
k_indptr = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_kv.view(-1), dim=0)
* head_dim_qk
* num_kv_heads,
]
).int()
v_indptr = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_kv.view(-1), dim=0)
* head_dim_vo
* num_kv_heads,
]
).int()
o_indptr = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_q.view(-1), dim=0)
* head_dim_vo
* num_qo_heads,
]
).int()
batch_offsets_stats = torch.cat(
[
torch.zeros(
1, device=actual_seq_lens_q.device, dtype=actual_seq_lens_q.dtype
),
torch.cumsum(actual_seq_lens_q.flatten(), dim=0) * num_qo_heads,
]
).cuda()
k_cache = torch.randn(
batch_size * s_kv,
num_kv_heads,
head_dim_qk,
device=device,
dtype=torch.bfloat16,
)
v_cache = torch.randn(
batch_size * s_kv,
num_kv_heads,
head_dim_vo,
device=device,
dtype=torch.bfloat16,
)
# Initialize scale
scale = float(1.0 / (head_dim_qk**0.5))
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
# output = torch.zeros_like(q)
output, lse = flashinfer.prefill.cudnn_batch_prefill_with_kv_cache(
q,
k_cache,
v_cache,
scale,
workspace_buffer,
max_token_per_sequence=s_qo,
max_sequence_kv=s_kv,
actual_seq_lens_q=actual_seq_lens_q,
actual_seq_lens_kv=actual_seq_lens_kv,
causal=causal,
return_lse=return_lse,
batch_offsets_q=q_indptr,
batch_offsets_k=k_indptr,
batch_offsets_v=v_indptr,
batch_offsets_o=o_indptr,
batch_offsets_stats=batch_offsets_stats,
is_cuda_graph_compatible=True,
)
qo_indptr = torch.cat(
[
torch.tensor([0], device=device),
torch.cumsum(actual_seq_lens_q.view(-1), dim=0),
]
).int()
# kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * s_kv
# Create kv_indptr as cumulative sum of actual_seq_lens_kv
kv_indptr = torch.cat(
[
torch.tensor(
[0],
device=device,
),
torch.cumsum(actual_seq_lens_kv.view(-1), dim=0),
]
).int()
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8),
kv_layout="NHD",
)
wrapper.plan(
qo_indptr,
kv_indptr,
num_qo_heads,
num_kv_heads,
head_dim_qk,
head_dim_vo=head_dim_vo,
causal=causal,
sm_scale=scale,
q_data_type=torch.bfloat16,
kv_data_type=torch.bfloat16,
)
output_ref, lse_ref = wrapper.run(q, k_cache, v_cache, return_lse=True)
torch.testing.assert_close(
output,
output_ref,
atol=1e-2,
rtol=1e-2,
)