from enum import Enum from typing import Optional import torch from ..jit import get_cudnn_fmha_gen_module try: import cudnn CUDNN_AVAILABLE = True except Exception: cudnn = None CUDNN_AVAILABLE = False # Global cudnn handle. need to make it per device in future _cudnn_handle = None def _create_cudnn_handle(stream: torch.cuda.Stream): global _cudnn_handle if _cudnn_handle is None: _cudnn_handle = cudnn.create_handle() cudnn.set_stream(_cudnn_handle, stream.cuda_stream) return _cudnn_handle # Tensor ids class UIDs(Enum): RESERVED_INVALID_UID = 0 Q_UID = 1 # Query tensor K_UID = 2 # Key cache tensor V_UID = 3 # Value cache tensor ACTUAL_SEQ_LENS_Q_UID = 100 # Actual sequence lengths for query tensor ACTUAL_SEQ_LENS_KV_UID = 101 # Actual sequence lengths for key/value tensor BLOCK_TABLES_UID = 200 # Block tables tensor BLOCK_TABLES_K_UID = 201 # Block tables tensor for key BLOCK_TABLES_V_UID = 202 # Block tables tensor for value RAGGED_Q_UID = 50 # Ragged query tensor RAGGED_O_UID = 51 # Ragged output tensor RAGGED_STATS_UID = 52 # Ragged stats tensor RAGGED_K_UID = 53 # Ragged key tensor RAGGED_V_UID = 54 # Ragged value tensor O_UID = 1000 # Output tensor STATS_UID = 1001 # Stats tensor def _sdpa_prefill_key_fn( q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, scale: float, *, max_token_seq_q: Optional[int] = None, max_sequence_kv: Optional[int] = None, actual_seq_lens_q: Optional[torch.Tensor] = None, actual_seq_lens_kv: torch.Tensor, block_tables: Optional[torch.Tensor] = None, page_size: Optional[int] = None, bottom_right_causal_mask: Optional[bool] = None, return_lse: Optional[bool] = False, batch_offsets_q: Optional[torch.Tensor] = None, batch_offsets_o: Optional[torch.Tensor] = None, batch_offsets_k: Optional[torch.Tensor] = None, batch_offsets_v: Optional[torch.Tensor] = None, batch_offsets_stats: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, ): graph_b = actual_seq_lens_q.shape[0] if q.dim() == 3: h_qo, d_qk = q.shape[1], q.shape[2] elif q.dim() == 4: h_qo, d_qk = q.shape[1], q.shape[3] if v_cache.dim() == 3: h_kv, d_vo = k_cache.shape[1], k_cache.shape[2] elif k_cache.dim() == 4: h_kv, d_vo = k_cache.shape[1], k_cache.shape[3] if block_tables is not None: page_size = k_cache.shape[2] key = ( graph_b, q.dim(), k_cache.dim(), max_token_seq_q, max_sequence_kv, h_qo, d_qk, h_kv, d_vo, block_tables is not None, return_lse, bottom_right_causal_mask, page_size, ) return key if CUDNN_AVAILABLE: @cudnn.jit(heur_modes=[cudnn.heur_mode.A]) @cudnn.graph_cache(key_fn=_sdpa_prefill_key_fn) def _build_prefill_graph( q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, scale: float, *, max_token_seq_q: Optional[int] = None, max_sequence_kv: Optional[int] = None, actual_seq_lens_q: Optional[torch.Tensor] = None, actual_seq_lens_kv: Optional[torch.Tensor] = None, block_tables: Optional[torch.Tensor] = None, bottom_right_causal_mask: Optional[bool] = True, return_lse: Optional[bool] = False, batch_offsets_q: Optional[torch.Tensor] = None, batch_offsets_o: Optional[torch.Tensor] = None, batch_offsets_k: Optional[torch.Tensor] = None, batch_offsets_v: Optional[torch.Tensor] = None, batch_offsets_stats: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, ): handle = _create_cudnn_handle(torch.cuda.current_stream(q.device)) graph_b = actual_seq_lens_q.shape[0] graph_s_qo = max_token_seq_q graph_s_kv = max_sequence_kv with cudnn.graph(handle) as (g, _): # Create tensors from the input tensors if q.dim() == 3: h_qo, d_qk = q.shape[1], q.shape[2] elif q.dim() == 4: h_qo, d_qk = q.shape[2], q.shape[3] else: raise ValueError(f"Invalid query tensor shape: {q.shape}") cudnn_q = g.tensor( name="q", dim=(graph_b, h_qo, graph_s_qo, d_qk), stride=(h_qo * d_qk, d_qk, d_qk * h_qo, 1), data_type=cudnn.data_type.BFLOAT16, ) if batch_offsets_q is not None: ragged_q = g.tensor_like(batch_offsets_q) ragged_q.set_uid(UIDs.RAGGED_Q_UID.value) cudnn_q.set_ragged_offset(ragged_q) if v_cache.dim() == 3: assert block_tables is None, ( "block_tables needs 4 dimensions of kv cache" ) h_kv, d_vo = v_cache.shape[1], v_cache.shape[2] elif v_cache.dim() == 4: h_kv, d_vo = ( v_cache.shape[1], v_cache.shape[3], ) else: raise ValueError(f"Invalid kv cache tensor shape: {k_cache.shape}") if k_cache.dim() == 3: cudnn_k_cache = g.tensor( name="k_cache", dim=(graph_b, h_kv, graph_s_kv, d_qk), stride=(h_kv * d_qk * graph_s_kv, d_qk, d_qk * h_kv, 1), data_type=cudnn.data_type.BFLOAT16, ) if batch_offsets_k is not None: ragged_k = g.tensor_like(batch_offsets_k) ragged_k.set_uid(UIDs.RAGGED_K_UID.value) cudnn_k_cache.set_ragged_offset(ragged_k) cudnn_v_cache = g.tensor( name="v_cache", dim=(graph_b, h_kv, graph_s_kv, d_vo), stride=(h_kv * d_vo * graph_s_kv, d_vo, d_vo * h_kv, 1), data_type=cudnn.data_type.BFLOAT16, ) if batch_offsets_v is not None: ragged_v = g.tensor_like(batch_offsets_v) ragged_v.set_uid(UIDs.RAGGED_V_UID.value) cudnn_v_cache.set_ragged_offset(ragged_v) elif k_cache.dim() == 4: cudnn_k_cache = g.tensor( name="k_cache", dim=k_cache.shape, stride=k_cache.stride(), data_type=cudnn.data_type.BFLOAT16, ) cudnn_v_cache = g.tensor( name="v_cache", dim=v_cache.shape, stride=v_cache.stride(), data_type=cudnn.data_type.BFLOAT16, ) cudnn_q.set_uid(UIDs.Q_UID.value) cudnn_k_cache.set_uid(UIDs.K_UID.value) cudnn_v_cache.set_uid(UIDs.V_UID.value) if block_tables is not None: nd_block_tables = block_tables.reshape( block_tables.shape[0], 1, block_tables.shape[1], 1 ) cudnn_k_block_tables = g.tensor_like(nd_block_tables) cudnn_k_block_tables.set_uid(UIDs.BLOCK_TABLES_K_UID.value) cudnn_v_block_tables = g.tensor_like(nd_block_tables) cudnn_v_block_tables.set_uid(UIDs.BLOCK_TABLES_V_UID.value) if actual_seq_lens_q is not None: cudnn_actual_seq_lens_q = g.tensor_like(actual_seq_lens_q) cudnn_actual_seq_lens_q.set_name("actual_seq_lens_q") cudnn_actual_seq_lens_q.set_uid(UIDs.ACTUAL_SEQ_LENS_Q_UID.value) if actual_seq_lens_kv is not None: cudnn_actual_seq_lens_kv = g.tensor_like(actual_seq_lens_kv) cudnn_actual_seq_lens_kv.set_name("actual_seq_lens_kv") cudnn_actual_seq_lens_kv.set_uid(UIDs.ACTUAL_SEQ_LENS_KV_UID.value) padding_mask = ( actual_seq_lens_q is not None and actual_seq_lens_kv is not None ) O, Stats = g.sdpa( name="sdpa", q=cudnn_q, k=cudnn_k_cache, v=cudnn_v_cache, seq_len_q=( cudnn_actual_seq_lens_q if actual_seq_lens_q is not None else None ), seq_len_kv=( cudnn_actual_seq_lens_kv if actual_seq_lens_kv is not None else None ), use_padding_mask=padding_mask, attn_scale=scale, generate_stats=return_lse, use_causal_mask_bottom_right=bottom_right_causal_mask, paged_attention_k_table=( cudnn_k_block_tables if block_tables is not None else None ), paged_attention_v_table=( cudnn_v_block_tables if block_tables is not None else None ), paged_attention_max_seq_len_kv=( graph_s_kv if block_tables is not None else None ), compute_data_type=cudnn.data_type.FLOAT, ) if batch_offsets_o is not None: ragged_o = g.tensor_like(batch_offsets_o) ragged_o.set_uid(UIDs.RAGGED_O_UID.value) O.set_ragged_offset(ragged_o) if batch_offsets_stats is not None: ragged_stats = g.tensor_like(batch_offsets_stats) ragged_stats.set_uid(UIDs.RAGGED_STATS_UID.value) Stats.set_ragged_offset(ragged_stats) O.set_uid(UIDs.O_UID.value).set_output(True).set_dim( [graph_b, h_qo, graph_s_qo, d_vo] ).set_stride( [graph_s_qo * d_vo * h_qo, d_vo, d_vo * h_qo, 1] ).set_data_type(cudnn.data_type.BFLOAT16) if return_lse: Stats.set_uid(UIDs.STATS_UID.value).set_output( return_lse ).set_data_type(cudnn.data_type.FLOAT).set_dim( [graph_b, h_qo, graph_s_qo, 1] ).set_stride([graph_s_qo * h_qo, 1, h_qo, 1]) tensors_to_return = [cudnn_q, cudnn_k_cache, cudnn_v_cache, O] if return_lse: tensors_to_return.append(Stats) if actual_seq_lens_q is not None: tensors_to_return.append(cudnn_actual_seq_lens_q) if actual_seq_lens_kv is not None: tensors_to_return.append(cudnn_actual_seq_lens_kv) return g, tensors_to_return def _batch_prefill_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, scale: float, workspace_buffer: torch.Tensor, *, max_token_per_sequence: int, max_sequence_kv: int, actual_seq_lens_q: torch.Tensor, actual_seq_lens_kv: torch.Tensor, block_tables: Optional[torch.Tensor] = None, causal: bool, return_lse: bool, batch_offsets_q: Optional[torch.Tensor] = None, batch_offsets_o: Optional[torch.Tensor] = None, batch_offsets_k: Optional[torch.Tensor] = None, batch_offsets_v: Optional[torch.Tensor] = None, batch_offsets_stats: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: graph, tensors = _build_prefill_graph( q=q, k_cache=k_cache, v_cache=v_cache, scale=scale, max_token_seq_q=max_token_per_sequence, max_sequence_kv=max_sequence_kv, actual_seq_lens_q=actual_seq_lens_q, actual_seq_lens_kv=actual_seq_lens_kv, block_tables=block_tables, bottom_right_causal_mask=causal, return_lse=return_lse, batch_offsets_q=batch_offsets_q, batch_offsets_o=batch_offsets_o, batch_offsets_k=batch_offsets_k, batch_offsets_v=batch_offsets_v, batch_offsets_stats=batch_offsets_stats, out=out, lse=lse, ) var_map = { UIDs.Q_UID.value: q, UIDs.K_UID.value: k_cache, UIDs.V_UID.value: v_cache, UIDs.O_UID.value: out, } if actual_seq_lens_q is not None: var_map[UIDs.ACTUAL_SEQ_LENS_Q_UID.value] = actual_seq_lens_q if actual_seq_lens_kv is not None: var_map[UIDs.ACTUAL_SEQ_LENS_KV_UID.value] = actual_seq_lens_kv if batch_offsets_q is not None: var_map[UIDs.RAGGED_Q_UID.value] = batch_offsets_q if batch_offsets_o is not None: var_map[UIDs.RAGGED_O_UID.value] = batch_offsets_o if batch_offsets_k is not None: var_map[UIDs.RAGGED_K_UID.value] = batch_offsets_k if batch_offsets_v is not None: var_map[UIDs.RAGGED_V_UID.value] = batch_offsets_v if block_tables is not None: var_map[UIDs.BLOCK_TABLES_K_UID.value] = block_tables var_map[UIDs.BLOCK_TABLES_V_UID.value] = block_tables if return_lse: var_map[UIDs.STATS_UID.value] = lse if batch_offsets_stats is not None: var_map[UIDs.RAGGED_STATS_UID.value] = batch_offsets_stats handle = _create_cudnn_handle(torch.cuda.current_stream(q.device)) graph.execute(var_map, workspace=workspace_buffer, handle=handle) if return_lse: return out, lse else: return out, None def cudnn_batch_prefill_with_kv_cache( q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, scale: float, workspace_buffer: torch.Tensor, *, max_token_per_sequence: int, max_sequence_kv: int, actual_seq_lens_q: torch.Tensor, actual_seq_lens_kv: torch.Tensor, block_tables: Optional[torch.Tensor] = None, causal: bool, return_lse: bool, batch_offsets_q: Optional[torch.Tensor] = None, batch_offsets_o: Optional[torch.Tensor] = None, batch_offsets_k: Optional[torch.Tensor] = None, batch_offsets_v: Optional[torch.Tensor] = None, batch_offsets_stats: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, is_cuda_graph_compatible: bool = False, backend: Optional[str] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """Performs batched prefill attention with paged KV cache using cuDNN. Args: q: Query tensor of shape (Total number of tokens, num_heads_qo, head_dim) k_cache: Key cache tensor of shape (total_num_pages, num_heads_kv, page_size, head_dim) if paged kv cache is enabled else (Total sequence length of kv, num_heads_kv, d_qk) v_cache: Value cache tensor of shape (total_num_pages, num_heads_kv, page_size, head_dim) if paged kv cache is enabled else (Total sequence length of kv, num_heads_kv, d_vo) scale: Scaling factor for attention scores, typically 1/sqrt(head_dim) workspace_buffer: Workspace buffer for cuDNN operations. Scales with batch size. 128 MB should be sufficient for most cases max_token_per_sequence: Maximum number of tokens per query sequence (s_qo_max) max_sequence_kv: Maximum number of tokens per key/value sequence (s_kv_max) actual_seq_lens_q: Actual number of tokens per query sequence shape (batch_size,) on cpu or device (cpu if cuda_graph is False) actual_seq_lens_kv: Actual sequence lengths for key/values per batch, shape (batch_size,) on CPU or device (cpu if cuda_graph is False) block_tables: Page table mapping for KV cache, shape (batch_size, num_pages_per_seq) on GPU causal: Whether to apply causal masking return_lse: Whether to return log-sum-exp values (must be True) out: Optional pre-allocated output tensor lse: Optional pre-allocated tensor for log-sum-exp values if return_lse is True else returns None is_cuda_graph_compatible: Whether the prefill operation is compatible with CUDA graph batch_offsets_q: Optional batch offsets for query tensor of shape (batch_size,) on GPU batch_offsets_o: Optional batch offsets for output tensor of shape (batch_size,) on GPU batch_offsets_k: Optional batch offsets for key tensor of shape (batch_size,) on GPU batch_offsets_v: Optional batch offsets for value tensor of shape (batch_size,) on GPU Returns: Output tensor of shape (batch_size * seq_len_q, num_heads_qo, head_dim) If return_lse is True, also returns log-sum-exp tensor of shape (batch_size, seq_len_q, num_heads_qo) Note: Query and KV heads can have different sizes (num_heads_qo >= num_heads_kv) When using cuda graph, actual_seq_lens_q and actual_seq_lens_kv must be on the same device as q Head dimension of query and key must be 128 or 192 Head dimension of value and output must be 128 """ num_tokens = q.shape[0] num_sequences = actual_seq_lens_q.shape[0] if q.dim() == 3: h_qo, d_qk = q.shape[1], q.shape[2] elif q.dim() == 4: h_qo, d_qk = q.shape[1], q.shape[3] if v_cache.dim() == 3: d_vo = v_cache.shape[2] elif v_cache.dim() == 4: d_vo = v_cache.shape[3] if return_lse: if lse is None: lse = torch.empty( num_sequences, max_token_per_sequence, h_qo, device=q.device, dtype=torch.float32, ) if lse is not None and lse.shape != (num_sequences, max_token_per_sequence, h_qo): raise ValueError( "lse must have shape (num_sequences, max_token_per_sequence, h_qo)" ) if out is None: out_shape = (num_tokens, h_qo, d_vo) out = torch.empty(out_shape, device=q.device, dtype=q.dtype) if CUDNN_AVAILABLE and backend != "cubin": return _batch_prefill_with_kv_cache( q=q, k_cache=k_cache, v_cache=v_cache, scale=scale, workspace_buffer=workspace_buffer, max_token_per_sequence=max_token_per_sequence, max_sequence_kv=max_sequence_kv, actual_seq_lens_q=actual_seq_lens_q, actual_seq_lens_kv=actual_seq_lens_kv, block_tables=block_tables, causal=causal, return_lse=return_lse, batch_offsets_q=batch_offsets_q, batch_offsets_o=batch_offsets_o, batch_offsets_k=batch_offsets_k, batch_offsets_v=batch_offsets_v, batch_offsets_stats=batch_offsets_stats, out=out, lse=lse, ) else: assert return_lse, "Currently only supports return_lse = True" assert (d_qk == 192 and block_tables is None) or ( d_qk == 128 and block_tables is not None ), ( "Currently only supports if d_qk = 192 and block_tables is None or d_qk = 128 and block_tables is not None" ) if max_sequence_kv is None: max_sequence_kv = max_token_per_sequence actual_seq_lens_q_gpu = actual_seq_lens_q.to(q.device, non_blocking=True) actual_seq_lens_kv_gpu = actual_seq_lens_kv.to(q.device, non_blocking=True) run_func = get_cudnn_fmha_gen_module().prefill run_func( num_sequences, max_token_per_sequence, # max_s_qo max_sequence_kv, # max_s_kv q, k_cache, v_cache, scale, workspace_buffer, actual_seq_lens_q, # actual_seq_lens_q actual_seq_lens_kv, # actual_seq_lens_kv actual_seq_lens_q_gpu, actual_seq_lens_kv_gpu, block_tables, causal, return_lse, out, lse, None, None, None, None, is_cuda_graph_compatible, ) return out, lse