import pytest import torch from flashinfer import ( BatchDecodeWithPagedKVCacheWrapper, BatchPrefillWithPagedKVCacheWrapper, ) from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd from sglang.srt.layers.attention.triton_ops.extend_attention import ( extend_attention_fwd, redundant_attention, ) from sglang.srt.utils import should_use_tensor_core flashinfer_prefill_wrapper = None flashinfer_decode_wrapper = None @pytest.mark.parametrize("batch_size", [12, 37, 67]) @pytest.mark.parametrize("kv_len", [54, 97]) @pytest.mark.parametrize("qo_len", [37, 17]) @pytest.mark.parametrize("num_kv_heads", [4]) @pytest.mark.parametrize("num_qo_heads", [32, 4]) @pytest.mark.parametrize("head_dim", [128]) def test_batch_prefill_with_paged_kv_cache( batch_size, kv_len, qo_len, num_kv_heads, num_qo_heads, head_dim, ): init_flashinfer(num_qo_heads, num_kv_heads) q = torch.randn(batch_size * qo_len, num_qo_heads, head_dim).to(0).half() qo_indptr = torch.arange(0, batch_size + 1).to(0).int() * qo_len total_tokens = kv_len * batch_size kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half() kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len kv_indices = torch.arange(0, total_tokens).to(0).int() kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0) # init args for triton kernel k_extend = ( kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 0] .contiguous() .view(-1, num_kv_heads, head_dim) ) v_extend = ( kv_data.view(batch_size, kv_len, 2, -1)[:, -qo_len:, 1] .contiguous() .view(-1, num_kv_heads, head_dim) ) o_triton = torch.empty_like(q) k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous() v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous() req_to_token = torch.arange(0, total_tokens).to(0).int().view(batch_size, kv_len) b_req_idx = torch.arange(0, batch_size).to(0).int() b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0) b_start_loc_extend = torch.arange(0, batch_size).to(0).int() * qo_len b_seq_len_extend = torch.full((batch_size,), qo_len, dtype=torch.int32).to(0) max_len_in_batch = kv_len max_len_extend = qo_len extend_attention_fwd( q, k_extend, v_extend, o_triton, k_buffer, v_buffer, req_to_token, b_req_idx, None, # b_start_loc = None b_seq_len, None, # b_seq_len_prefix = None b_start_loc_extend, b_seq_len_extend, max_len_in_batch, max_len_extend, ) o_redundant = torch.empty_like(q) b_start_loc = torch.zeros((batch_size,), dtype=torch.int32).to(0) b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], dim=0) b_seq_len_prefix = b_seq_len - b_seq_len_extend redundant_attention( q, k_extend, v_extend, o_redundant, k_buffer, v_buffer, req_to_token, b_req_idx, b_start_loc, b_seq_len, b_seq_len_prefix, max_len_in_batch, ) print("Mean: ", torch.mean(torch.abs(o_redundant - o_triton))) print("Max: ", torch.max(torch.abs(o_redundant - o_triton))) assert torch.allclose(o_redundant, o_triton, rtol=1e-2, atol=1e-3) flashinfer_prefill_wrapper.end_forward() flashinfer_prefill_wrapper.begin_forward( qo_indptr, kv_indptr, kv_indices, kv_last_page_len, num_qo_heads, num_kv_heads, head_dim, 1, ) o = flashinfer_prefill_wrapper.forward( q.contiguous().view(-1, num_qo_heads, head_dim), kv_data ) print("Mean: ", torch.mean(torch.abs(o - o_triton))) print("Max: ", torch.max(torch.abs(o - o_triton))) assert torch.allclose(o, o_triton, rtol=1e-2, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17, 37]) @pytest.mark.parametrize("kv_len", [54, 127, 537]) @pytest.mark.parametrize("num_kv_heads", [32]) @pytest.mark.parametrize("num_qo_heads", [32]) @pytest.mark.parametrize("head_dim", [128]) def test_batch_decode_with_paged_kv_cache( batch_size, kv_len, num_kv_heads, num_qo_heads, head_dim, ): # note(lsyin): when pytest, the number of heads cannot change, because triton kernel has a cache # to test different shape of decode, change the parameters in the __main__, and run decode only once init_flashinfer(num_qo_heads, num_kv_heads) q = torch.randn(batch_size, num_qo_heads, head_dim).to(0).half() total_tokens = kv_len * batch_size kv_data = torch.randn(total_tokens, 2, num_kv_heads, head_dim).to(0).half() kv_indptr = torch.arange(0, batch_size + 1).to(0).int() * kv_len kv_indices = torch.arange(0, total_tokens).to(0).int() kv_last_page_len = torch.full((batch_size,), 1, dtype=torch.int32).to(0) # init args for triton kernel k_buffer = kv_data[:, 0].view(-1, num_kv_heads, head_dim).contiguous() v_buffer = kv_data[:, 1].view(-1, num_kv_heads, head_dim).contiguous() o_triton = torch.empty_like(q) req_to_token = ( torch.arange(0, kv_len * batch_size).to(0).int().view(batch_size, kv_len) ) b_req_idx = torch.arange(0, batch_size).to(0).int() b_start_loc = torch.arange(0, batch_size).to(0).int() * kv_len b_seq_len = torch.full((batch_size,), kv_len, dtype=torch.int32).to(0) max_len_in_batch = kv_len other_kv_index = 0 decode_attention_fwd( q, k_buffer, v_buffer, o_triton, req_to_token, b_req_idx, b_start_loc, b_seq_len, max_len_in_batch, other_kv_index, total_tokens, ) flashinfer_decode_wrapper.end_forward() flashinfer_decode_wrapper.begin_forward( kv_indptr, kv_indices, kv_last_page_len, num_qo_heads, num_kv_heads, head_dim, 1, pos_encoding_mode="NONE", data_type="float16", ) o = flashinfer_decode_wrapper.forward( q.contiguous().view(-1, num_qo_heads, head_dim), kv_data ) print("Mean: ", torch.mean(torch.abs(o - o_triton))) print("Max: ", torch.max(torch.abs(o - o_triton))) assert torch.allclose(o, o_triton, rtol=1e-2, atol=2e-3) def init_flashinfer(num_attention_heads, num_kv_heads): use_tensor_cores = should_use_tensor_core( torch.half, num_attention_heads, num_kv_heads ) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device="cuda") global flashinfer_prefill_wrapper, flashinfer_decode_wrapper flashinfer_prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper( workspace_buffer, "NHD" ) flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( workspace_buffer, "NHD", use_tensor_cores=use_tensor_cores ) if __name__ == "__main__": test_batch_prefill_with_paged_kv_cache(12, 54, 37, 8, 8, 128) test_batch_prefill_with_paged_kv_cache(37, 1111, 456, 32, 32, 128) test_batch_decode_with_paged_kv_cache(12, 54, 4, 32, 128)