sglang_v0.5.2/flashinfer_0.3.1/tests/test_pod_kernels.py

278 lines
7.8 KiB
Python

"""
Copyright (c) 2023 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import pytest
import torch
from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules
import flashinfer
from flashinfer.jit.attention.pytorch import gen_pod_module
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
flashinfer.jit.build_jit_specs(
gen_decode_attention_modules(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[128], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_fp16_qk_reductions
)
+ gen_prefill_attention_modules(
[torch.float16], # q_dtypes
[
torch.float16,
], # kv_dtypes
[128], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_cap
[False], # use_fp16_qk_reductions
)
+ [
gen_pod_module(
torch.float16, # dtype_q
torch.float16, # dtype_kv
torch.float16, # dtype_o
128, # head_dim
0, # pos_encoding_mode_p
False, # use_sliding_window_p
False, # use_logits_soft_cap_p
False, # use_fp16_qk_reduction
torch.int32, # dtype_idx
0, # pos_encoding_mode_d
False, # use_sliding_window_d
False, # use_logits_soft_cap_d
)
],
verbose=False,
)
yield
@pytest.mark.parametrize("kv_len_p", [127, 12288])
@pytest.mark.parametrize("qo_len_p", [127, 12288])
@pytest.mark.parametrize("causal", [False, True])
@pytest.mark.parametrize("batch_size_d", [1, 17, 127])
@pytest.mark.parametrize("kv_len_d", [127, 12288])
@pytest.mark.parametrize("page_size_d", [1, 16])
@pytest.mark.parametrize("kv_layout_d", ["NHD"])
@pytest.mark.parametrize("num_kv_heads", [8])
@pytest.mark.parametrize("num_qo_heads", [8, 32])
@pytest.mark.parametrize("head_dim", [128])
@pytest.mark.parametrize("pos_encoding_mode", ["NONE"])
@pytest.mark.parametrize("q_dtype", [torch.float16])
@pytest.mark.parametrize("kv_dtype", [torch.float16])
@pytest.mark.parametrize("contiguous_kv", [True])
def test_pod_with_paged_kv_cache(
# Prefill params
kv_len_p,
qo_len_p,
causal,
# Decode params
batch_size_d,
kv_len_d,
page_size_d,
kv_layout_d,
# Shared params
num_kv_heads,
num_qo_heads,
head_dim,
pos_encoding_mode,
q_dtype,
kv_dtype,
contiguous_kv,
):
if causal and qo_len_p > kv_len_p:
pytest.skip("Causal prefill with qo_len_p > kv_len_p is not supported")
# Prefill inputs
q_p = torch.randn(
qo_len_p, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16
)
k_p = torch.randn(
kv_len_p, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16
)
v_p = torch.randn(
kv_len_p, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16
)
# Generate prefill reference output
o_ref_p = flashinfer.prefill.single_prefill_with_kv_cache(
q_p,
k_p,
v_p,
causal=causal,
pos_encoding_mode=pos_encoding_mode,
)
# Decode inputs
q_d = torch.randn(
batch_size_d, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16
)
num_pages_per_seq = (kv_len_d + page_size_d - 1) // page_size_d
total_num_pages = num_pages_per_seq * batch_size_d
if kv_layout_d == "HND":
kv_shape = [total_num_pages, 2, num_kv_heads, page_size_d, head_dim]
else:
kv_shape = [total_num_pages, 2, page_size_d, num_kv_heads, head_dim]
if not contiguous_kv:
tmp = [kv_shape[0]]
for v_d in kv_shape[1:]:
tmp.append(2)
tmp.append(v_d)
kv_shape = tmp
kv_data_fp32 = torch.randn(*kv_shape, device="cuda:0", dtype=torch.float32)
kv_data = kv_data_fp32.to(kv_dtype)
kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :]
kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :]
# actual data is stored in non-contiguous memory
assert (
kv_data.stride(-4)
!= kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1]
)
else:
kv_data_fp32 = torch.randn(*kv_shape, device="cuda:0", dtype=torch.float32)
kv_data = kv_data_fp32.to(kv_dtype)
kv_indptr_d = (
torch.arange(0, batch_size_d + 1, device="cuda:0", dtype=torch.int32)
* num_pages_per_seq
)
kv_indices_d = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32)
kv_last_page_len = torch.full(
(batch_size_d,),
(kv_len_d - 1) % page_size_d + 1,
device="cuda:0",
dtype=torch.int32,
)
# Generate decode reference output
decode_workspace_buffer = torch.empty(
32 * 1024 * 1024, device="cuda:0", dtype=torch.int8
)
decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper(
decode_workspace_buffer, kv_layout_d
)
decode_wrapper.plan(
kv_indptr_d,
kv_indices_d,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size_d,
pos_encoding_mode=pos_encoding_mode,
data_type=kv_dtype,
q_data_type=q_dtype,
)
o_ref_d = decode_wrapper.run(q_d, kv_data)
workspace_buffer = torch.empty(32 * 1024 * 1024, device="cuda:0", dtype=torch.int8)
pod_wrapper = flashinfer.PODWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout_d,
)
pod_wrapper.plan(
kv_indptr_d,
kv_indices_d,
kv_last_page_len,
num_qo_heads,
num_kv_heads,
head_dim,
page_size_d,
pos_encoding_mode=pos_encoding_mode,
data_type=kv_dtype,
q_data_type=q_dtype,
)
o_p, o_d = pod_wrapper.run(
q_p,
k_p,
v_p,
q_d,
kv_data,
pos_encoding_mode_p=pos_encoding_mode,
causal_p=causal,
)
# Prefill is run with batch size 1
torch.testing.assert_close(
o_p, o_ref_p, rtol=1e-3, atol=1e-3, msg="Prefill mismatch"
)
# Decode uses all batches at once.
torch.testing.assert_close(
o_d, o_ref_d, rtol=1e-3, atol=1e-3, msg="Decode mismatch"
)
if __name__ == "__main__":
test_pod_with_paged_kv_cache(
# Prefill params
128,
128,
True,
# Decode params
80,
12288,
16,
"NHD",
# Other shared params
8,
8,
128,
"NONE",
torch.float16,
torch.float16,
True,
)
test_pod_with_paged_kv_cache(
# Prefill params
12288,
12288,
True,
# Decode params
220,
12288,
16,
"NHD",
# Other shared params
4,
16,
128,
"NONE",
torch.float16,
torch.float16,
True,
)
test_pod_with_paged_kv_cache(
# Prefill params
16384,
16384,
True,
# Decode params
250,
12288,
16,
"NHD",
# Other shared params
4,
16,
128,
"NONE",
torch.float16,
torch.float16,
True,
)
print("POD test(s) passed!")