sglang_v0.5.2/flashinfer_0.3.1/tests/test_non_contiguous_prefill.py

200 lines
7.0 KiB
Python

"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import pytest
import torch
from jit_utils import gen_prefill_attention_modules
import flashinfer
@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
flashinfer.jit.build_jit_specs(
gen_prefill_attention_modules(
[torch.float16], # q_dtypes
[torch.float16], # kv_dtypes
[64, 128, 256], # head_dims
[0], # pos_encoding_modes
[False], # use_sliding_windows
[False], # use_logits_soft_caps
[False], # use_fp16_qk_reductions
),
verbose=False,
)
yield
@pytest.mark.parametrize("seq_len", [1, 7, 127, 999, 3579])
@pytest.mark.parametrize("num_kv_heads", [1, 4, 8])
@pytest.mark.parametrize("num_qo_heads", [4, 8, 32])
@pytest.mark.parametrize("head_dim", [64, 128, 256])
@pytest.mark.parametrize("causal", [True, False])
def test_single_prefill_packed_input(
seq_len, num_kv_heads, num_qo_heads, head_dim, causal
):
if num_qo_heads % num_kv_heads != 0:
pytest.skip("num_qo_heads must be a multiple of num_kv_heads")
qkv_packed = torch.randn(
seq_len,
(num_qo_heads + 2 * num_kv_heads) * head_dim,
dtype=torch.float16,
device="cuda:0",
)
q = qkv_packed[:, : num_qo_heads * head_dim].reshape(
seq_len, num_qo_heads, head_dim
)
k = qkv_packed[
:, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
].reshape(seq_len, num_kv_heads, head_dim)
v = qkv_packed[:, (num_qo_heads + num_kv_heads) * head_dim :].reshape(
seq_len, num_kv_heads, head_dim
)
o_packed = flashinfer.single_prefill_with_kv_cache(q, k, v, causal=causal)
o_contiguous = flashinfer.single_prefill_with_kv_cache(
q.contiguous(), k.contiguous(), v.contiguous(), causal=causal
)
torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99])
@pytest.mark.parametrize("seq_len", [1, 7, 127, 257])
@pytest.mark.parametrize("num_kv_heads", [1, 4, 8])
@pytest.mark.parametrize("num_qo_heads", [4, 8])
@pytest.mark.parametrize("head_dim", [64, 128, 256])
@pytest.mark.parametrize("causal", [True, False])
def test_batch_ragged_prefill_packed_input(
batch_size, seq_len, num_kv_heads, num_qo_heads, head_dim, causal
):
if num_qo_heads % num_kv_heads != 0:
pytest.skip("num_qo_heads must be a multiple of num_kv_heads")
nnz = batch_size * seq_len
qkv_packed = torch.randn(
nnz,
(num_qo_heads + 2 * num_kv_heads) * head_dim,
dtype=torch.float16,
device="cuda:0",
)
q = qkv_packed[:, : num_qo_heads * head_dim].reshape(nnz, num_qo_heads, head_dim)
k = qkv_packed[
:, num_qo_heads * head_dim : (num_qo_heads + num_kv_heads) * head_dim
].reshape(nnz, num_kv_heads, head_dim)
v = qkv_packed[:, (num_qo_heads + num_kv_heads) * head_dim :].reshape(
nnz, num_kv_heads, head_dim
)
qo_indptr = torch.tensor(
[i * seq_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
)
kv_indptr = qo_indptr
workspace_buffer = torch.empty(
(256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0"
)
wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer)
wrapper.plan(
qo_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, causal=causal
)
o_packed = wrapper.run(q, k, v)
o_contiguous = wrapper.run(q.contiguous(), k.contiguous(), v.contiguous())
torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99])
@pytest.mark.parametrize("page_size", [1, 5])
@pytest.mark.parametrize("seq_len", [1, 7, 127, 257])
@pytest.mark.parametrize("num_kv_heads", [1, 4, 8])
@pytest.mark.parametrize("num_qo_heads", [4, 8])
@pytest.mark.parametrize("head_dim", [64, 128, 256])
@pytest.mark.parametrize("causal", [True, False])
def test_batch_paged_prefill_packed_input(
batch_size,
page_size,
seq_len,
num_kv_heads,
num_qo_heads,
head_dim,
causal,
):
if num_qo_heads % num_kv_heads != 0:
pytest.skip("num_qo_heads must be a multiple of num_kv_heads")
nnz = batch_size * seq_len
num_pages_per_req = (seq_len + page_size - 1) // page_size
num_pages = batch_size * num_pages_per_req
last_page_len = (seq_len - 1) % page_size + 1
k_cache = torch.randn(
size=(num_pages, page_size, num_kv_heads, head_dim),
dtype=torch.float16,
device="cuda:0",
)
v_cache = torch.randn_like(k_cache)
paged_kv_cache = (k_cache, v_cache)
workspace_buffer = torch.empty(
(256 * 1024 * 1024,), dtype=torch.uint8, device="cuda:0"
)
qo_indptr = torch.tensor(
[i * seq_len for i in range(batch_size + 1)], dtype=torch.int32, device="cuda:0"
)
paged_kv_indptr = torch.tensor(
[i * num_pages_per_req for i in range(batch_size + 1)],
dtype=torch.int32,
device="cuda:0",
)
paged_kv_indices = torch.tensor(
list(range(num_pages)), dtype=torch.int32, device="cuda:0"
)
paged_kv_last_page_len = torch.tensor(
[last_page_len for _ in range(batch_size)], dtype=torch.int32, device="cuda:0"
)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer)
wrapper.plan(
qo_indptr=qo_indptr,
paged_kv_indptr=paged_kv_indptr,
paged_kv_indices=paged_kv_indices,
paged_kv_last_page_len=paged_kv_last_page_len,
num_qo_heads=num_qo_heads,
num_kv_heads=num_kv_heads,
head_dim_qk=head_dim,
page_size=page_size,
causal=causal,
)
qkv_packed = torch.randn(
size=(nnz, (num_qo_heads + 2 * num_kv_heads) * head_dim),
dtype=torch.float16,
device="cuda:0",
)
qkv_split_idx = (
num_qo_heads * head_dim,
num_kv_heads * head_dim,
num_kv_heads * head_dim,
)
q, _, _ = qkv_packed.split(qkv_split_idx, dim=-1)
# pretend that we have already appended k/v to paged_kv table
q = q.view(-1, num_qo_heads, head_dim)
o_packed = wrapper.run(q, paged_kv_cache)
o_contiguous = wrapper.run(q.contiguous(), paged_kv_cache)
torch.testing.assert_close(o_packed, o_contiguous, rtol=1e-3, atol=2e-3)
if __name__ == "__main__":
test_single_prefill_packed_input(127, 4, 4, 64, True)
test_batch_ragged_prefill_packed_input(37, 127, 4, 4, 64, True)
test_batch_paged_prefill_packed_input(37, 5, 127, 4, 4, 64, True)