import math from typing import List, Optional, Tuple import pytest import torch from einops import rearrange, repeat from sgl_kernel.sparse_flash_attn import ( convert_vertical_slash_indexes, convert_vertical_slash_indexes_mergehead, sparse_attn_func, ) from test_flash_attention import construct_local_mask, is_fa3_supported def ref_attn( q, k, v, query_padding_mask=None, key_padding_mask=None, attn_bias=None, dropout_p=0.0, dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size softcap=0.0, upcast=True, reorder_ops=False, key_leftpad=None, ): """ Arguments: q: (batch_size, seqlen_q, nheads, head_dim) k: (batch_size, seqlen_k, nheads_k, head_dim) v: (batch_size, seqlen_k, nheads_k, head_dim) query_padding_mask: (batch_size, seqlen_q) key_padding_mask: (batch_size, seqlen_k) attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) dropout_p: float dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) causal: whether to apply causal masking window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) without changing the math. This is to estimate the numerical error from operation reordering. Output: output: (batch_size, seqlen_q, nheads, head_dim) lse: (batch_size, nheads, seqlen_q) """ if causal: window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: q, k, v = q.float(), k.float(), v.float() seqlen_q, seqlen_k = q.shape[1], k.shape[1] k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) d = q.shape[-1] if not reorder_ops: scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) else: scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) lse_ref = scores.logsumexp(dim=-1) if softcap > 0: scores = scores / softcap scores = scores.tanh() scores = scores * softcap if key_padding_mask is not None: scores.masked_fill_( rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf") ) if window_size[0] >= 0 or window_size[1] >= 0: local_mask = construct_local_mask( seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, q.device, key_leftpad=key_leftpad, ) scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias attention = torch.softmax(scores, dim=-1).to(v.dtype) # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill( torch.all(local_mask, dim=-1, keepdim=True), 0.0 ) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: attention = attention.masked_fill( rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0 ) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) if dropout_mask is not None: attention_drop = attention.masked_fill(~dropout_mask, 0.0) else: attention_drop = attention output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) return output.to(dtype=dtype_og), lse_ref def ref_paged_attn( query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, query_lens: List[int], kv_lens: List[int], block_tables: torch.Tensor, scale: float, sliding_window: Optional[int] = None, soft_cap: Optional[float] = None, ) -> torch.Tensor: num_seqs = len(query_lens) block_tables = block_tables.cpu().numpy() _, block_size, num_kv_heads, head_size = key_cache.shape outputs: List[torch.Tensor] = [] start_idx = 0 for i in range(num_seqs): query_len = query_lens[i] kv_len = kv_lens[i] # clone to avoid clobbering the query tensor q = query[start_idx : start_idx + query_len].clone() q *= scale num_kv_blocks = (kv_len + block_size - 1) // block_size block_indices = block_tables[i, :num_kv_blocks] k = key_cache[block_indices].view(-1, num_kv_heads, head_size) k = k[:kv_len] v = value_cache[block_indices].view(-1, num_kv_heads, head_size) v = v[:kv_len] if q.shape[1] != k.shape[1]: k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) attn = torch.einsum("qhd,khd->hqk", q, k).float() empty_mask = torch.ones(query_len, kv_len) mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() if sliding_window is not None: sliding_window_mask = ( torch.triu( empty_mask, diagonal=kv_len - (query_len + sliding_window) + 1 ) .bool() .logical_not() ) mask |= sliding_window_mask if soft_cap is not None: attn = soft_cap * torch.tanh(attn / soft_cap) attn.masked_fill_(mask, float("-inf")) attn = torch.softmax(attn, dim=-1).to(v.dtype) out = torch.einsum("hqk,khd->qhd", attn, v) outputs.append(out) start_idx += query_len return torch.cat(outputs, dim=0) @pytest.mark.skipif( not is_fa3_supported(), reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", ) @pytest.mark.parametrize("batch_size", [1, 2]) @pytest.mark.parametrize( "seq_lens", [ (1, 1), (1, 1024), (1, 2048), (1023, 2049), (1023, 1023), (32, 32), (65, 65), (129, 129), ], ) @pytest.mark.parametrize("num_heads", [1, 2, 4]) @pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("NNZ_S", [0, 1, 2, 3, 7, 15, 32]) @torch.inference_mode() def test_sparse_attention( batch_size, seq_lens, num_heads, head_size, dtype, NNZ_S, ) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) block_size_M = 64 block_size_N = 64 seqlen_q, seqlen_k = seq_lens q = torch.randn( batch_size, seqlen_q, num_heads, head_size, dtype=dtype, requires_grad=False ) k = torch.randn( batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False ) v = torch.randn( batch_size, seqlen_k, num_heads, head_size, dtype=dtype, requires_grad=False ) NUM_ROWS = (seqlen_q + block_size_M - 1) // block_size_M if NNZ_S * block_size_N > seqlen_k: return NNZ_V = seqlen_k - NNZ_S * block_size_N block_count = torch.tensor( [NNZ_S] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32 ).reshape(batch_size, num_heads, NUM_ROWS) column_count = torch.tensor( [NNZ_V] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32 ).reshape(batch_size, num_heads, NUM_ROWS) block_offset = torch.tensor( [[i * block_size_N for i in range(NNZ_S)]] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32, ).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S) column_index = torch.tensor( [[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * batch_size * NUM_ROWS * num_heads, dtype=torch.int32, ).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V) out, lse = sparse_attn_func( q, k, v, block_count, block_offset, column_count, column_index, return_softmax_lse=True, ) ref_out, ref_lse = ref_attn(q, k, v) torch.testing.assert_close( out, ref_out, atol=2e-2, rtol=1e-2 ), f"{torch.max(torch.abs(out - ref_out))}" torch.testing.assert_close( lse, ref_lse, atol=2e-2, rtol=1e-2 ), f"{torch.max(torch.abs(lse - ref_lse))}" # sparse attention utils # origin @pytest.mark.skipif( not is_fa3_supported(), reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", ) @pytest.mark.parametrize("causal", [True, False]) def test_convert_vertical_slash_indexes(causal): # Prepare small, hand-checkable inputs q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") # [BATCH] kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") vertical_indexes = torch.tensor( [[[1, 3]]], dtype=torch.int32, device="cuda" ) # [BATCH, N_HEADS, NNZ_V] slash_indexes = torch.tensor( [[[2]]], dtype=torch.int32, device="cuda" ) # [BATCH, N_HEADS, NNZ_S] context_size = 4 block_size_M = 2 block_size_N = 2 # Call your CUDA kernel wrapper block_count, block_offset, column_count, column_index = ( convert_vertical_slash_indexes( q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, context_size, block_size_M, block_size_N, causal=causal, ) ) # Manually create expected outputs for this input # There are 2 rows (blocks): row0 (tokens 0-1), row1 (tokens 2-3) # Fill these expected tensors based on your CUDA kernel's logic # For demonstration, we assume: # - block_count: how many slash indices fall into each block # - block_offset: the value of those indices # - column_count: number of valid vertical indices per block # - column_index: the actual vertical indices expected_column_index = torch.tensor( [[[[0, 0], [0, 0]]]], dtype=torch.int32, device="cuda" ) # If causal=False, update these tensors according to expected behavior if not causal: # Update these values if your kernel produces different output in non-causal mode expected_column_index = torch.tensor( [[[[1, 0], [1, 3]]]], dtype=torch.int32, device="cuda" ) # Assert that outputs match expectations assert torch.equal(column_index, expected_column_index) # mergehead @pytest.mark.skipif( not is_fa3_supported(), reason="flash_attn at sgl-kernel is only supported on sm90 or sm80", ) @pytest.mark.parametrize("causal", [True, False]) def test_convert_vertical_slash_indexes_mergehead(causal): # Prepare small, hand-checkable inputs for mergehead version q_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") kv_seqlens = torch.tensor([4], dtype=torch.int32, device="cuda") vertical_indexes = torch.tensor( [ [ [1, 3], # head 0 [2, 0], # head 1 ] ], dtype=torch.int32, device="cuda", ) # [BATCH, N_HEADS, NNZ_V] slash_indexes = torch.tensor( [ [ [2, 0], # head 0 [1, 3], # head 1 ] ], dtype=torch.int32, device="cuda", ) # [BATCH, N_HEADS, NNZ_S] vertical_indices_count = torch.tensor([2, 1], dtype=torch.int32, device="cuda") slash_indices_count = torch.tensor([1, 2], dtype=torch.int32, device="cuda") context_size = 4 block_size_M = 2 block_size_N = 2 # Call your CUDA kernel wrapper block_count, block_offset, column_count, column_index = ( convert_vertical_slash_indexes_mergehead( q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count, slash_indices_count, context_size, block_size_M, block_size_N, causal=causal, ) ) # Manually create expected outputs for this input # For demonstration, assume: # - batch=1, head=2, num_rows=2, nnz_v=2, nnz_s=2 # Fill these expected tensors according to your kernel's behavior expected_column_index = torch.tensor( [[[[1, 0], [1, 3]], [[-1079459945, -1077788999], [-1080050043, -1104625879]]]], dtype=torch.int32, device="cuda", ) if not causal: # If non-causal mode output is different, update these values expected_column_index = torch.tensor( [[[[1, 0], [1, 3]], [[2, -1077788999], [2, -1104625879]]]], dtype=torch.int32, device="cuda", ) # Assert that outputs match expectations assert torch.equal(column_index, expected_column_index) # skip cause use fa2 for test # @pytest.mark.parametrize("seq_lens", [[(1024, 1328)], # [(1024, 1328), (1, 2048)], # [(1025, 1328), (2, 2048)], # [(1025, 2049), (2, 1281)], # ]) # @pytest.mark.parametrize("head_size", [128]) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) # @torch.inference_mode() # def test_sparse_attention_varlen( # seq_lens, # head_size, # dtype, # ) -> None: # torch.set_default_device("cuda") # torch.cuda.manual_seed_all(0) # block_size_M = 64 # block_size_N = 64 # num_seqs = len(seq_lens) # query_lens = [x[0] for x in seq_lens] # kv_lens = [x[1] for x in seq_lens] # num_heads = 1 # query = torch.randn(sum(query_lens), # num_heads, # head_size, # dtype=dtype) # key = torch.randn(sum(kv_lens), # num_heads, # head_size, # dtype=dtype) # value = torch.randn_like(key) # cu_query_lens = torch.tensor([0] + query_lens, # dtype=torch.int32).cumsum(dim=0, # dtype=torch.int32) # cu_kv_lens = torch.tensor([0] + kv_lens, # dtype=torch.int32).cumsum(dim=0, # dtype=torch.int32) # max_query_len = max(query_lens) # max_kv_len = max(kv_lens) # NUM_ROWS = (max_query_len + block_size_M - 1) // block_size_M # NNZ_S = 20 # NNZ_V = 2048 # batch_size = len(query_lens) # block_counts = [] # column_counts = [] # block_offsets = [] # column_indices = [] # for b in range(batch_size): # block_counts.append(torch.tensor([NNZ_S] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS)) # columns = kv_lens[b] - NNZ_S * block_size_N # column_counts.append(torch.tensor([columns] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS)) # block_offsets.append(torch.tensor([[i * block_size_N for i in range(NNZ_S)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_S)) # column_indices.append(torch.tensor([[NNZ_S * block_size_N + i for i in range(NNZ_V)]] * NUM_ROWS * num_heads, dtype=torch.int32).reshape(num_heads, NUM_ROWS, NNZ_V)) # block_count = torch.concat(block_counts).reshape(batch_size, num_heads, NUM_ROWS) # column_count = torch.concat(column_counts).reshape(batch_size, num_heads, NUM_ROWS) # block_offset = torch.concat(block_offsets).reshape(batch_size, num_heads, NUM_ROWS, NNZ_S) # column_index = torch.concat(column_indices).reshape(batch_size, num_heads, NUM_ROWS, NNZ_V) # out, lse = sparse_attn_varlen_func( # query, # key, # value, # block_count, # block_offset, # column_count, # column_index, # cu_seqlens_q=cu_query_lens, # cu_seqlens_k=cu_kv_lens, # max_seqlen_q=max_query_len, # max_seqlen_k=max_kv_len, # return_softmax_lse=True, # ) # max_num_blocks_per_seq = (max_kv_len + 2048 - 1) // 2048 # block_tables = torch.randint(0, # 2048, # (len(query_lens), max_num_blocks_per_seq), # dtype=torch.int32) # scale = head_size**-0.5 # ref_out, ref_lse, _ = ref_paged_attn( # query, # key, # value, # query_lens=query_lens, # kv_lens=kv_lens, # block_tables=block_tables, # scale=scale # ) # torch.testing.assert_close(out, ref_out, atol=2e-2, rtol=1e-2), \ # f"{torch.max(torch.abs(out - ref_out))}" # torch.testing.assert_close(lse, ref_lse, atol=2e-2, rtol=1e-2), \ # f"{torch.max(torch.abs(lse - ref_lse))}" if __name__ == "__main__": pytest.main([__file__])