""" Copyright (c) 2023 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import numpy import pytest import torch from jit_utils import gen_prefill_attention_modules import flashinfer @pytest.fixture(autouse=True, scope="module") def warmup_jit(): flashinfer.jit.build_jit_specs( gen_prefill_attention_modules( [torch.float16], # q_dtypes [ torch.float16, torch.float8_e4m3fn, torch.float8_e5m2, ], # kv_dtypes [128, 256], # head_dims [0, 1], # pos_encoding_modes [False], # use_sliding_windows [False], # use_logits_soft_caps [False], # use_fp16_qk_reductions ), verbose=False, ) yield @pytest.mark.parametrize("batch_size", [12, 17, 128]) @pytest.mark.parametrize("kv_len", [54, 97, 512, 2048]) @pytest.mark.parametrize("qo_len", [37, 17, 127, 577]) @pytest.mark.parametrize("page_size", [1, 5, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("kv_layout", ["NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) @pytest.mark.parametrize("use_cuda_graph", [True]) @pytest.mark.parametrize("logits_soft_cap", [0.0]) @pytest.mark.parametrize("return_lse", [True]) @pytest.mark.parametrize("contiguous_kv", [True]) def test_batch_prefill_with_paged_kv_cache( batch_size, kv_len, qo_len, page_size, num_kv_heads, num_qo_heads, head_dim, causal, kv_layout, pos_encoding_mode, use_cuda_graph, logits_soft_cap, return_lse, contiguous_kv, ): if qo_len > kv_len and causal: pytest.skip("qo_len > kv_len and causal is not supported") q = torch.randn( batch_size * qo_len, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16, ) q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size if kv_layout == "HND": kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] else: kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] if not contiguous_kv: tmp = [kv_shape[0]] for v in kv_shape[1:]: tmp.append(2) tmp.append(v) kv_shape = tmp kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") kv_data = kv_data_fp32.half() kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] # actual data is stored in non-contiguous memory assert ( kv_data.stride(-4) != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] ) else: kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") kv_data = kv_data_fp32.half() kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq kv_indices_cpu = torch.arange(0, total_num_pages).int() kv_last_page_len_cpu = torch.full( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 ) workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0") if not use_cuda_graph: q_indptr_gpu = q_indptr_cpu.to(0) kv_indptr_gpu = kv_indptr_cpu.to(0) kv_indices_gpu = kv_indices_cpu.to(0) kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) wrapper.plan( q_indptr_gpu, kv_indptr_gpu, kv_indices_gpu, kv_last_page_len_gpu, num_qo_heads, num_kv_heads, head_dim, page_size, causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) if return_lse: o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) # test with pre-allocated output o_buffer = torch.empty_like(o) wrapper.run(q, kv_data, out=o_buffer) torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) else: q_indptr_buffer = torch.empty( batch_size + 1, device="cuda:0", dtype=torch.int32 ) kv_indptr_buffer = torch.empty( batch_size + 1, device="cuda:0", dtype=torch.int32 ) kv_indices_buffer = torch.empty( total_num_pages, device="cuda:0", dtype=torch.int32 ) kv_last_page_len_buffer = torch.empty( batch_size, device="cuda:0", dtype=torch.int32 ) wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout, use_cuda_graph=True, qo_indptr_buf=q_indptr_buffer, paged_kv_indptr_buf=kv_indptr_buffer, paged_kv_indices_buf=kv_indices_buffer, paged_kv_last_page_len_buf=kv_last_page_len_buffer, ) q_indptr_warmup = torch.arange(0, batch_size + 1).int() * qo_len kv_indptr_warmup = torch.arange(0, batch_size + 1).int() kv_indices_warmup = torch.arange(0, batch_size).int() kv_last_page_len_warmup = torch.full( (batch_size,), page_size, dtype=torch.int32 ) wrapper.plan( q_indptr_warmup, kv_indptr_warmup, kv_indices_warmup, kv_last_page_len_warmup, num_qo_heads, num_kv_heads, head_dim, page_size, causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) # warmup s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): if return_lse: o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) torch.cuda.current_stream().wait_stream(s) # capture g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): if return_lse: o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) wrapper.plan( q_indptr_cpu, kv_indptr_cpu, kv_indices_cpu, kv_last_page_len_cpu, num_qo_heads, num_kv_heads, head_dim, page_size, causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) g.replay() for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] ki = torch.cat( [ kv_data_fp32[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 0] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data_fp32[ kv_indptr_cpu[i + 1] - 1, 0, :, : kv_last_page_len_cpu[i] ] if kv_layout == "HND" else kv_data_fp32[ kv_indptr_cpu[i + 1] - 1, 0, : kv_last_page_len_cpu[i], : ] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], dim=0, ).half() vi = torch.cat( [ kv_data_fp32[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 1] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data_fp32[ kv_indptr_cpu[i + 1] - 1, 1, :, : kv_last_page_len_cpu[i] ] if kv_layout == "HND" else kv_data_fp32[ kv_indptr_cpu[i + 1] - 1, 1, : kv_last_page_len_cpu[i], : ] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], dim=0, ).half() o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( qi, ki, vi, causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) o_i = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17, 128]) @pytest.mark.parametrize("kv_len", [54, 97, 512, 2048]) @pytest.mark.parametrize("qo_len", [37, 17, 127, 577]) @pytest.mark.parametrize("page_size", [1, 5, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("kv_layout", ["NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) @pytest.mark.parametrize("use_cuda_graph", [False, True]) @pytest.mark.parametrize("logits_soft_cap", [0.0]) @pytest.mark.parametrize("return_lse", [True]) @pytest.mark.parametrize("contiguous_kv", [True]) def test_batch_prefill_with_tuple_paged_kv_cache( batch_size, kv_len, qo_len, page_size, num_kv_heads, num_qo_heads, head_dim, causal, kv_layout, pos_encoding_mode, use_cuda_graph, logits_soft_cap, return_lse, contiguous_kv, ): if qo_len > kv_len and causal: pytest.skip("qo_len > kv_len and causal is not supported") q = torch.randn( batch_size * qo_len, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16, ) q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size if kv_layout == "HND": kv_shape = [total_num_pages, num_kv_heads, page_size, head_dim] else: kv_shape = [total_num_pages, page_size, num_kv_heads, head_dim] if not contiguous_kv: tmp = [kv_shape[0]] for v in kv_shape[1:]: tmp.append(2) tmp.append(v) kv_shape = tmp kv_data_fp32 = [ torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") for _ in range(2) ] kv_data = [kv_data_fp32[i].half() for i in range(2)] for i in range(2): kv_data_fp32[i] = kv_data_fp32[i][:, 1, :, 1, :, 1, :] kv_data[i] = kv_data[i][:, 1, :, 1, :, 1, :] # actual data is stored in non-contiguous memory assert ( kv_data[i].stride(-4) != kv_data[i].shape[-3] * kv_data[i].shape[-2] * kv_data[i].shape[-1] ) else: kv_data_fp32 = [ torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") for _ in range(2) ] kv_data = [kv_data_fp32[i].half() for i in range(2)] kv_data = tuple(kv_data) kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq kv_indices_cpu = torch.arange(0, total_num_pages).int() kv_last_page_len_cpu = torch.full( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 ) workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0") if not use_cuda_graph: q_indptr_gpu = q_indptr_cpu.to(0) kv_indptr_gpu = kv_indptr_cpu.to(0) kv_indices_gpu = kv_indices_cpu.to(0) kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) wrapper.plan( q_indptr_gpu, kv_indptr_gpu, kv_indices_gpu, kv_last_page_len_gpu, num_qo_heads, num_kv_heads, head_dim, page_size, causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) if return_lse: o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) else: q_indptr_buffer = torch.empty( batch_size + 1, device="cuda:0", dtype=torch.int32 ) kv_indptr_buffer = torch.empty( batch_size + 1, device="cuda:0", dtype=torch.int32 ) kv_indices_buffer = torch.empty( total_num_pages, device="cuda:0", dtype=torch.int32 ) kv_last_page_len_buffer = torch.empty( batch_size, device="cuda:0", dtype=torch.int32 ) wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout, use_cuda_graph=True, qo_indptr_buf=q_indptr_buffer, paged_kv_indptr_buf=kv_indptr_buffer, paged_kv_indices_buf=kv_indices_buffer, paged_kv_last_page_len_buf=kv_last_page_len_buffer, ) q_indptr_warmup = torch.arange(0, batch_size + 1).int() * qo_len kv_indptr_warmup = torch.arange(0, batch_size + 1).int() kv_indices_warmup = torch.arange(0, batch_size).int() kv_last_page_len_warmup = torch.full( (batch_size,), page_size, dtype=torch.int32 ) wrapper.plan( q_indptr_warmup, kv_indptr_warmup, kv_indices_warmup, kv_last_page_len_warmup, num_qo_heads, num_kv_heads, head_dim, page_size, causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) # warmup s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): if return_lse: o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) torch.cuda.current_stream().wait_stream(s) # capture g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): if return_lse: o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) wrapper.plan( q_indptr_cpu, kv_indptr_cpu, kv_indices_cpu, kv_last_page_len_cpu, num_qo_heads, num_kv_heads, head_dim, page_size, causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) g.replay() k_cache, v_cache = kv_data_fp32 for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] ki = torch.cat( [ k_cache[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( k_cache[kv_indptr_cpu[i + 1] - 1, :, : kv_last_page_len_cpu[i]] if kv_layout == "HND" else k_cache[kv_indptr_cpu[i + 1] - 1, : kv_last_page_len_cpu[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], dim=0, ).half() vi = torch.cat( [ v_cache[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( v_cache[kv_indptr_cpu[i + 1] - 1, :, : kv_last_page_len_cpu[i]] if kv_layout == "HND" else v_cache[kv_indptr_cpu[i + 1] - 1, : kv_last_page_len_cpu[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], dim=0, ).half() o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( qi, ki, vi, causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) o_i = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17, 128]) @pytest.mark.parametrize("kv_len", [54, 97, 512, 2048]) @pytest.mark.parametrize("qo_len", [37, 17, 127, 577]) @pytest.mark.parametrize("page_size", [1, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) @pytest.mark.parametrize("logits_soft_cap", [0.0]) @pytest.mark.parametrize("return_lse", [True]) @pytest.mark.parametrize("contiguous_kv", [True]) def test_batch_prefill_with_paged_kv_cache_custom_mask( batch_size, kv_len, qo_len, page_size, num_kv_heads, num_qo_heads, head_dim, kv_layout, pos_encoding_mode, logits_soft_cap, return_lse, contiguous_kv, ): q = torch.randn( batch_size * qo_len, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16, ) q_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len ) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size if kv_layout == "HND": kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] else: kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] if not contiguous_kv: tmp = [kv_shape[0]] for v in kv_shape[1:]: tmp.append(2) tmp.append(v) kv_shape = tmp kv_data = torch.randn(*kv_shape, dtype=torch.float16, device="cuda:0") kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] # actual data is stored in non-contiguous memory assert ( kv_data.stride(-4) != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] ) else: kv_data = torch.randn(*kv_shape, dtype=torch.float16, device="cuda:0") kv_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * num_pages_per_seq ) kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) kv_last_page_len = torch.full( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" ) workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0") wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) custom_mask = torch.tril( torch.full((batch_size, qo_len, kv_len), True, device="cuda:0"), diagonal=(kv_len - qo_len), ).reshape(-1) # use custom mask wrapper.plan( q_indptr, kv_indptr, kv_indices, kv_last_page_len, num_qo_heads, num_kv_heads, head_dim, page_size, custom_mask=custom_mask, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) if return_lse: o_custom, _ = wrapper.run(q, kv_data, return_lse=True) else: o_custom = wrapper.run(q, kv_data) # use causal wrapper.plan( q_indptr, kv_indptr, kv_indices, kv_last_page_len, num_qo_heads, num_kv_heads, head_dim, page_size, causal=True, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) if return_lse: o_causal, _ = wrapper.run(q, kv_data, return_lse=True) else: o_causal = wrapper.run(q, kv_data) torch.testing.assert_close(o_custom, o_causal, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17, 128]) @pytest.mark.parametrize("kv_len", [54, 97, 512, 2048]) @pytest.mark.parametrize("qo_len", [37, 17, 127, 577]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) @pytest.mark.parametrize("logits_soft_cap", [0.0]) @pytest.mark.parametrize("return_lse", [True]) def test_batch_prefill_with_ragged_kv_cache( batch_size, kv_len, qo_len, num_kv_heads, num_qo_heads, head_dim, causal, pos_encoding_mode, logits_soft_cap, return_lse, ): if qo_len > kv_len and causal: pytest.skip("qo_len > kv_len and causal is not supported") kv_layout = "NHD" q = torch.randn( batch_size * qo_len, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16, ) q_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len ) k = torch.randn( batch_size * kv_len, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16, ) v = torch.randn( batch_size * kv_len, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16, ) kv_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * kv_len ) workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0") wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, kv_layout ) wrapper.plan( q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) if return_lse: o, _ = wrapper.run(q, k, v, return_lse=True) else: o = wrapper.run(q, k, v) for i in range(batch_size): o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( q[q_indptr[i] : q_indptr[i + 1]], k[kv_indptr[i] : kv_indptr[i + 1]], v[kv_indptr[i] : kv_indptr[i + 1]], causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) o_i = o[q_indptr[i] : q_indptr[i + 1]] torch.testing.assert_close(o_i, o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("kv_len", [54, 97]) @pytest.mark.parametrize("qo_len", [37, 17]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA", "ALIBI"]) @pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) @pytest.mark.parametrize("return_lse", [True, False]) def test_batch_prefill_with_ragged_kv_cache_custom_mask( batch_size, kv_len, qo_len, num_kv_heads, num_qo_heads, head_dim, pos_encoding_mode, logits_soft_cap, return_lse, ): kv_layout = "NHD" q = torch.randn( batch_size * qo_len, num_qo_heads, head_dim, device="cuda:0", dtype=torch.float16, ) q_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * qo_len ) k = torch.randn( batch_size * kv_len, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16, ) v = torch.randn( batch_size * kv_len, num_kv_heads, head_dim, device="cuda:0", dtype=torch.float16, ) kv_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * kv_len ) workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device="cuda:0") wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, kv_layout ) custom_mask = torch.tril( torch.full((batch_size, qo_len, kv_len), True, device="cuda:0"), diagonal=(kv_len - qo_len), ).reshape(-1) # use custom mask wrapper.plan( q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, custom_mask=custom_mask, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) if return_lse: o_custom, _ = wrapper.run(q, k, v, return_lse=True) else: o_custom = wrapper.run(q, k, v) # use causal wrapper.plan( q_indptr, kv_indptr, num_qo_heads, num_kv_heads, head_dim, causal=True, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) if return_lse: o_causal, _ = wrapper.run(q, k, v, return_lse=True) else: o_causal = wrapper.run(q, k, v) torch.testing.assert_close(o_custom, o_causal, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1]) @pytest.mark.parametrize( "kv_len, qo_len, prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr", [ (54, 37, 17, list(range(17)) + list(range(19)) + [0], 100, [18]), (97, 81, 16, list(range(80)) + [0], 97, [79]), ], ) @pytest.mark.parametrize("page_size", [1, 5, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128]) @pytest.mark.parametrize("causal", [True]) @pytest.mark.parametrize("kv_layout", ["NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["ROPE_LLAMA"]) @pytest.mark.parametrize("logits_soft_cap", [0.0, 30.0]) @pytest.mark.parametrize("return_lse", [True, False]) def test_batch_prefill_with_paged_kv_cache_multi_item_scoring( batch_size, kv_len, qo_len, prefix_len_ptr, token_pos_in_items_ptr, token_pos_in_items_len, max_item_len_ptr, page_size, num_kv_heads, num_qo_heads, head_dim, causal, kv_layout, pos_encoding_mode, logits_soft_cap, return_lse, ): q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() q_indptr_cpu = torch.arange(0, batch_size + 1).int() * qo_len num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size kv_data = ( torch.randn(total_num_pages, 2, num_kv_heads, page_size, head_dim).to(0).half() if kv_layout == "HND" else torch.randn(total_num_pages, 2, page_size, num_kv_heads, head_dim) .to(0) .half() ) kv_indptr_cpu = torch.arange(0, batch_size + 1).int() * num_pages_per_seq kv_indices_cpu = torch.arange(0, total_num_pages).int() kv_last_page_len_cpu = torch.full( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 ) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0) q_indptr_gpu = q_indptr_cpu.to(0) kv_indptr_gpu = kv_indptr_cpu.to(0) kv_indices_gpu = kv_indices_cpu.to(0) kv_last_page_len_gpu = kv_last_page_len_cpu.to(0) wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) wrapper.plan( q_indptr_gpu, kv_indptr_gpu, kv_indices_gpu, kv_last_page_len_gpu, num_qo_heads, num_kv_heads, head_dim, page_size, causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, prefix_len_ptr=torch.tensor(prefix_len_ptr).to(dtype=torch.uint32).to(0), token_pos_in_items_ptr=torch.tensor(token_pos_in_items_ptr) .to(dtype=torch.uint16) .to(0), token_pos_in_items_len=torch.tensor(token_pos_in_items_len) .to(dtype=torch.uint32) .to(0), max_item_len_ptr=torch.tensor(max_item_len_ptr).to(dtype=torch.uint16).to(0), ) if return_lse: o, _ = wrapper.run_return_lse(q, kv_data) else: o = wrapper.run(q, kv_data) for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] qi = q[q_indptr_cpu[i] : q_indptr_cpu[i + 1]] ki = torch.cat( [ kv_data[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 0] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data[kv_indptr_cpu[i + 1] - 1, 0, :, : kv_last_page_len_cpu[i]] if kv_layout == "HND" else kv_data[ kv_indptr_cpu[i + 1] - 1, 0, : kv_last_page_len_cpu[i], : ] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], dim=0, ) vi = torch.cat( [ kv_data[kv_indptr_cpu[i] : kv_indptr_cpu[i + 1] - 1, 1] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data[kv_indptr_cpu[i + 1] - 1, 1, :, : kv_last_page_len_cpu[i]] if kv_layout == "HND" else kv_data[ kv_indptr_cpu[i + 1] - 1, 1, : kv_last_page_len_cpu[i], : ] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], dim=0, ) def create_2D_multi_item_mask_dense( is_delimiter, sliding_window_size=-1, prefix_cache_len=None ): # Function to create custom_mask for multi-item scoring # # Note, sliding window implementation assumes that candidate_i_size < sliding_window_size < prefix_size # Args: # is_delimiter: a boolen torch vec to indicate the delimiter position for creating custom attnetion mask in multi-item scoring # currently assume qo len and kv len are the same and 1D (bsz=1) case # sliding_window_size: the window size for sliding window attention, -1 means no sliding window attention delimiter_idx = is_delimiter.nonzero(as_tuple=True)[0] if len(delimiter_idx) == 0: return None else: first_delimiter_pos = delimiter_idx[0] seq_len = len(is_delimiter) pos = torch.arange(seq_len, device=is_delimiter.device) group_ids = torch.cumsum(is_delimiter, 0) # Get mask for within-group causal attention within_group_causal = (group_ids.unsqueeze(1) == group_ids.unsqueeze(0)) & ( pos.unsqueeze(0) <= pos.unsqueeze(1) ) # Combine all conditions attention_mask = ( ( within_group_causal | ( (pos >= first_delimiter_pos).unsqueeze(1) & (pos < first_delimiter_pos).unsqueeze(0) ) # Prefix attention ) & ~is_delimiter.unsqueeze(0) & ~is_delimiter.unsqueeze(1) ) # No delimiter attention if sliding_window_size > 0 and sliding_window_size < len(is_delimiter): # Calculate how many positions from right of prefix each token can attend to group_size = torch.sum( within_group_causal & ~is_delimiter.unsqueeze(0), dim=1 ) # For prefix: after sliding_window_size position, can see window_size tokens # For candidate items: can see (sliding_window_size - group_size) tokens from prefix end prefix_window = torch.where( pos >= first_delimiter_pos, sliding_window_size - group_size, torch.where( pos < sliding_window_size, first_delimiter_pos, sliding_window_size, ), ) # Starting index of attention window relative to token position for candidate item/group prefix_start = first_delimiter_pos - prefix_window.unsqueeze(1) attention_mask = attention_mask & (pos >= prefix_start) if prefix_cache_len: patch = torch.ones( seq_len, prefix_cache_len, device=is_delimiter.device, dtype=torch.bool, ) attention_mask = torch.concat([patch, attention_mask], dim=1) return attention_mask.unsqueeze(0).reshape(-1) custom_mask = create_2D_multi_item_mask_dense( is_delimiter=torch.tensor(token_pos_in_items_ptr).to(0) == 0, sliding_window_size=-1, prefix_cache_len=prefix_len_ptr, ) o_ref_i = flashinfer.prefill.single_prefill_with_kv_cache( qi, ki, vi, causal=causal, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, custom_mask=custom_mask, ) o_i_np = o[q_indptr_cpu[i] : q_indptr_cpu[i + 1]].cpu().numpy() o_ref_i_np = o_ref_i.cpu().numpy() numpy.testing.assert_allclose(o_i_np, o_ref_i_np, rtol=1e-3, atol=1e-3) if __name__ == "__main__": test_batch_prefill_with_paged_kv_cache( 12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, 0.0, False, True ) test_batch_prefill_with_tuple_paged_kv_cache( 12, 54, 37, 16, 8, 8, 128, True, "HND", "NONE", True, 0.0, False, True ) test_batch_prefill_with_paged_kv_cache( 12, 54, 37, 1, 8, 8, 128, True, "HND", "NONE", False, 0.0, False, True ) test_batch_prefill_with_paged_kv_cache_custom_mask( 1, 137, 137, 1, 8, 8, 128, "HND", "NONE", 0.0, False, True ) test_batch_prefill_with_ragged_kv_cache( 12, 54, 37, 8, 8, 128, True, "NONE", 0.0, False ) test_batch_prefill_with_ragged_kv_cache_custom_mask( 1, 137, 137, 8, 8, 128, "NONE", 0.0, False )