""" Copyright (c) 2024 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import pytest import torch from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules import flashinfer @pytest.fixture(autouse=True, scope="module") def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( [torch.float16], # q_dtypes [torch.float16], # kv_dtypes [64, 128, 256], # head_dims [0], # pos_encoding_modes [False, True], # use_sliding_windows [False], # use_logits_soft_caps ) + gen_prefill_attention_modules( [torch.float16], # q_dtypes [torch.float16], # kv_dtypes [64, 128, 256], # head_dims [0], # pos_encoding_modes [False, True], # use_sliding_windows [False], # use_logits_soft_caps [False], # use_fp16_qk_reductions ), verbose=False, ) yield @pytest.mark.parametrize("seq_len", [1, 3, 19, 99, 199, 1999]) @pytest.mark.parametrize("window_left", [3, 13, 23, 43]) @pytest.mark.parametrize("num_kv_heads", [1, 4]) @pytest.mark.parametrize("num_qo_heads", [4, 8]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) def test_single_decode_sliding_window( seq_len, window_left, num_kv_heads, num_qo_heads, head_dim ): q = torch.randn(num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0") k = torch.randn( seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ) v = torch.randn( seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ) k_sliced = k[-(window_left + 1) :] v_sliced = v[-(window_left + 1) :] o_ref = flashinfer.single_decode_with_kv_cache(q, k_sliced, v_sliced) o = flashinfer.single_decode_with_kv_cache(q, k, v, window_left=window_left) torch.testing.assert_close(o.cpu(), o_ref.cpu(), rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1, 3, 13, 32]) @pytest.mark.parametrize("kv_len", [1, 3, 99, 199, 1999]) @pytest.mark.parametrize("window_left", [33, 533]) @pytest.mark.parametrize("num_kv_heads", [1, 4]) @pytest.mark.parametrize("num_qo_heads", [4, 8]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) @pytest.mark.parametrize("page_size", [1, 16]) def test_batch_decode_sliding_window( batch_size, kv_len, window_left, num_kv_heads, num_qo_heads, head_dim, page_size ): q = torch.randn( batch_size, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0" ) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size k_data = torch.randn( total_num_pages, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0", ) v_data = torch.randn( total_num_pages, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0", ) kv_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * num_pages_per_seq ) kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) kv_last_page_len = torch.full( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" ) workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0") wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD") wrapper.plan( kv_indptr, kv_indices, kv_last_page_len, num_qo_heads, num_kv_heads, head_dim, page_size, window_left=window_left, ) o = wrapper.run(q, (k_data, v_data)) for i in range(batch_size): qi = q[i] ki = torch.cat( [ k_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape( -1, num_kv_heads, head_dim ), k_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :], ], dim=0, ) vi = torch.cat( [ v_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape( -1, num_kv_heads, head_dim ), v_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :], ], dim=0, ) o_ref_i = flashinfer.single_decode_with_kv_cache( qi, ki, vi, window_left=window_left, ) torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("seq_len", [1, 3, 19, 99, 199, 1999]) @pytest.mark.parametrize("window_left", [3, 13, 23, 43]) @pytest.mark.parametrize("num_kv_heads", [1, 4]) @pytest.mark.parametrize("num_qo_heads", [4, 8]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) def test_single_decode_prefill_sliding_window_match( seq_len, window_left, num_kv_heads, num_qo_heads, head_dim ): q = torch.randn(1, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0") k = torch.randn( seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ) v = torch.randn( seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ) o = flashinfer.single_prefill_with_kv_cache( q, k, v, window_left=window_left, causal=True ) o_decoded = flashinfer.single_decode_with_kv_cache( q[0], k, v, window_left=window_left ) torch.testing.assert_close(o.cpu()[0], o_decoded.cpu(), rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("seq_len", [99, 199, 1999]) @pytest.mark.parametrize("window_left", [43, 233]) @pytest.mark.parametrize("num_kv_heads", [1, 4]) @pytest.mark.parametrize("num_qo_heads", [4, 8]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) def test_single_prefill_sliding_window( seq_len, window_left, num_kv_heads, num_qo_heads, head_dim ): q = torch.randn( seq_len, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0" ) k = torch.randn( seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ) v = torch.randn( seq_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0" ) row_idx = torch.arange(seq_len, dtype=torch.int32, device="cuda:0")[:, None] col_idx = torch.arange(seq_len, dtype=torch.int32, device="cuda:0")[None, :] mask = (row_idx >= col_idx) & (row_idx - window_left <= col_idx) o_ref = flashinfer.single_prefill_with_kv_cache(q, k, v, custom_mask=mask) o = flashinfer.single_prefill_with_kv_cache( q, k, v, window_left=window_left, causal=True ) torch.testing.assert_close(o.cpu(), o_ref.cpu(), rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("kv_len", [54, 397]) @pytest.mark.parametrize("qo_len", [37, 47]) @pytest.mark.parametrize("window_left", [13, 33]) @pytest.mark.parametrize("num_kv_heads", [1, 4]) @pytest.mark.parametrize("num_qo_heads", [4, 8]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) @pytest.mark.parametrize("page_size", [1, 16]) def test_batch_paged_prefill_sliding_window( batch_size, kv_len, qo_len, window_left, num_kv_heads, num_qo_heads, head_dim, page_size, ): q = torch.randn( batch_size * qo_len, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0", ) q_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len ) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size k_data = torch.randn( total_num_pages, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0", ) v_data = torch.randn( total_num_pages, page_size, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0", ) kv_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * num_pages_per_seq ) kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) kv_last_page_len = torch.full( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" ) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(workspace_buffer, "NHD") wrapper.plan( q_indptr, kv_indptr, kv_indices, kv_last_page_len, num_qo_heads, num_kv_heads, head_dim, page_size, window_left=window_left, causal=True, ) o = wrapper.run( q, (k_data, v_data), ) for i in range(batch_size): qi = q[q_indptr[i] : q_indptr[i + 1]] ki = torch.cat( [ k_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape( -1, num_kv_heads, head_dim ), k_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :], ], dim=0, ) vi = torch.cat( [ v_data[kv_indptr[i] : kv_indptr[i + 1] - 1].reshape( -1, num_kv_heads, head_dim ), v_data[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :], ], dim=0, ) o_ref_i = flashinfer.single_prefill_with_kv_cache( qi, ki, vi, window_left=window_left, causal=True, backend="fa2" ) o_i = o[q_indptr[i] : q_indptr[i + 1]] torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("kv_len", [54, 397]) @pytest.mark.parametrize("qo_len", [37, 47]) @pytest.mark.parametrize("window_left", [13, 33]) @pytest.mark.parametrize("num_kv_heads", [1, 4]) @pytest.mark.parametrize("num_qo_heads", [4, 8]) @pytest.mark.parametrize("head_dim", [64, 128, 256]) def test_batch_ragged_prefill_sliding_window( batch_size, kv_len, qo_len, window_left, num_kv_heads, num_qo_heads, head_dim ): q = torch.randn( batch_size * qo_len, num_qo_heads, head_dim, dtype=torch.float16, device="cuda:0", ) q_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len ) k = torch.randn( batch_size * kv_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0", ) v = torch.randn( batch_size * kv_len, num_kv_heads, head_dim, dtype=torch.float16, device="cuda:0", ) kv_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * kv_len ) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") wrapper = flashinfer.BatchPrefillWithRaggedKVCacheWrapper(workspace_buffer, "NHD") wrapper.plan( q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, window_left=window_left, causal=True, ) o = wrapper.run(q, k, v) for i in range(batch_size): qi = q[q_indptr[i] : q_indptr[i + 1]] ki = k[kv_indptr[i] : kv_indptr[i + 1]] vi = v[kv_indptr[i] : kv_indptr[i + 1]] o_ref_i = flashinfer.single_prefill_with_kv_cache( qi, ki, vi, window_left=window_left, causal=True, ) o_i = o[q_indptr[i] : q_indptr[i + 1]] torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) if __name__ == "__main__": test_single_decode_sliding_window(13, 20, 1, 4, 128) test_single_prefill_sliding_window(13, 20, 1, 4, 128) test_batch_paged_prefill_sliding_window(12, 54, 37, 13, 1, 4, 128, 1) test_batch_ragged_prefill_sliding_window(12, 54, 37, 13, 1, 4, 128)