from typing import List, Optional, Tuple, Union import torch import torch.nn as nn def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x # Sparse attention utils def convert_vertical_slash_indexes( q_seqlens: torch.Tensor, # [BATCH, ] kv_seqlens: torch.Tensor, # [BATCH, ] vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] context_size: int, block_size_M: int, block_size_N: int, causal: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: batch_size = slash_indexes.size(0) num_heads = slash_indexes.size(1) nnz_slash = slash_indexes.size(2) nnz_vertical = vertical_indexes.size(2) num_rows = (context_size + block_size_M - 1) // block_size_M block_count = torch.zeros( batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device ) block_offset = torch.zeros( batch_size, num_heads, num_rows, nnz_slash, dtype=q_seqlens.dtype, device=q_seqlens.device, ) column_count = torch.zeros( batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device ) column_index = torch.zeros( batch_size, num_heads, num_rows, nnz_vertical, dtype=q_seqlens.dtype, device=q_seqlens.device, ) torch.ops.sgl_kernel.convert_vertical_slash_indexes.default( block_count, block_offset, column_count, column_index, q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, context_size, block_size_M, block_size_N, causal, ) return block_count, block_offset, column_count, column_index def convert_vertical_slash_indexes_mergehead( q_seqlens: torch.Tensor, # [BATCH, ] kv_seqlens: torch.Tensor, # [BATCH, ] vertical_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_V] slash_indexes: torch.Tensor, # [BATCH, N_HEADS, NNZ_S] # [N_HEADS] : different head use different number of indices vertical_indices_count: torch.Tensor, slash_indices_count: torch.Tensor, context_size: int, block_size_M: int, block_size_N: int, causal: bool = True, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: batch_size = slash_indexes.size(0) num_heads = slash_indexes.size(1) nnz_slash = slash_indexes.size(2) nnz_vertical = vertical_indexes.size(2) num_rows = (context_size + block_size_M - 1) // block_size_M block_count = torch.empty( batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device ) block_offset = torch.empty( batch_size, num_heads, num_rows, nnz_slash, dtype=q_seqlens.dtype, device=q_seqlens.device, ) column_count = torch.empty( batch_size, num_heads, num_rows, dtype=q_seqlens.dtype, device=q_seqlens.device ) column_index = torch.empty( batch_size, num_heads, num_rows, nnz_vertical, dtype=q_seqlens.dtype, device=q_seqlens.device, ) torch.ops.sgl_kernel.convert_vertical_slash_indexes_mergehead.default( block_count, block_offset, column_count, column_index, q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, vertical_indices_count, slash_indices_count, context_size, block_size_M, block_size_N, causal, ) return block_count, block_offset, column_count, column_index def sparse_attn_func( q, k, v, block_count, block_offset, column_count, column_index, dropout_p=0.0, softmax_scale=None, causal=False, softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, *, return_softmax_lse=False, out=None, ): """Compute attention with vertical and slash sparsity patterns. Most Arguments are the same with the flash_attn_func interface, except for 4 extra args: block_count and block_offset for slash sparsity patterns, and column_count and column_index for vertical sparsity patterns. For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. Arguments: q: (batch_size, seqlen, nheads, headdim) k: (batch_size, seqlen, nheads_k, headdim) v: (batch_size, seqlen, nheads_k, headdim) block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (batch_size, seqlen, nheads, headdim). softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse = torch.ops.sgl_kernel.fwd_sparse.default( q, k, v, block_count, block_offset, column_count, column_index, out, alibi_slopes, dropout_p, softmax_scale, causal, softcap, return_attn_probs and dropout_p > 0, None, ) return (out, softmax_lse) if return_softmax_lse else out def sparse_attn_varlen_func( q, k, v, block_count, block_offset, column_count, column_index, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, causal=False, softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, *, return_softmax_lse=False, out=None, ): """Compute attention with vertical and slash sparsity patterns. Most Arguments are the same with the flash_attn_varlen_func interface, except for 4 extra args: block_count and block_offset for slash sparsity patterns, and column_count and column_index for vertical sparsity patterns. For more details please refer to Appendix C.4.2 of paper https://arxiv.org/abs/2407.02490. Arguments: q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. block_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) block_offset: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_S) column_count: (batch_size, nheads, cdiv(seqlen, BLOCK_M)) column_index: (batch_size, nheads, cdiv(seqlen, BLOCK_M), NNZ_V) cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into q. cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, used to index into kv. max_seqlen_q: int. Maximum query sequence length in the batch. max_seqlen_k: int. Maximum key sequence length in the batch. dropout_p: float. Dropout probability. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, which is slightly slower and uses more memory. The forward pass is always deterministic. return_attn_probs: bool. Whether to return the attention probabilities. This option is for testing only. The returned probabilities are not guaranteed to be correct (they might not have the right scaling). Return: out: (total, nheads, headdim). softmax_lse [optional, if return_softmax_lse=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). """ if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, softmax_lse = torch.ops.sgl_kernel.varlen_fwd_sparse.default( q, k, v, block_count, block_offset, column_count, column_index, out, cu_seqlens_q, cu_seqlens_k, None, alibi_slopes, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, softcap, return_attn_probs and dropout_p > 0, None, ) return (out, softmax_lse) if return_softmax_lse else out