381 lines
12 KiB
Python
381 lines
12 KiB
Python
"""
|
|
Copyright (c) 2024 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules
|
|
|
|
import flashinfer
|
|
|
|
|
|
@pytest.fixture(autouse=True, scope="module")
|
|
def warmup_jit():
|
|
flashinfer.jit.build_jit_specs(
|
|
gen_decode_attention_modules(
|
|
[torch.float16], # q_dtypes
|
|
[torch.float16], # kv_dtypes
|
|
[64, 128, 256], # head_dims
|
|
[0], # pos_encoding_modes
|
|
[False, True], # use_sliding_windows
|
|
[False], # use_logits_soft_caps
|
|
)
|
|
+ gen_prefill_attention_modules(
|
|
[torch.float16], # q_dtypes
|
|
[torch.float16], # kv_dtypes
|
|
[64, 128, 256], # head_dims
|
|
[0], # pos_encoding_modes
|
|
[False, True], # use_sliding_windows
|
|
[False], # use_logits_soft_caps
|
|
[False], # use_fp16_qk_reductions
|
|
),
|
|
verbose=False,
|
|
)
|
|
yield
|
|
|
|
|
|
@pytest.mark.parametrize("seq_len", [1, 3, 19, 99, 199, 1999])
|
|
@pytest.mark.parametrize("window_left", [3, 13, 23, 43])
|
|
@pytest.mark.parametrize("num_kv_heads", [1, 4])
|
|
@pytest.mark.parametrize("num_qo_heads", [4, 8])
|
|
@pytest.mark.parametrize("head_dim", [64, 128, 256])
|
|
def test_single_decode_sliding_window(
|
|
seq_len, window_left, num_kv_heads, num_qo_heads, head_dim
|
|
):
|
|
q = torch.randn(num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0")
|
|
k = torch.randn(
|
|
seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
|
|
)
|
|
v = torch.randn(
|
|
seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
|
|
)
|
|
|
|
k_sliced = k[-(window_left + 1) :]
|
|
v_sliced = v[-(window_left + 1) :]
|
|
|
|
o_ref = flashinfer.single_decode_with_kv_cache(q, k_sliced, v_sliced)
|
|
o = flashinfer.single_decode_with_kv_cache(q, k, v, window_left=window_left)
|
|
|
|
torch.testing.assert_close(o.cpu(), o_ref.cpu(), rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 3, 13, 32])
|
|
@pytest.mark.parametrize("kv_len", [1, 3, 99, 199, 1999])
|
|
@pytest.mark.parametrize("window_left", [33, 533])
|
|
@pytest.mark.parametrize("num_kv_heads", [1, 4])
|
|
@pytest.mark.parametrize("num_qo_heads", [4, 8])
|
|
@pytest.mark.parametrize("head_dim", [64, 128, 256])
|
|
@pytest.mark.parametrize("page_size", [1, 16])
|
|
def test_batch_decode_sliding_window(
|
|
batch_size, kv_len, window_left, num_kv_heads, num_qo_heads, head_dim, page_size
|
|
):
|
|
q = torch.randn(
|
|
batch_size, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0"
|
|
)
|
|
num_pages_per_seq = (kv_len + page_size - 1) // page_size
|
|
total_num_pages = num_pages_per_seq * batch_size
|
|
k_data = torch.randn(
|
|
total_num_pages,
|
|
page_size,
|
|
num_kv_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda:0",
|
|
)
|
|
v_data = torch.randn(
|
|
total_num_pages,
|
|
page_size,
|
|
num_kv_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda:0",
|
|
)
|
|
kv_indptr = (
|
|
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32)
|
|
* num_pages_per_seq
|
|
)
|
|
kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32)
|
|
kv_last_page_len = torch.full(
|
|
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0"
|
|
)
|
|
|
|
workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0")
|
|
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
|
wrapper.plan(
|
|
kv_indptr,
|
|
kv_indices,
|
|
kv_last_page_len,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
page_size,
|
|
window_left=window_left,
|
|
)
|
|
o = wrapper.run(q, (k_data, v_data))
|
|
|
|
for i in range(batch_size):
|
|
qi = q[i]
|
|
ki = torch.cat(
|
|
[
|
|
k_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape(
|
|
-1, num_kv_heads, head_dim
|
|
),
|
|
k_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :],
|
|
],
|
|
dim=0,
|
|
)
|
|
vi = torch.cat(
|
|
[
|
|
v_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape(
|
|
-1, num_kv_heads, head_dim
|
|
),
|
|
v_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :],
|
|
],
|
|
dim=0,
|
|
)
|
|
o_ref_i = flashinfer.single_decode_with_kv_cache(
|
|
qi,
|
|
ki,
|
|
vi,
|
|
window_left=window_left,
|
|
)
|
|
torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("seq_len", [1, 3, 19, 99, 199, 1999])
|
|
@pytest.mark.parametrize("window_left", [3, 13, 23, 43])
|
|
@pytest.mark.parametrize("num_kv_heads", [1, 4])
|
|
@pytest.mark.parametrize("num_qo_heads", [4, 8])
|
|
@pytest.mark.parametrize("head_dim", [64, 128, 256])
|
|
def test_single_decode_prefill_sliding_window_match(
|
|
seq_len, window_left, num_kv_heads, num_qo_heads, head_dim
|
|
):
|
|
q = torch.randn(1, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0")
|
|
k = torch.randn(
|
|
seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
|
|
)
|
|
v = torch.randn(
|
|
seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
|
|
)
|
|
o = flashinfer.single_prefill_with_kv_cache(
|
|
q, k, v, window_left=window_left, causal=True
|
|
)
|
|
o_decoded = flashinfer.single_decode_with_kv_cache(
|
|
q[0], k, v, window_left=window_left
|
|
)
|
|
torch.testing.assert_close(o.cpu()[0], o_decoded.cpu(), rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("seq_len", [99, 199, 1999])
|
|
@pytest.mark.parametrize("window_left", [43, 233])
|
|
@pytest.mark.parametrize("num_kv_heads", [1, 4])
|
|
@pytest.mark.parametrize("num_qo_heads", [4, 8])
|
|
@pytest.mark.parametrize("head_dim", [64, 128, 256])
|
|
def test_single_prefill_sliding_window(
|
|
seq_len, window_left, num_kv_heads, num_qo_heads, head_dim
|
|
):
|
|
q = torch.randn(
|
|
seq_len, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0"
|
|
)
|
|
k = torch.randn(
|
|
seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
|
|
)
|
|
v = torch.randn(
|
|
seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0"
|
|
)
|
|
|
|
row_idx = torch.arange(seq_len, dtype=torch.int32, device="cuda:0")[:, None]
|
|
col_idx = torch.arange(seq_len, dtype=torch.int32, device="cuda:0")[None, :]
|
|
mask = (row_idx >= col_idx) & (row_idx - window_left <= col_idx)
|
|
|
|
o_ref = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask)
|
|
o = flashinfer.single_prefill_with_kv_cache(
|
|
q, k, v, window_left=window_left, causal=True
|
|
)
|
|
torch.testing.assert_close(o.cpu(), o_ref.cpu(), rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [12, 17])
|
|
@pytest.mark.parametrize("kv_len", [54, 397])
|
|
@pytest.mark.parametrize("qo_len", [37, 47])
|
|
@pytest.mark.parametrize("window_left", [13, 33])
|
|
@pytest.mark.parametrize("num_kv_heads", [1, 4])
|
|
@pytest.mark.parametrize("num_qo_heads", [4, 8])
|
|
@pytest.mark.parametrize("head_dim", [64, 128, 256])
|
|
@pytest.mark.parametrize("page_size", [1, 16])
|
|
def test_batch_paged_prefill_sliding_window(
|
|
batch_size,
|
|
kv_len,
|
|
qo_len,
|
|
window_left,
|
|
num_kv_heads,
|
|
num_qo_heads,
|
|
head_dim,
|
|
page_size,
|
|
):
|
|
q = torch.randn(
|
|
batch_size * qo_len,
|
|
num_qo_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda:0",
|
|
)
|
|
q_indptr = (
|
|
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len
|
|
)
|
|
num_pages_per_seq = (kv_len + page_size - 1) // page_size
|
|
total_num_pages = num_pages_per_seq * batch_size
|
|
k_data = torch.randn(
|
|
total_num_pages,
|
|
page_size,
|
|
num_kv_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda:0",
|
|
)
|
|
v_data = torch.randn(
|
|
total_num_pages,
|
|
page_size,
|
|
num_kv_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda:0",
|
|
)
|
|
kv_indptr = (
|
|
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32)
|
|
* num_pages_per_seq
|
|
)
|
|
kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32)
|
|
kv_last_page_len = torch.full(
|
|
(batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0"
|
|
)
|
|
|
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0")
|
|
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD")
|
|
wrapper.plan(
|
|
q_indptr,
|
|
kv_indptr,
|
|
kv_indices,
|
|
kv_last_page_len,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
page_size,
|
|
window_left=window_left,
|
|
causal=True,
|
|
)
|
|
o = wrapper.run(
|
|
q,
|
|
(k_data, v_data),
|
|
)
|
|
|
|
for i in range(batch_size):
|
|
qi = q[q_indptr[i] : q_indptr[i + 1]]
|
|
ki = torch.cat(
|
|
[
|
|
k_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape(
|
|
-1, num_kv_heads, head_dim
|
|
),
|
|
k_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :],
|
|
],
|
|
dim=0,
|
|
)
|
|
vi = torch.cat(
|
|
[
|
|
v_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape(
|
|
-1, num_kv_heads, head_dim
|
|
),
|
|
v_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :],
|
|
],
|
|
dim=0,
|
|
)
|
|
o_ref_i = flashinfer.single_prefill_with_kv_cache(
|
|
qi, ki, vi, window_left=window_left, causal=True, backend="fa2"
|
|
)
|
|
o_i = o[q_indptr[i] : q_indptr[i + 1]]
|
|
torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [12, 17])
|
|
@pytest.mark.parametrize("kv_len", [54, 397])
|
|
@pytest.mark.parametrize("qo_len", [37, 47])
|
|
@pytest.mark.parametrize("window_left", [13, 33])
|
|
@pytest.mark.parametrize("num_kv_heads", [1, 4])
|
|
@pytest.mark.parametrize("num_qo_heads", [4, 8])
|
|
@pytest.mark.parametrize("head_dim", [64, 128, 256])
|
|
def test_batch_ragged_prefill_sliding_window(
|
|
batch_size, kv_len, qo_len, window_left, num_kv_heads, num_qo_heads, head_dim
|
|
):
|
|
q = torch.randn(
|
|
batch_size * qo_len,
|
|
num_qo_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda:0",
|
|
)
|
|
q_indptr = (
|
|
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len
|
|
)
|
|
k = torch.randn(
|
|
batch_size * kv_len,
|
|
num_kv_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda:0",
|
|
)
|
|
v = torch.randn(
|
|
batch_size * kv_len,
|
|
num_kv_heads,
|
|
head_dim,
|
|
dtype=torch.float16,
|
|
device="cuda:0",
|
|
)
|
|
kv_indptr = (
|
|
torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * kv_len
|
|
)
|
|
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0")
|
|
wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer, "NHD")
|
|
wrapper.plan(
|
|
q_indptr,
|
|
kv_indptr,
|
|
num_qo_heads,
|
|
num_kv_heads,
|
|
head_dim,
|
|
window_left=window_left,
|
|
causal=True,
|
|
)
|
|
o = wrapper.run(q, k, v)
|
|
|
|
for i in range(batch_size):
|
|
qi = q[q_indptr[i] : q_indptr[i + 1]]
|
|
ki = k[kv_indptr[i] : kv_indptr[i + 1]]
|
|
vi = v[kv_indptr[i] : kv_indptr[i + 1]]
|
|
o_ref_i = flashinfer.single_prefill_with_kv_cache(
|
|
qi,
|
|
ki,
|
|
vi,
|
|
window_left=window_left,
|
|
causal=True,
|
|
)
|
|
o_i = o[q_indptr[i] : q_indptr[i + 1]]
|
|
torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_single_decode_sliding_window(13, 20, 1, 4, 128)
|
|
test_single_prefill_sliding_window(13, 20, 1, 4, 128)
|
|
test_batch_paged_prefill_sliding_window(12, 54, 37, 13, 1, 4, 128, 1)
|
|
test_batch_ragged_prefill_sliding_window(12, 54, 37, 13, 1, 4, 128)
|