""" Copyright (c) 2023 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import pytest import torch from jit_utils import gen_decode_attention_modules, gen_prefill_attention_modules import flashinfer @pytest.fixture(autouse=True, scope="module") def warmup_jit(): flashinfer.jit.build_jit_specs( gen_decode_attention_modules( [torch.float16], # q_dtypes [ torch.float16, torch.float8_e4m3fn, ], # kv_dtypes [128, 256], # head_dims [0, 1], # pos_encoding_modes [False], # use_sliding_windows [False], # use_logits_soft_caps ) + gen_prefill_attention_modules( [torch.float16], # q_dtypes [ torch.float16, torch.float8_e4m3fn, ], # kv_dtypes [128, 256], # head_dims [0, 1], # pos_encoding_modes [False], # use_sliding_windows [False], # use_logits_soft_caps [False], # use_fp16_qk_reductions ), verbose=False, ) yield @pytest.mark.parametrize("batch_size", [12, 17, 128]) @pytest.mark.parametrize("kv_len", [54, 97, 512, 2048, 16384]) @pytest.mark.parametrize("page_size", [1, 8, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) @pytest.mark.parametrize("logits_soft_cap", [0.0]) @pytest.mark.parametrize("return_lse", [True]) @pytest.mark.parametrize("q_dtype", [torch.float16]) @pytest.mark.parametrize("kv_dtype", [torch.float16, torch.float8_e4m3fn]) @pytest.mark.parametrize("contiguous_kv", [True]) def test_batch_decode_with_paged_kv_cache( batch_size, kv_len, page_size, num_kv_heads, num_qo_heads, head_dim, kv_layout, pos_encoding_mode, logits_soft_cap, return_lse, q_dtype, kv_dtype, contiguous_kv, ): q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size if kv_layout == "HND": kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] else: kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] if not contiguous_kv: tmp = [kv_shape[0]] for v in kv_shape[1:]: tmp.append(2) tmp.append(v) kv_shape = tmp kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") kv_data = kv_data_fp32.to(kv_dtype) kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] # actual data is stored in non-contiguous memory assert ( kv_data.stride(-4) != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] ) else: kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") kv_data = kv_data_fp32.to(kv_dtype) kv_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * num_pages_per_seq ) kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) kv_last_page_len = torch.full( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" ) workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0") wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) wrapper.plan( kv_indptr, kv_indices, kv_last_page_len, num_qo_heads, num_kv_heads, head_dim, page_size, logits_soft_cap=logits_soft_cap, pos_encoding_mode=pos_encoding_mode, data_type=kv_dtype, q_data_type=q_dtype, ) if return_lse: o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] qi = q[i] ki = torch.cat( [ kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 0] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data_fp32[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]] if kv_layout == "HND" else kv_data_fp32[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], dim=0, ).to(kv_dtype) vi = torch.cat( [ kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data_fp32[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]] if kv_layout == "HND" else kv_data_fp32[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], dim=0, ).to(kv_dtype) o_ref_i = flashinfer.decode.single_decode_with_kv_cache( qi, ki, vi, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) # test user-allocated output o_buffer = torch.empty_like(o) wrapper.run(q, kv_data, out=o_buffer) torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17, 128]) @pytest.mark.parametrize("kv_len", [54, 97, 512, 2048, 16384]) @pytest.mark.parametrize("page_size", [1, 8, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) @pytest.mark.parametrize("logits_soft_cap", [0.0]) @pytest.mark.parametrize("return_lse", [True]) @pytest.mark.parametrize("q_dtype", [torch.float16]) @pytest.mark.parametrize("kv_dtype", [torch.float16, torch.float8_e4m3fn]) @pytest.mark.parametrize("contiguous_kv", [True]) def test_batch_decode_with_tuple_paged_kv_cache( batch_size, kv_len, page_size, num_kv_heads, num_qo_heads, head_dim, kv_layout, pos_encoding_mode, logits_soft_cap, return_lse, q_dtype, kv_dtype, contiguous_kv, ): q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size if kv_layout == "HND": kv_shape = [total_num_pages, num_kv_heads, page_size, head_dim] else: kv_shape = [total_num_pages, page_size, num_kv_heads, head_dim] if not contiguous_kv: tmp = [kv_shape[0]] for v in kv_shape[1:]: tmp.append(2) tmp.append(v) kv_shape = tmp kv_data_fp32 = [ torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") for _ in range(2) ] kv_data = [kv_data_fp32[i].to(kv_dtype) for i in range(2)] for i in range(2): kv_data_fp32[i] = kv_data_fp32[i][:, 1, :, 1, :, 1, :] kv_data[i] = kv_data[i][:, 1, :, 1, :, 1, :] # actual data is stored in non-contiguous memory assert ( kv_data[i].stride(-4) != kv_data[i].shape[-3] * kv_data[i].shape[-2] * kv_data[i].shape[-1] ) else: kv_data_fp32 = [ torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") for _ in range(2) ] kv_data = [kv_data_fp32[i].to(kv_dtype) for i in range(2)] kv_data = tuple(kv_data) kv_indptr = ( torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) * num_pages_per_seq ) kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) kv_last_page_len = torch.full( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" ) workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0") wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_layout ) wrapper.plan( kv_indptr, kv_indices, kv_last_page_len, num_qo_heads, num_kv_heads, head_dim, page_size, logits_soft_cap=logits_soft_cap, pos_encoding_mode=pos_encoding_mode, data_type=kv_dtype, q_data_type=q_dtype, ) if return_lse: o, _ = wrapper.run(q, kv_data, return_lse=True) else: o = wrapper.run(q, kv_data) k_cache, v_cache = kv_data_fp32 for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] qi = q[i] ki = torch.cat( [ k_cache[kv_indptr[i] : kv_indptr[i + 1] - 1] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( k_cache[kv_indptr[i + 1] - 1, :, : kv_last_page_len[i]] if kv_layout == "HND" else k_cache[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], dim=0, ).to(kv_dtype) vi = torch.cat( [ v_cache[kv_indptr[i] : kv_indptr[i + 1] - 1] .to(torch.float32) # torch.cat does not support some fp8 types .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( v_cache[kv_indptr[i + 1] - 1, :, : kv_last_page_len[i]] if kv_layout == "HND" else v_cache[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], dim=0, ).to(kv_dtype) o_ref_i = flashinfer.decode.single_decode_with_kv_cache( qi, ki, vi, pos_encoding_mode=pos_encoding_mode, logits_soft_cap=logits_soft_cap, ) torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17, 128]) @pytest.mark.parametrize("kv_len", [54, 2048, 16384]) @pytest.mark.parametrize("page_size", [1, 8, 16]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [4, 32]) @pytest.mark.parametrize("head_dim", [128, 256]) @pytest.mark.parametrize("kv_layout", ["NHD"]) @pytest.mark.parametrize("pos_encoding_mode", ["NONE", "ROPE_LLAMA"]) @pytest.mark.parametrize("q_dtype", [torch.float16]) @pytest.mark.parametrize("kv_dtype", [torch.float16, torch.float8_e4m3fn]) @pytest.mark.parametrize("contiguous_kv", [True]) def test_cuda_graph_batch_decode_with_paged_kv_cache( batch_size, kv_len, page_size, num_kv_heads, num_qo_heads, head_dim, kv_layout, pos_encoding_mode, q_dtype, kv_dtype, contiguous_kv, ): q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype) num_pages_per_seq = (kv_len + page_size - 1) // page_size total_num_pages = num_pages_per_seq * batch_size if kv_layout == "HND": kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] else: kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] if not contiguous_kv: tmp = [kv_shape[0]] for v in kv_shape[1:]: tmp.append(2) tmp.append(v) kv_shape = tmp kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") kv_data = kv_data_fp32.to(kv_dtype) kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] # actual data is stored in non-contiguous memory assert ( kv_data.stride(-4) != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] ) else: kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") kv_data = kv_data_fp32.to(kv_dtype) kv_indptr_host_warmup = torch.arange( 0, batch_size + 1, device="cuda:0", dtype=torch.int32 ) kv_indices_host_warmup = torch.arange( 0, batch_size, device="cuda:0", dtype=torch.int32 ) kv_last_page_len_host_warmup = torch.full( (batch_size,), page_size, dtype=torch.int32 ) # NOTE(Zihao): allocate more space than needed for testing kv_indptr_device_buffer = torch.empty( batch_size + 1, device="cuda:0", dtype=torch.int32 ) kv_indices_device_buffer = torch.empty( total_num_pages, device="cuda:0", dtype=torch.int32 ) kv_last_page_device_buffer = torch.empty( batch_size, device="cuda:0", dtype=torch.int32 ) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda:0") wrapper = flashinfer.decode.CUDAGraphBatchDecodeWithPagedKVCacheWrapper( workspace_buffer, kv_indptr_device_buffer, kv_indices_device_buffer, kv_last_page_device_buffer, kv_layout, ) wrapper.plan( kv_indptr_host_warmup, kv_indices_host_warmup, kv_last_page_len_host_warmup, num_qo_heads, num_kv_heads, head_dim, page_size, data_type=kv_dtype, pos_encoding_mode=pos_encoding_mode, q_data_type=q_dtype, ) # warmup s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): o = wrapper.run(q, kv_data) torch.cuda.current_stream().wait_stream(s) # capture g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): o = wrapper.run(q, kv_data) # replay multiple times for i in range(1, min(4, num_pages_per_seq)): kv_indptr_host = torch.arange(0, batch_size + 1).int() * i kv_indices_host = torch.arange(0, i * batch_size).int() kv_last_page_len_host = torch.full((batch_size,), page_size, dtype=torch.int32) wrapper.plan( kv_indptr_host, kv_indices_host, kv_last_page_len_host, num_qo_heads, num_kv_heads, head_dim, page_size, data_type=kv_dtype, pos_encoding_mode=pos_encoding_mode, q_data_type=q_dtype, ) g.replay() # replay again kv_indptr_host = torch.arange(0, batch_size + 1).int() * num_pages_per_seq kv_indices_host = torch.arange(0, total_num_pages).int() kv_last_page_len_host = torch.full( (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32 ) wrapper.plan( kv_indptr_host, kv_indices_host, kv_last_page_len_host, num_qo_heads, num_kv_heads, head_dim, page_size, data_type=kv_dtype, pos_encoding_mode=pos_encoding_mode, q_data_type=q_dtype, ) g.replay() # compute ground truth and compare kv_indptr = kv_indptr_host.to(0) kv_last_page_len = kv_last_page_len_host.to(0) for i in range(batch_size): perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] qi = q[i] ki = torch.cat( [ kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 0] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data_fp32[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]] if kv_layout == "HND" else kv_data_fp32[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], dim=0, ).to(kv_dtype) vi = torch.cat( [ kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] .permute(*perm_dims) .reshape(-1, num_kv_heads, head_dim), ( kv_data_fp32[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]] if kv_layout == "HND" else kv_data_fp32[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :] ) .permute(*perm_dims_last) .reshape(-1, num_kv_heads, head_dim), ], dim=0, ).to(kv_dtype) o_ref_i = flashinfer.decode.single_decode_with_kv_cache( qi, ki, vi, pos_encoding_mode=pos_encoding_mode ) torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) if __name__ == "__main__": test_batch_decode_with_paged_kv_cache( 256, 54, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, torch.float16, torch.float16, True, ) test_batch_decode_with_tuple_paged_kv_cache( 256, 54, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, torch.float16, torch.float16, True, ) test_batch_decode_with_paged_kv_cache( 12, 2048, 8, 8, 8, 128, "NHD", "NONE", 0.0, False, torch.float16, torch.float16, True, ) test_batch_decode_with_paged_kv_cache( 12, 54, 1, 8, 8, 128, "HND", "NONE", 0.0, True, torch.float16, torch.float8_e5m2, True, ) test_cuda_graph_batch_decode_with_paged_kv_cache( 12, 2048, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16, True ) test_cuda_graph_batch_decode_with_paged_kv_cache( 128, 54, 8, 8, 8, 128, "NHD", "NONE", torch.float16, torch.float16, True ) test_batch_decode_with_paged_kv_cache( 12, 54, 1, 8, 8, 128, "HND", "NONE", 0.0, True, torch.float16, torch.float8_e5m2, True, ) test_cuda_graph_batch_decode_with_paged_kv_cache( 12, 54, 8, 8, 8, 128, "HND", "NONE", torch.float16, torch.float8_e5m2, True )