""" Copyright (c) 2023 by FlashInfer team. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import math import pytest import torch from conftest import clear_cuda_cache import flashinfer from flashinfer.jit import build_jit_specs from flashinfer.jit.attention import ( gen_batch_mla_module, gen_batch_prefill_module, gen_single_prefill_module, ) from flashinfer.utils import ( is_sm90a_supported, is_sm100a_supported, is_sm110a_supported, ) @pytest.fixture(autouse=True, scope="module") def warmup_jit(): try: modules = [] for backend in ["fa2", "fa3"]: if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): continue modules.append( gen_single_prefill_module( backend, torch.float16, torch.float16, torch.float16, 192, 128, 0, False, False, False, ) ) for backend in ["fa2", "fa3"]: if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): continue modules.append( gen_batch_prefill_module( backend, torch.float16, torch.float16, torch.float16, torch.int32, 192, 128, 0, False, False, False, ) ) for backend in ["fa2", "fa3"]: if backend == "fa3" and not is_sm90a_supported(torch.device("cuda")): continue modules.append( gen_batch_mla_module( backend, torch.float16, torch.float16, torch.float16, torch.int32, 512, 64, False, ) ) build_jit_specs(modules, verbose=False) except Exception as e: # abort the test session if warmup fails pytest.exit(str(e)) finally: yield def attention_ref( batch_size, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool, sm_scale: float, ) -> torch.Tensor: qo_len = q.shape[0] // batch_size kv_len = k.shape[0] // batch_size num_qo_heads = q.shape[1] head_dim_qk = q.shape[2] head_dim_vo = v.shape[2] logits = ( torch.einsum( "bmhd,bnhd->bhmn", q.view(batch_size, qo_len, num_qo_heads, head_dim_qk).float(), k.view(batch_size, kv_len, num_qo_heads, head_dim_qk).float(), ) * sm_scale ) if causal: mask = torch.arange(kv_len - qo_len, kv_len, device=q.device).unsqueeze( 1 ) >= torch.arange(0, kv_len, device=q.device).unsqueeze(0) else: mask = torch.ones(qo_len, kv_len, device=q.device) logits = logits.masked_fill(mask.unsqueeze(0).unsqueeze(0) == 0, float("-inf")) lse_ref = torch.logsumexp(logits, -1).transpose(-1, -2) p = torch.softmax(logits, dim=-1) o_ref = ( torch.einsum( "bhmn,bnhd->bmhd", p, v.view(batch_size, kv_len, num_qo_heads, head_dim_vo).float(), ) .contiguous() .view(batch_size * qo_len, num_qo_heads, head_dim_vo) .to(q) ) return o_ref, lse_ref * math.log2(math.e) @pytest.mark.parametrize("kv_len", [5532, 7563]) @pytest.mark.parametrize("qo_len", [1832, 3928]) @pytest.mark.parametrize("num_heads", [4, 32, 128]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("backend", ["fa2", "fa3"]) @pytest.mark.parametrize("dtype", [torch.half]) def test_single_prefill_with_kv_cache( kv_len, qo_len, num_heads, causal, backend, dtype, ): device = torch.device("cuda:0") clear_cuda_cache(device) if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") if is_sm110a_supported(device) and num_heads * kv_len > 700000: pytest.skip("skip large tests on Thor due to memory limit") torch.manual_seed(42) head_dim_qk = 192 head_dim_vo = 128 q = torch.randn(qo_len, num_heads, head_dim_qk, dtype=dtype, device=device) k = torch.randn(kv_len, num_heads, head_dim_qk, dtype=dtype, device=device) v = torch.randn(kv_len, num_heads, head_dim_vo, dtype=dtype, device=device) o, lse = flashinfer.single_prefill_with_kv_cache( q, k, v, causal=causal, backend=backend, return_lse=True ) sm_scale = 1.0 / (head_dim_qk**0.5) o_ref, lse_ref = attention_ref(1, q, k, v, causal, sm_scale) torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) torch.testing.assert_close(lse, lse_ref.squeeze(0), rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [12, 17]) @pytest.mark.parametrize("kv_len", [544, 977]) @pytest.mark.parametrize("qo_len", [377, 177]) @pytest.mark.parametrize("num_heads", [4, 32, 128]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("backend", ["fa2", "fa3"]) @pytest.mark.parametrize("dtype", [torch.half]) def test_batch_prefill_with_ragged_kv_cache( batch_size, kv_len, qo_len, num_heads, causal, backend, dtype, ): device = torch.device("cuda:0") clear_cuda_cache(device) if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") torch.manual_seed(42) kv_layout = "NHD" head_dim_qk = 192 head_dim_vo = 128 q = torch.randn( batch_size * qo_len, num_heads, head_dim_qk, dtype=dtype, device=device ) q_indptr = ( torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * qo_len ) k = torch.zeros( batch_size * kv_len, num_heads, head_dim_qk, dtype=dtype, device=device ) v = torch.zeros( batch_size * kv_len, num_heads, head_dim_vo, dtype=dtype, device=device ) kv_indptr = ( torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * kv_len ) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( workspace_buffer, kv_layout, backend=backend ) wrapper.plan( q_indptr, kv_indptr, num_heads, num_heads, head_dim_qk, head_dim_vo=head_dim_vo, causal=causal, ) o, lse = wrapper.run_return_lse(q, k, v) sm_scale = 1.0 / (head_dim_qk**0.5) o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale) lse_ref = lse_ref.flatten(0, 1) torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) # test with pre-allocated output o_buffer = torch.empty_like(o) lse_buffer = torch.empty_like(lse) wrapper.run(q, k, v, out=o_buffer, lse=lse_buffer) torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) torch.testing.assert_close(lse, lse_buffer, rtol=1e-3, atol=1e-3) def generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads): bs_page_num, page_size, ckv_dim = ckv.shape page_num = bs_page_num // batch_size _, _, kpe_dim = kpe.shape ckv = ckv.view(batch_size, page_num * page_size, ckv_dim) kpe = kpe.view(batch_size, page_num * page_size, kpe_dim) ckv = ckv[:, :kv_len, :] kpe = kpe[:, :kv_len, :] k = ( torch.cat([ckv, kpe], dim=-1) .view(-1, 1, ckv_dim + kpe_dim) .repeat_interleave(num_heads, dim=1) ) v = ckv.repeat_interleave(num_heads, dim=1) return k, v @pytest.mark.parametrize("batch_size", [1, 3, 5, 7]) @pytest.mark.parametrize("kv_len_0", [0, 1, 3, 11]) @pytest.mark.parametrize("kv_len_1", [17, 33, 79, 114]) @pytest.mark.parametrize("kv_len_2", [514, 2743, 8736]) @pytest.mark.parametrize("qo_len", [1, 3, 5, 7, 9, 11, 13, 15, 17]) @pytest.mark.parametrize("num_heads", [16, 64]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("page_size", [1]) @pytest.mark.parametrize("backend", ["fa2", "fa3"]) @pytest.mark.parametrize("dtype", [torch.half]) def test_batch_mla_varlen_page_attention( batch_size, kv_len_0, kv_len_1, kv_len_2, qo_len, num_heads, causal, page_size, backend, dtype, ): device = torch.device("cuda:0") clear_cuda_cache(device) if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") if causal and qo_len > min(kv_len_0, kv_len_1, kv_len_2): pytest.skip("qo_len > kv_len not supported for causal attention") num_different_kv_len = 3 kv_lens = torch.tensor([kv_len_0, kv_len_1, kv_len_2], dtype=torch.int32) torch.manual_seed(42) head_dim_ckv = 512 head_dim_kpe = 64 q_nope = torch.randn( num_different_kv_len * batch_size * qo_len, num_heads, head_dim_ckv, dtype=dtype, device=device, ) q_pe = torch.randn( num_different_kv_len * batch_size * qo_len, num_heads, head_dim_kpe, dtype=dtype, device=device, ) pages_nums = torch.tensor( [math.ceil(kv_len / page_size) for kv_len in kv_lens], dtype=torch.int32, ) pages_nums_indptr = torch.zeros(num_different_kv_len + 1, dtype=torch.int32) pages_nums_indptr[1:] = pages_nums.cumsum(0) pages_nums_sum = pages_nums_indptr[-1] ckv = torch.randn( batch_size * pages_nums_sum, page_size, head_dim_ckv, dtype=dtype, device=device, ) kpe = torch.randn( batch_size * pages_nums_sum, page_size, head_dim_kpe, dtype=dtype, device=device, ) sm_scale = 1.0 / ((128 + 64) ** 0.5) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( workspace_buffer, backend=backend ) q_indptr = ( torch.arange( 0, num_different_kv_len * batch_size + 1, device=device, dtype=torch.int32 ) * qo_len ) kv_indptr = torch.cat( [ torch.arange(0, batch_size + 1).unsqueeze(-1).int() * pages_nums_sum + pages_nums_indptr[i] for i in range(num_different_kv_len) ], dim=-1, ).flatten() kv_indices = torch.arange( 0, batch_size * pages_nums_sum, device=device, dtype=torch.int32 ) kv_lens = torch.tensor(kv_lens, dtype=torch.int32, device=device).repeat(batch_size) wrapper.plan( q_indptr, kv_indptr, kv_indices, kv_lens, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_nope.dtype, ckv.dtype, ) o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) q_rows = ( torch.arange(0, num_different_kv_len * qo_len)[None, :] + torch.arange(0, batch_size)[:, None] * num_different_kv_len * qo_len ).int() kv_rows = ( torch.arange(0, pages_nums_sum)[None, :] + torch.arange(0, batch_size)[:, None] * pages_nums_sum ).int() q_rows_arr = [ q_rows[:, i * qo_len : (i + 1) * qo_len].flatten() for i in range(num_different_kv_len) ] kv_rows_arr = [ kv_rows[:, pages_nums_indptr[i] : pages_nums_indptr[i + 1]].flatten() for i in range(num_different_kv_len) ] for i in range(num_different_kv_len): k, v = generate_kv_from_cache( ckv[kv_rows_arr[i]], kpe[kv_rows_arr[i]], kv_lens[i], batch_size, num_heads ) q = torch.cat([q_nope, q_pe], dim=-1)[q_rows_arr[i]] o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale) lse_ref = lse_ref.flatten(0, 1) o_i = o[q_rows_arr[i]] torch.testing.assert_close(o_i, o_ref, rtol=1e-3, atol=1e-3) # if kv_lens[i] != 0: # torch.testing.assert_close(lse_i, lse_ref, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1, 2, 3, 4, 5, 6, 7, 157]) @pytest.mark.parametrize("kv_len", [17, 33, 75, 197]) @pytest.mark.parametrize("qo_len", [3, 7, 17]) @pytest.mark.parametrize("num_heads", [16]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("page_size", [16, 32]) @pytest.mark.parametrize("backend", ["fa2", "fa3"]) @pytest.mark.parametrize("dtype", [torch.half]) def test_batch_mla_oob_kv_nan( batch_size, kv_len, qo_len, num_heads, causal, page_size, backend, dtype ): device = torch.device("cuda:0") clear_cuda_cache(device) if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") if causal and qo_len > kv_len: pytest.skip("qo_len > kv_len not supported for causal attention") torch.manual_seed(42) head_dim_ckv = 512 head_dim_kpe = 64 q_nope = torch.randn( batch_size * qo_len, num_heads, head_dim_ckv, dtype=dtype, device=device ) q_pe = torch.randn( batch_size * qo_len, num_heads, head_dim_kpe, dtype=dtype, device=device ) pages_num = math.ceil(kv_len / page_size) ckv = torch.randn( batch_size * pages_num, page_size, head_dim_ckv, dtype=dtype, device=device ) kpe = torch.randn( batch_size * pages_num, page_size, head_dim_kpe, dtype=dtype, device=device ) # Fill oob positions with nan for i in range(batch_size): last_page_len = kv_len - (pages_num - 1) * page_size ckv[(i + 1) * pages_num - 1, last_page_len:, :] = float("nan") kpe[(i + 1) * pages_num - 1, last_page_len:, :] = float("nan") sm_scale = 1.0 / ((128 + 64) ** 0.5) workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( workspace_buffer, backend=backend ) q_indptr = ( torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * qo_len ) kv_indptr = ( torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * pages_num ) kv_indices = torch.arange( 0, batch_size * pages_num, device=device, dtype=torch.int32 ) kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device=device) wrapper.plan( q_indptr, kv_indptr, kv_indices, kv_lens, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_nope.dtype, ckv.dtype, ) o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads) q = torch.cat([q_nope, q_pe], dim=-1) o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale) lse_ref = lse_ref.flatten(0, 1) torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) if kv_len != 0: torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1, 3, 5, 7, 157]) @pytest.mark.parametrize("kv_len", [0, 17, 33, 96, 97, 114, 514, 1024]) @pytest.mark.parametrize("qo_len", [1, 3, 5, 7, 9, 11, 13, 15, 17]) @pytest.mark.parametrize("num_heads", [16]) @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("page_size", [1, 16]) @pytest.mark.parametrize("backend", ["fa2", "fa3"]) @pytest.mark.parametrize("use_cuda_graph", [False]) @pytest.mark.parametrize("dtype", [torch.half]) def test_batch_mla_page_attention( batch_size, kv_len, qo_len, num_heads, causal, page_size, backend, use_cuda_graph, dtype, ): device = torch.device("cuda:0") clear_cuda_cache(device) if backend == "fa3" and not is_sm90a_supported(device): pytest.skip("FA3 is not supported on this device") if causal and qo_len > kv_len: pytest.skip("qo_len > kv_len not supported for causal attention") torch.manual_seed(42) head_dim_ckv = 512 head_dim_kpe = 64 q_nope = torch.randn( batch_size * qo_len, num_heads, head_dim_ckv, dtype=dtype, device=device ) q_pe = torch.randn( batch_size * qo_len, num_heads, head_dim_kpe, dtype=dtype, device=device ) pages_num = math.ceil(kv_len / page_size) ckv = torch.randn( batch_size * pages_num, page_size, head_dim_ckv, dtype=dtype, device=device, ) kpe = torch.randn( batch_size * pages_num, page_size, head_dim_kpe, dtype=dtype, device=device, ) sm_scale = 1.0 / ((128 + 64) ** 0.5) # use head dimension before matrix absorption workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( workspace_buffer, backend=backend, use_cuda_graph=True, qo_indptr=torch.empty(batch_size + 1, dtype=torch.int32, device=device), kv_indptr=torch.empty(batch_size + 1, dtype=torch.int32, device=device), kv_indices=torch.empty(1048576, dtype=torch.int32, device=device), kv_len_arr=torch.empty(batch_size, dtype=torch.int32, device=device), ) q_indptr = ( torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * qo_len ) kv_indptr = ( torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * pages_num ) kv_indices = torch.arange( 0, batch_size * pages_num, device=device, dtype=torch.int32 ) kv_lens = torch.full((batch_size,), kv_len, dtype=torch.int32, device=device) if use_cuda_graph: kv_indptr_warmup = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) kv_indices_warmup = torch.arange( 0, batch_size, device=device, dtype=torch.int32 ) kv_lens_warmup = torch.full((batch_size,), 0, dtype=torch.int32, device=device) wrapper.plan( q_indptr, kv_indptr_warmup, kv_indices_warmup, kv_lens_warmup, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_nope.dtype, ckv.dtype, ) # warmup s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) torch.cuda.current_stream().wait_stream(s) # capture g = torch.cuda.CUDAGraph() with torch.cuda.graph(g): o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) wrapper.plan( q_indptr, kv_indptr, kv_indices, kv_lens, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_nope.dtype, ckv.dtype, ) if use_cuda_graph: o.fill_(0) lse.fill_(0) g.replay() else: o, lse = wrapper.run(q_nope, q_pe, ckv, kpe, return_lse=True) k, v = generate_kv_from_cache(ckv, kpe, kv_len, batch_size, num_heads) q = torch.cat([q_nope, q_pe], dim=-1) o_ref, lse_ref = attention_ref(batch_size, q, k, v, causal, sm_scale) lse_ref = lse_ref.flatten(0, 1) torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) if kv_len != 0: torch.testing.assert_close(lse, lse_ref, rtol=1e-3, atol=1e-3) # test with pre-allocated output o_buffer = torch.empty_like(o) lse_buffer = torch.empty_like(lse) wrapper.run(q_nope, q_pe, ckv, kpe, out=o_buffer, lse=lse_buffer) torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) torch.testing.assert_close(lse, lse_buffer, rtol=1e-3, atol=1e-3) @pytest.mark.parametrize("batch_size", [1, 2, 4]) @pytest.mark.parametrize("max_seq_len", [128, 1024, 4096]) @pytest.mark.parametrize("page_size", [1, 16, 128]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.half]) def test_cutlass_mla(batch_size, max_seq_len, page_size, dtype): device = torch.device("cuda:0") clear_cuda_cache(device) if not is_sm100a_supported(device) and not is_sm110a_supported(device): pytest.skip("Cutlass MLA is not supported on this device") torch.manual_seed(42) num_local_heads = 128 head_dim_ckv = 512 head_dim_kpe = 64 total_page_num = 8192 # NOTE(Zihao): use larger scale to detect bugs such as # https://github.com/flashinfer-ai/flashinfer/pull/1055 q_nope_pe = ( torch.randn( batch_size, num_local_heads, head_dim_ckv + head_dim_kpe, dtype=dtype, device=device, ) * 100 ) ckv_kpe = torch.randn( total_page_num, page_size, head_dim_ckv + head_dim_kpe, dtype=dtype, device=device, ) kv_lens = torch.full((batch_size,), max_seq_len, dtype=torch.int32, device=device) page_num_per_batch = (max_seq_len + page_size - 1) // page_size # Cutlass MLA requires small pages (< 128) are packed into a 128 page. assert page_num_per_batch % (128 // page_size) == 0 page_table = torch.randint( 0, total_page_num, (batch_size, page_num_per_batch), dtype=torch.int32, device=device, ) mla_ref = flashinfer.mla.BatchMLAPagedAttentionWrapper( torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device), backend="fa2" ) # for decode, each query length is 1 q_indptr = torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) kv_lens = torch.full((batch_size,), max_seq_len, dtype=torch.int32, device=device) kv_indptr = ( torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) * page_num_per_batch ) kv_indices = page_table.flatten() q_nope = q_nope_pe[..., :head_dim_ckv] q_pe = q_nope_pe[..., head_dim_ckv:] ckv = ckv_kpe[..., :head_dim_ckv] kpe = ckv_kpe[..., head_dim_ckv:] # use head dimension before matrix absorption sm_scale = 1.0 / ((128 + 64) ** 0.5) mla_ref.plan( q_indptr, kv_indptr, kv_indices, kv_lens, num_local_heads, head_dim_ckv, head_dim_kpe, page_size, False, # causal sm_scale, q_nope.dtype, ckv.dtype, ) o_ref = mla_ref.run(q_nope, q_pe, ckv, kpe, return_lse=False) mla_ans = flashinfer.mla.BatchMLAPagedAttentionWrapper( torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device), backend="cutlass", ) o_ans = mla_ans.run(q_nope, q_pe, ckv, kpe, kv_len=kv_lens, page_table=page_table) torch.testing.assert_close(o_ans, o_ref, rtol=1e-2, atol=1e-2) if __name__ == "__main__": test_batch_mla_varlen_page_attention( 1, 65, 65, 65, 1, 128, True, 64, "fa2", torch.half ) # test_batch_mla_varlen_page_attention( # 155, 1024, 8, 128, 128, 16, False, 1, "fa3", torch.half # ) # test_batch_mla_page_attention(1, 1024, 128, 128, False, 1, "fa2", True, torch.half)