/* * Copyright (c) 2025 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FLASHINFER_PERSISTENT_CUH_ #define FLASHINFER_PERSISTENT_CUH_ #include "../cp_async.cuh" #include "../math.cuh" #include "../utils.cuh" #include "mask.cuh" #include "persistent_template.cuh" #include "prefill.cuh" #include "state.cuh" namespace flashinfer { using cp_async::PrefetchMode; using cp_async::SharedMemFillMode; template __device__ __forceinline__ auto get_block_coord(const Params& params, const uint32_t work_idx) { return std::tuple(params.q_indptr[work_idx], params.kv_indptr[work_idx], params.partial_indptr[work_idx], params.q_len[work_idx], params.kv_len[work_idx], params.q_start[work_idx], params.kv_start[work_idx], params.kv_end[work_idx], params.kv_head_idx_arr[work_idx], *params.len_kv_chunk); } template __device__ __forceinline__ void prefetch_offest( const uint32_t packed_block_iter_base, const uint32_t packed_kv_bound, const uint32_t kv_head_idx, const uint32_t kv_stride_page, const uint32_t kv_stride_h, const uint32_t kv_stride_n, const uint_fastdiv& block_size, typename KTraits::IdType* indices, size_t* kv_offset) { using DTypeKV = typename KTraits::DTypeKV; constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; const uint32_t lane_idx = threadIdx.x % 32, warp_idx = threadIdx.x / 32; #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { uint32_t page_iter, entry_idx; uint32_t packed_block_iter = packed_block_iter_base + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i; block_size.divmod(packed_block_iter, page_iter, entry_idx); kv_offset[i] = (packed_block_iter < packed_kv_bound ? indices[page_iter] : 0) * kv_stride_page + entry_idx * kv_stride_n + kv_head_idx * kv_stride_h + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); } } template __device__ __forceinline__ void write_o_(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], smem_t* o_smem, typename KTraits::DTypeO* o_ptr_base, const uint32_t o_packed_idx_base_warp, const uint32_t o_packed_idx_base_cta, const uint32_t qo_upper_bound, const uint32_t o_stride_n, const uint_fastdiv group_size, const uint32_t warp_idx, const uint32_t lane_idx, const dim3 tid) { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; const uint32_t warp_idx_x = get_warp_idx_q(tid.y), warp_idx_z = get_warp_idx_kv(tid.z); static_assert(sizeof(DTypeO) == 2); if (warp_idx_z == 0) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { uint32_t o_frag_f16[8 / 2]; vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED uint32_t o_smem_offset_w = o_smem->get_permuted_offset( (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16, mma_d * 2 + lane_idx / 16); o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); #else uint32_t o_smem_offset_w = o_smem->get_permuted_offset( (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = o_frag_f16[1]; ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2]; ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = o_frag_f16[3]; #endif } } uint32_t o_smem_offset_w = o_smem->get_permuted_offset( warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2 * 2; ++j) { uint32_t q, r; const uint32_t o_packed_idx = o_packed_idx_base_warp + lane_idx / 8 + mma_q * 16 + j * 4; group_size.divmod(o_packed_idx, q, r); const uint32_t o_idx = q; DTypeO* o_ptr = o_ptr_base + (o_packed_idx - o_packed_idx_base_cta) * o_stride_n + (lane_idx % 8) * upcast_size(); #pragma unroll for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) { if (o_idx < qo_upper_bound) { o_smem->store_128b(o_smem_offset_w, o_ptr); } o_ptr += 8 * upcast_size(); o_smem_offset_w = o_smem->template advance_offset_by_column<8>(o_smem_offset_w, mma_do); } o_smem_offset_w = o_smem->template advance_offset_by_row<4, UPCAST_STRIDE_O>(o_smem_offset_w) - 2 * KTraits::NUM_MMA_D_VO; } } } } template __device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) { using AttentionVariant = typename KTraits::AttentionVariant; if constexpr (AttentionVariant::use_softmax) { float d_rcp[KTraits::NUM_MMA_Q][2]; // compute reciprocal of d #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { d_rcp[mma_q][j] = (m[mma_q][j] != typename KTraits::DTypeQKAccum(-math::inf)) ? math::ptx_rcp(d[mma_q][j]) : 0.f; } } #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { o_frag[mma_q][mma_d][reg_id] = o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id >> 1) & 1]; } } } } } template struct BlockBatchPagedAttentionPersistent { using KTraits = KTraits_; using Params = Params_; static __device__ __forceinline__ void Run(const Params& params, typename KTraits::SharedStorage* smem_storage PROFILER_CLOSURE_FUNC_PARAMS) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; using IdType = typename Params::IdType; using DTypeQKAccum = typename KTraits::DTypeQKAccum; using AttentionVariant = typename KTraits::AttentionVariant; [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = KTraits::SWIZZLE_MODE_Q; [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; [[maybe_unused]] constexpr bool CAUSAL = KTraits::MASK_MODE == MaskMode::kCausal; [[maybe_unused]] constexpr uint32_t NUM_STAGES = KTraits::NUM_STAGES; DTypeQ* q = params.q; DTypeKV* k = params.k; DTypeKV* v = params.v; IdType* kv_indices = params.kv_indices; float* partial_lse = params.partial_lse; IdType* work_indptr = params.work_indptr; float s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; float m[NUM_MMA_Q][2]; float d[NUM_MMA_Q][2]; const uint_fastdiv& gqa_group_size = params.gqa_group_size; const uint32_t num_kv_heads = params.num_kv_heads; const uint_fastdiv& block_size = params.page_size; const uint32_t q_stride_n = params.q_stride_n; const uint32_t q_stride_h = params.q_stride_h; const uint32_t k_stride_page = params.k_stride_page; const uint32_t k_stride_h = params.k_stride_h; const uint32_t k_stride_n = params.k_stride_n; const uint32_t v_stride_page = params.v_stride_page; const uint32_t v_stride_h = params.v_stride_h; const uint32_t v_stride_n = params.v_stride_n; const uint32_t cluster_tile_q = gridDim.x * CTA_TILE_Q; smem_t q_smem(smem_storage->q_smem); AttentionVariant variant(params, /*batch_idx=*/0, nullptr); const uint32_t lane_idx = threadIdx.x % 32; const uint32_t warp_idx = threadIdx.x / 32; // threadIdx: [32, NUM_WARPS_Q, NUM_WARPS_KV] // remap to utilize tool function in FA2 prefill const dim3 tid = dim3(lane_idx, warp_idx % NUM_WARPS_Q, warp_idx / NUM_WARPS_Q); uint32_t q_smem_offset_r = get_permuted_offset( get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); uint32_t k_smem_offset_r = get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), v_smem_offset_r = get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16); uint32_t k_smem_offset_w = get_permuted_offset( warp_idx * KTraits::KV_THR_LAYOUT_ROW + lane_idx / KTraits::KV_THR_LAYOUT_COL, lane_idx % KTraits::KV_THR_LAYOUT_COL), v_smem_offset_w = get_permuted_offset( warp_idx * KTraits::KV_THR_LAYOUT_ROW + lane_idx / KTraits::KV_THR_LAYOUT_COL, lane_idx % KTraits::KV_THR_LAYOUT_COL); size_t thr_local_kv_offset[NUM_MMA_KV * KTraits::KV_THR_LAYOUT_COL / 2 / KTraits::NUM_WARPS_Q]; #pragma unroll 1 for (IdType work_idx = work_indptr[blockIdx.y]; work_idx < work_indptr[blockIdx.y + 1]; ++work_idx) { // profile log if constexpr (CTA_TILE_Q > 64) { PROFILER_EVENT_START(profiler_closure, PersistentProfileEventType::kRunner1); } else { PROFILER_EVENT_START(profiler_closure, PersistentProfileEventType::kRunner2); } const auto [q_indptr, kv_indptr, o_indptr, q_len, kv_len, packed_qo_start, kv_start, kv_end, kv_head_idx, len_kv_chunk] = get_block_coord(params, work_idx); const uint32_t kv_chunk_idx = kv_start / len_kv_chunk; const uint32_t num_kv_chunks = ceil_div( CAUSAL ? min((kv_len - q_len) + ceil_div(packed_qo_start + cluster_tile_q, gqa_group_size), kv_len) : kv_len, len_kv_chunk); const uint32_t qo_packed_idx_base = packed_qo_start + blockIdx.x * CTA_TILE_Q + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16; const uint32_t qo_upperbound = min(q_len, ceil_div(qo_packed_idx_base + CTA_TILE_Q, gqa_group_size)); init_states(variant, o_frag, m, d); DTypeQ* q_ptr_base = q + q_indptr * q_stride_n + (kv_head_idx * gqa_group_size) * q_stride_h; // load_q load_q_global_smem(qo_packed_idx_base, qo_upperbound, q_ptr_base, q_stride_n, q_stride_h, gqa_group_size, &q_smem, tid); smem_t k_smem(smem_storage->k_smem), v_smem(smem_storage->v_smem); int kv_tile_idx = ceil_div((CAUSAL ? min(kv_end, kv_len - q_len + ceil_div((packed_qo_start + cluster_tile_q), gqa_group_size)) : kv_end), CTA_TILE_KV) - 1 - (kv_start / CTA_TILE_KV); int mask_tile_idx = (CAUSAL ? min(kv_end, kv_len - q_len + ceil_div(packed_qo_start, gqa_group_size)) : kv_end) / CTA_TILE_KV - (kv_start / CTA_TILE_KV); uint32_t block_iter_base = kv_indptr * block_size + kv_start; // last kv tile __syncthreads(); uint32_t packed_kv_bound = kv_indptr * block_size + kv_len; prefetch_offest(block_iter_base + kv_tile_idx * CTA_TILE_KV, packed_kv_bound, kv_head_idx, k_stride_page, k_stride_h, k_stride_n, block_size, kv_indices, thr_local_kv_offset); page_produce_kv(smem_storage, &k_smem_offset_w, k, kv_start + kv_tile_idx * CTA_TILE_KV, thr_local_kv_offset, kv_end, warp_idx, lane_idx); cp_async::commit_group(); page_produce_kv(smem_storage, &v_smem_offset_w, v, kv_start + kv_tile_idx * CTA_TILE_KV, thr_local_kv_offset, kv_end, warp_idx, lane_idx); cp_async::commit_group(); // loop with mask LOOP_SPLIT_MASK( kv_tile_idx, kv_tile_idx >= mask_tile_idx && kv_tile_idx > 0, kv_tile_idx + 1 > NUM_STAGES, { prefetch_offest(block_iter_base + (kv_tile_idx - 1) * CTA_TILE_KV, packed_kv_bound, kv_head_idx, k_stride_page, k_stride_h, k_stride_n, block_size, kv_indices, thr_local_kv_offset); cp_async::wait_group<1>(); __syncthreads(); compute_qk(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); if constexpr (AttentionVariant::use_logits_soft_cap) { logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, kv_start + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, q_len, kv_len, gqa_group_size, s_frag, tid, kv_head_idx); } if constexpr (WITH_MASK) { logits_mask( params, variant, /*batch_idx=*/0, qo_packed_idx_base, kv_start + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, q_len, kv_len, kv_end, gqa_group_size, s_frag, tid, kv_head_idx); } update_mdo_states(variant, s_frag, o_frag, m, d); __syncthreads(); page_produce_kv(smem_storage, &k_smem_offset_w, k, kv_start + (kv_tile_idx - 1) * CTA_TILE_KV, thr_local_kv_offset, kv_end, warp_idx, lane_idx); cp_async::commit_group(); cp_async::wait_group<1>(); __syncthreads(); compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); __syncthreads(); page_produce_kv(smem_storage, &v_smem_offset_w, v, kv_start + (kv_tile_idx - 1) * CTA_TILE_KV, thr_local_kv_offset, kv_end, warp_idx, lane_idx); cp_async::commit_group(); }); cp_async::wait_group<0>(); __syncthreads(); #pragma unroll for (; kv_tile_idx >= 0; --kv_tile_idx) { compute_qk(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); if constexpr (AttentionVariant::use_logits_soft_cap) { logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, kv_start + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, q_len, kv_len, gqa_group_size, s_frag, tid, kv_head_idx); } logits_mask( params, variant, /*batch_idx=*/0, qo_packed_idx_base, kv_start + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, q_len, kv_len, kv_end, gqa_group_size, s_frag, tid, kv_head_idx); update_mdo_states(variant, s_frag, o_frag, m, d); compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); } __syncthreads(); finalize_m(variant, m); // threadblock synchronization threadblock_sync_mdo_states(o_frag, smem_storage, m, d, warp_idx, lane_idx, tid); // normalize d normalize_d(o_frag, m, d); // write back to global memory // o_indptr (partial_o): [packed_qo_len * num_kv_chunks, num_kv_heads, head_dim] // q_indpt (final_o): [qo_len, num_kv_heads, gqa_group_size, head_dim] if (num_kv_chunks > 1) { DTypeO* o_ptr_base = params.partial_o + ((o_indptr + kv_chunk_idx) * num_kv_heads + kv_head_idx) * HEAD_DIM_VO; write_o_(o_frag, &q_smem, o_ptr_base, qo_packed_idx_base, packed_qo_start, qo_upperbound, num_kv_chunks * num_kv_heads * HEAD_DIM_VO, gqa_group_size, warp_idx, lane_idx, tid); } else { // write through DTypeO* o_ptr_base = params.final_o + q_indptr * q_stride_n + (kv_head_idx * gqa_group_size) * q_stride_h; write_o_reg_gmem(o_frag, &q_smem, o_ptr_base, qo_packed_idx_base, q_len, q_stride_n, q_stride_h, gqa_group_size, tid); } if constexpr (variant.use_softmax) { if (get_warp_idx_kv(tid.z) == 0) { #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { uint32_t q, r; const uint32_t packed_qo_idx = qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16; gqa_group_size.divmod(packed_qo_idx, q, r); if (q < qo_upperbound) { if (num_kv_chunks > 1) { partial_lse[(o_indptr + (packed_qo_idx - packed_qo_start) * num_kv_chunks + kv_chunk_idx) * num_kv_heads + kv_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else if (params.final_lse != nullptr) { // write through const uint32_t qo_head_idx = kv_head_idx * gqa_group_size + r; params.final_lse[(q_indptr + q) * num_kv_heads * gqa_group_size + qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } } } } // profile if constexpr (CTA_TILE_Q > 64) { PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kRunner1); } else { PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kRunner2); } } } }; template struct StateReductionKernelTraits { using DTypeIn = DTypeIn_; using DTypeO = DTypeO_; using IdType = IdType_; static constexpr uint32_t HEAD_DIM_VO = HEAD_DIM_VO_; static constexpr uint32_t NUM_SMEM_STAGES = NUM_SMEM_STAGES_; static constexpr uint32_t NUM_THREADS = NUM_THREADS_; static constexpr uint32_t NUM_WARPS = NUM_THREADS / 32; static constexpr uint32_t vec_size = std::max(16U / static_cast(sizeof(DTypeIn)), HEAD_DIM_VO / 32U); static constexpr uint32_t bdx = HEAD_DIM_VO / vec_size; // gridDim is accessed by runtime variable and should be set by core attention // workload layout [bdx, bdy, num_warps] static_assert(NUM_THREADS % bdx == 0); static constexpr uint32_t bdy = 32 / bdx; // pipeline load & reduction static constexpr size_t SMEM_SIZE = NUM_WARPS * NUM_SMEM_STAGES * bdy * HEAD_DIM_VO * sizeof(DTypeIn) + NUM_THREADS * sizeof(float); }; template struct BlockBatchReductionPersistent { using KTraits = KTraits_; static __device__ __forceinline__ void Run( typename KTraits::DTypeIn* __restrict__ V, typename KTraits::DTypeO* __restrict__ v_merged, float* __restrict__ S, float* __restrict__ s_merged, const typename KTraits::IdType num_packed_qo_len, const uint_fastdiv gqa_group_size, const uint32_t num_kv_heads, const typename KTraits::IdType* indptr, const typename KTraits::IdType* o_indices, uint8_t* smem PROFILER_CLOSURE_FUNC_PARAMS) { __syncthreads(); // NOTE(Zihao): required for guarantee correctness on blackwell using DTypeIn = typename KTraits::DTypeIn; using DTypeO = typename KTraits::DTypeO; using IdType = typename KTraits::IdType; [[maybe_unused]] constexpr uint32_t bdx = KTraits::bdx; [[maybe_unused]] constexpr uint32_t bdy = KTraits::bdy; [[maybe_unused]] constexpr uint32_t num_warps = KTraits::NUM_WARPS; [[maybe_unused]] constexpr uint32_t vec_size = KTraits::vec_size; [[maybe_unused]] constexpr uint32_t head_dim = KTraits::HEAD_DIM_VO; [[maybe_unused]] constexpr uint32_t num_smem_stages = KTraits::NUM_SMEM_STAGES; [[maybe_unused]] constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; // control flow metadata const uint32_t warp_idx = threadIdx.x / 32; const uint32_t tx = (threadIdx.x % 32) % bdx, ty = (threadIdx.x % 32) / bdx; const uint32_t worker_id = blockIdx.y * num_warps + warp_idx; const uint32_t num_workers = gridDim.x * gridDim.y * gridDim.z * num_warps; DTypeIn* v_smem = (DTypeIn*)smem + warp_idx * num_smem_stages * bdy * head_dim; // FIXME: fix the offset calculation float* s_smem = (float*)(smem + num_warps * num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + warp_idx * 32 * sizeof(float)); // V: [num_packed_qo_len x num_kv_tiles, num_kv_heads, head_dim] // v_merged: [qo_len, num_kv_heads, gqa_group_size, head_dim] #pragma unroll 1 for (uint32_t i = worker_id; i < num_packed_qo_len * num_kv_heads; i += num_workers) { PROFILER_EVENT_START(profiler_closure, PersistentProfileEventType::kReduction); // remap workload uint32_t packed_qo_idx = i / num_kv_heads; uint32_t kv_head_idx = i % num_kv_heads; const uint32_t num_index_sets = indptr[packed_qo_idx + 1] - indptr[packed_qo_idx]; if (num_index_sets == 0 || num_index_sets == 1) { // already write through, bypass PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kReduction); continue; } // index calculation auto partial_idx_to_offset = [&](uint32_t off) { return (indptr[packed_qo_idx] + off) * num_kv_heads + kv_head_idx; }; auto merge_idx_to_offset = [&]() { // NOTE (Yilong): qo_head_idx has been calculated in schedule.plan return o_indices[packed_qo_idx] + kv_head_idx * gqa_group_size; }; state_t st; #pragma unroll for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { cp_async::pred_load( v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, V + partial_idx_to_offset(iter * bdy + ty) * head_dim + tx * vec_size, (iter * bdy + ty) < num_index_sets); cp_async::commit_group(); } #pragma unroll 4 for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) { if (iter % bdx == 0) { s_smem[ty * bdx + tx] = iter * bdy + (ty * bdx + tx) < num_index_sets ? S[partial_idx_to_offset(iter * bdy + ty * bdx + tx)] : 0.f; __syncwarp(); } cp_async::wait_group(); __syncwarp(); vec_t v; v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size); if (iter * bdy + ty < num_index_sets) { float s = s_smem[(iter % bdx) * bdy + ty]; st.merge(v, s, 1); } __syncwarp(); cp_async::pred_load( v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, V + partial_idx_to_offset((iter + num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, (iter + num_smem_stages) * bdy + ty < num_index_sets); cp_async::commit_group(); } cp_async::wait_group<0>(); __syncwarp(); st.normalize(); if constexpr (bdy > 1) { warp_sync_state(st, v_smem, s_smem, tx, ty); st.normalize(); } st.o.cast_store(v_merged + merge_idx_to_offset() * head_dim + tx * vec_size); if (s_merged != nullptr) { s_merged[merge_idx_to_offset()] = st.get_lse(); } PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kReduction); } } }; template cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params params_2, const uint32_t num_blks_x, const uint32_t num_blks_y, const cudaStream_t stream) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; using IdType = typename Params::IdType; constexpr uint32_t NUM_WARPS_Q_1 = get_num_warps_q(CTA_TILE_Q_1); constexpr uint32_t NUM_WARPS_KV_1 = get_num_warps_kv(CTA_TILE_Q_1); constexpr uint32_t NUM_MMA_Q_1 = get_num_mma_q(CTA_TILE_Q_1); constexpr uint32_t NUM_MMA_KV_1 = 4; constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; using KTraits1 = KernelTraits; constexpr uint32_t NUM_WARPS_Q_2 = get_num_warps_q(CTA_TILE_Q_2); constexpr uint32_t NUM_WARPS_KV_2 = get_num_warps_kv(CTA_TILE_Q_2); constexpr uint32_t NUM_MMA_Q_2 = get_num_mma_q(CTA_TILE_Q_2); constexpr uint32_t NUM_MMA_KV_2 = 2; using KTraits2 = KernelTraits; // Attention state reduction kernel constexpr uint32_t NUM_THREADS = KTraits1::NUM_THREADS > KTraits2::NUM_THREADS ? KTraits1::NUM_THREADS : KTraits2::NUM_THREADS; using ReductionKTraits = StateReductionKernelTraits; size_t smem_size = max(sizeof(typename KTraits1::SharedStorage), sizeof(typename KTraits2::SharedStorage)); smem_size = max(smem_size, ReductionKTraits::SMEM_SIZE); // Launch persistent kernel auto kernel = PersistentKernelTemplate, BlockBatchPagedAttentionPersistent, BlockBatchReductionPersistent>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); dim3 nblks(num_blks_x, num_blks_y); dim3 nthrs(NUM_THREADS); void* args[] = {(void*)¶ms_1, (void*)¶ms_2}; FLASHINFER_CUDA_CALL( cudaLaunchCooperativeKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); return cudaSuccess; } }; // namespace flashinfer #endif // FLASHINFER_PERSISTENT_CUH_