173 lines
4.4 KiB
Python
173 lines
4.4 KiB
Python
import pytest
|
|
import torch
|
|
|
|
import flashinfer
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 4])
|
|
@pytest.mark.parametrize("s_qo", [32, 64, 87])
|
|
@pytest.mark.parametrize("s_kv", [32, 64, 87])
|
|
@pytest.mark.parametrize("num_kv_heads", [1])
|
|
@pytest.mark.parametrize("num_qo_heads", [1, 16])
|
|
@pytest.mark.parametrize("causal", [True, False])
|
|
def test_cudnn_prefill_deepseek(
|
|
batch_size, s_qo, s_kv, num_kv_heads, num_qo_heads, causal
|
|
):
|
|
if s_qo > s_kv:
|
|
pytest.skip("s_qo > s_kv, skipping test as causal")
|
|
|
|
head_dim_qk = 192
|
|
head_dim_vo = 128
|
|
|
|
return_lse = True
|
|
|
|
# test set up basics
|
|
seed = 0
|
|
torch.manual_seed(seed)
|
|
device = "cuda:0"
|
|
|
|
actual_seq_lens_q = torch.randint(
|
|
1, s_qo + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device
|
|
)
|
|
|
|
actual_seq_lens_kv = torch.randint(
|
|
s_qo, s_kv + 1, (batch_size, 1, 1, 1), dtype=torch.int32, device=device
|
|
)
|
|
|
|
cumsum_s_qo = torch.sum(actual_seq_lens_q)
|
|
|
|
q = torch.randn(
|
|
cumsum_s_qo, num_qo_heads, head_dim_qk, device=device, dtype=torch.bfloat16
|
|
)
|
|
|
|
q_indptr = torch.cat(
|
|
[
|
|
torch.tensor([0], device=device),
|
|
torch.cumsum(actual_seq_lens_q.view(-1), dim=0)
|
|
* head_dim_qk
|
|
* num_qo_heads,
|
|
]
|
|
).int()
|
|
|
|
k_indptr = torch.cat(
|
|
[
|
|
torch.tensor([0], device=device),
|
|
torch.cumsum(actual_seq_lens_kv.view(-1), dim=0)
|
|
* head_dim_qk
|
|
* num_kv_heads,
|
|
]
|
|
).int()
|
|
|
|
v_indptr = torch.cat(
|
|
[
|
|
torch.tensor([0], device=device),
|
|
torch.cumsum(actual_seq_lens_kv.view(-1), dim=0)
|
|
* head_dim_vo
|
|
* num_kv_heads,
|
|
]
|
|
).int()
|
|
|
|
o_indptr = torch.cat(
|
|
[
|
|
torch.tensor([0], device=device),
|
|
torch.cumsum(actual_seq_lens_q.view(-1), dim=0)
|
|
* head_dim_vo
|
|
* num_qo_heads,
|
|
]
|
|
).int()
|
|
|
|
batch_offsets_stats = torch.cat(
|
|
[
|
|
torch.zeros(
|
|
1, device=actual_seq_lens_q.device, dtype=actual_seq_lens_q.dtype
|
|
),
|
|
torch.cumsum(actual_seq_lens_q.flatten(), dim=0) * num_qo_heads,
|
|
]
|
|
).cuda()
|
|
|
|
k_cache = torch.randn(
|
|
batch_size * s_kv,
|
|
num_kv_heads,
|
|
head_dim_qk,
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
v_cache = torch.randn(
|
|
batch_size * s_kv,
|
|
num_kv_heads,
|
|
head_dim_vo,
|
|
device=device,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
|
|
# Initialize scale
|
|
scale = float(1.0 / (head_dim_qk**0.5))
|
|
|
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device)
|
|
|
|
# output = torch.zeros_like(q)
|
|
output, lse = flashinfer.prefill.cudnn_batch_prefill_with_kv_cache(
|
|
q,
|
|
k_cache,
|
|
v_cache,
|
|
scale,
|
|
workspace_buffer,
|
|
max_token_per_sequence=s_qo,
|
|
max_sequence_kv=s_kv,
|
|
actual_seq_lens_q=actual_seq_lens_q,
|
|
actual_seq_lens_kv=actual_seq_lens_kv,
|
|
causal=causal,
|
|
return_lse=return_lse,
|
|
batch_offsets_q=q_indptr,
|
|
batch_offsets_k=k_indptr,
|
|
batch_offsets_v=v_indptr,
|
|
batch_offsets_o=o_indptr,
|
|
batch_offsets_stats=batch_offsets_stats,
|
|
is_cuda_graph_compatible=True,
|
|
)
|
|
|
|
qo_indptr = torch.cat(
|
|
[
|
|
torch.tensor([0], device=device),
|
|
torch.cumsum(actual_seq_lens_q.view(-1), dim=0),
|
|
]
|
|
).int()
|
|
|
|
# kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * s_kv
|
|
|
|
# Create kv_indptr as cumulative sum of actual_seq_lens_kv
|
|
kv_indptr = torch.cat(
|
|
[
|
|
torch.tensor(
|
|
[0],
|
|
device=device,
|
|
),
|
|
torch.cumsum(actual_seq_lens_kv.view(-1), dim=0),
|
|
]
|
|
).int()
|
|
|
|
wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper(
|
|
torch.empty(128 * 1024 * 1024, device="cuda", dtype=torch.uint8),
|
|
kv_layout="NHD",
|
|
)
|
|
wrapper.plan(
|
|
qo_indptr,
|
|
kv_indptr,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim_qk,
|
|
head_dim_vo=head_dim_vo,
|
|
causal=causal,
|
|
sm_scale=scale,
|
|
q_data_type=torch.bfloat16,
|
|
kv_data_type=torch.bfloat16,
|
|
)
|
|
output_ref, lse_ref = wrapper.run(q, k_cache, v_cache, return_lse=True)
|
|
|
|
torch.testing.assert_close(
|
|
output,
|
|
output_ref,
|
|
atol=1e-2,
|
|
rtol=1e-2,
|
|
)
|