# Adapted from https://github.com/Dao-AILab/flash-attention/blob/203b9b3dba39d5d08dffb49c09aa622984dff07d/flash_attn/cute/interface.py # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0. import math from typing import Optional, Tuple import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute import torch from cutlass.cute.runtime import from_dlpack from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x torch2cute_dtype_map = { torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, torch.float32: cutlass.Float32, } def _flash_attn_fwd( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, softcap: Optional[float] = None, window_size_left: Optional[int] = None, window_size_right: Optional[int] = None, learnable_sink: Optional[torch.Tensor] = None, # m_block_size: int = 128, # n_block_size: int = 64, # num_threads: int = 128, m_block_size: int = 128, n_block_size: int = 128, num_threads: int = 384, pack_gqa: Optional[bool] = None, _compute_capability: Optional[int] = None, return_softmax_lse: Optional[bool] = False, ) -> Tuple[torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] if cu_seqlens_q is None: batch_size, seqlen_q = q.shape[:2] total_q = batch_size * seqlen_q else: batch_size = cu_seqlens_q.shape[0] - 1 seqlen_q = None total_q = q.shape[0] if page_table is not None: assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k" assert page_table.dtype == torch.int32, "page_table must be int32" assert ( page_table.stride(-1) == 1 ), "page_table must be contiguous in the last dimension" max_num_pages_per_seq = page_table.shape[1] assert page_table.shape == (batch_size, max_num_pages_per_seq) num_pages, page_size = k.shape[:2] seqlen_k = num_pages * page_size else: num_pages, page_size = None, None seqlen_k = k.shape[-3] num_head_kv = k.shape[-2] head_dim_v = v.shape[-1] if cu_seqlens_k is None: if page_table is None: assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) else: assert k.shape == (num_pages, page_size, num_head_kv, head_dim) assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v) else: assert k.shape == (seqlen_k, num_head_kv, head_dim) assert v.shape == (seqlen_k, num_head_kv, head_dim_v) assert cu_seqlens_k.shape == ( batch_size + 1, ), "cu_seqlens_k must have shape (batch_size + 1,)" if cu_seqlens_q is not None: assert cu_seqlens_q.shape == ( batch_size + 1, ), "cu_seqlens_q must have shape (batch_size + 1,)" assert seqused_q is None or seqused_q.shape == ( batch_size, ), "seqused_q must have shape (batch_size,)" assert seqused_k is None or seqused_k.shape == ( batch_size, ), "seqused_k must have shape (batch_size,)" assert q.dtype in [ torch.float16, torch.bfloat16, ], "inputs must be float16 or bfloat16" assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: if t is not None: assert ( t.dtype == torch.int32 ), "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" assert ( t.stride(0) == 1 ), "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" if learnable_sink is not None: assert learnable_sink.shape == (num_head,) assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" assert all( t is None or t.is_cuda for t in ( q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, learnable_sink, ) ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" assert head_dim <= 256, "head_dim must be less than or equal to 256" alignment = 16 // q.element_size() assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}" assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}" if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(head_dim) if softcap == 0.0: softcap = None qhead_per_kvhead = num_head // num_head_kv if pack_gqa is None: pack_gqa = qhead_per_kvhead > 1 out_torch_dtype = q.dtype device = q.device q_batch_seqlen_shape = ( (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) ) out = torch.empty( *q_batch_seqlen_shape, num_head, head_dim_v, dtype=out_torch_dtype, device=device, ) lse_shape = ( (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) ) lse = ( torch.empty(lse_shape, dtype=torch.float32, device=device) if return_softmax_lse else None ) dtype = torch2cute_dtype_map[q.dtype] q_tensor, k_tensor, v_tensor, o_tensor = [ from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic( leading_dim=t.ndim - 1 ) for t in (q, k, v, out) ] lse_tensor = ( from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic( leading_dim=lse.ndim - 1 ) if lse is not None else None ) ( cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, learnable_sink_tensor, ) = [ ( from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=0) if t is not None else None ) for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, learnable_sink) ] page_table_tensor = ( from_dlpack(page_table.detach(), assumed_align=4).mark_layout_dynamic( leading_dim=1 ) if page_table is not None else None ) if causal: window_size_right = 0 local = window_size_left is not None or window_size_right is not None if window_size_left is not None or window_size_right is not None: if window_size_left is None and window_size_right == 0: causal, local = True, False else: causal, local = False, True compute_capability = ( torch.cuda.get_device_capability()[0] if _compute_capability is None else _compute_capability ) assert compute_capability in [ 9, 10, ], "Unsupported compute capability. Supported: 9.x, 10.x" current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) if compute_capability == 9: # TODO: tune block size according to hdim if head_dim == head_dim_v == 128 and not causal and not local: n_block_size = 192 if compute_capability == 10: # TODO: fix the varlen case if ( pack_gqa and (128 % qhead_per_kvhead != 0) or (cu_seqlens_q is not None or seqused_q is not None) ): pack_gqa = False compile_key = ( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, softcap is not None, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, seqused_k is None, page_table is not None, window_size_left is not None, window_size_right is not None, learnable_sink is not None, m_block_size, n_block_size, num_threads, pack_gqa, compute_capability, ) if compile_key not in _flash_attn_fwd.compile_cache: if compute_capability == 9: assert page_table is None, "paged KV not supported on SM 9.0" # fa_fwd = FlashAttentionForwardSm80( fa_fwd = FlashAttentionForwardSm90( dtype, head_dim, head_dim_v, qhead_per_kvhead, is_causal=causal, is_local=local, pack_gqa=pack_gqa, m_block_size=m_block_size, n_block_size=n_block_size, # num_stages=1, num_stages=2, num_threads=num_threads, Q_in_regs=False, ) elif compute_capability == 10: assert page_size in [ None, 128, ], "Only page_size=128 is supported for paged KV on SM 10.0" fa_fwd = FlashAttentionForwardSm100( head_dim, head_dim_v, qhead_per_kvhead=qhead_per_kvhead, is_causal=causal, is_local=local, pack_gqa=pack_gqa, is_persistent=not causal and not local and cu_seqlens_q is None and seqused_q is None, ) else: raise ValueError( f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x" ) # TODO: check @can_implement _flash_attn_fwd.compile_cache[compile_key] = cute.compile( fa_fwd, q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, page_table_tensor, softcap, window_size_left, window_size_right, learnable_sink_tensor, ) _flash_attn_fwd.compile_cache[compile_key]( q_tensor, k_tensor, v_tensor, o_tensor, lse_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor, page_table_tensor, softcap, window_size_left, window_size_right, learnable_sink_tensor, ) return out, lse _flash_attn_fwd.compile_cache = {} def flash_attn_varlen_func( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_k: Optional[torch.Tensor] = None, seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, pack_gqa: Optional[bool] = None, return_softmax_lse: Optional[bool] = False, ) -> Tuple[torch.Tensor, torch.Tensor]: out, lse = _flash_attn_fwd( q, k, v, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table=page_table, softmax_scale=softmax_scale, causal=causal, window_size_left=window_size[0], window_size_right=window_size[1], learnable_sink=learnable_sink, softcap=softcap, pack_gqa=pack_gqa, return_softmax_lse=return_softmax_lse, ) return (out, lse) if return_softmax_lse else out