/* * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri * Dao. Licensed under the BSD 3-Clause. * * Modified by the FlashInfer team. */ #ifndef FLASHINFER_ATTENTION_HOPPER_MAINLOOP_MMA_CUH_ #define FLASHINFER_ATTENTION_HOPPER_MAINLOOP_MMA_CUH_ #include #include #include #include namespace flashinfer { template CUTLASS_DEVICE void mma_f16( const Params& mainloop_params, AttentionVariant& variant, MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, PipelineState& smem_pipe_read_k, PipelineState& smem_pipe_read_v, FrgTensorO& tOrO, AttentionUpdater& attention_updater, int kv_tile_idx_count, int swa_begin_kv_tile_idx, int swa_end_kv_tile_idx, int thread_idx, int work_idx, int q_tile_idx, SharedStorage& shared_storage, const int32_t qo_len, const int32_t kv_len, const int32_t qo_head_idx, const int32_t kv_head_idx, const uint32_t prefix_len, uint16_t* token_pos_in_items, const int num_kv_tiles_outside_items_window = 0, const int num_kv_tiles_prefix = 0) { using DTypeQ = typename Ktraits::DTypeQ; using DTypeKV = typename Ktraits::DTypeKV; using IdType = typename Ktraits::IdType; using TileShape_QKD = typename Ktraits::TileShape_QKD; static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS; using SmemLayoutQ = typename Ktraits::SmemLayoutQ; using SmemLayoutK = typename Ktraits::SmemLayoutK; using SmemLayoutV = typename Ktraits::SmemLayoutV; using SmemLayoutVt = typename Ktraits::SmemLayoutVt; static_assert(is_rmem::value, "O tensor must be rmem resident."); static constexpr int CTA_Q = get<0>(TileShape_QKD{}); static constexpr int CTA_KV = get<1>(TileShape_QKD{}); Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{}); typename Ktraits::TiledMmaQK tiled_mma_qk; typename Ktraits::TiledMmaPV tiled_mma_pv; auto threadMmaQK = tiled_mma_qk.get_thread_slice(thread_idx); auto threadMmaPV = tiled_mma_pv.get_thread_slice(thread_idx); Tensor tSrQ = threadMmaQK.partition_fragment_A(sQ); Tensor tSrK = threadMmaQK.partition_fragment_B(sK); Tensor tOrV = threadMmaPV.partition_fragment_B(sVt); auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); }; tiled_mma_pv.accumulate_ = GMMA::ScaleOut::Zero; int kv_tile_idx = kv_tile_idx_count - 1; cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); } Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); consumer_wait(pipeline_k, smem_pipe_read_k); WarpScheduler::barrier_sync(); gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); WarpScheduler::barrier_arrive(); if (work_idx != 0) { int lane_predicate = cute::elect_one_sync(); if (cutlass::canonical_warp_idx_sync() == Ktraits::NUM_WARPS - 1 && lane_predicate) { #pragma unroll for (uint32_t cta_id = 0; cta_id < 1; ++cta_id) { shared_storage.barrier_O.arrive(cta_id, lane_predicate); } } } warpgroup_wait<0>(); pipeline_k.consumer_release(smem_pipe_read_k); ++smem_pipe_read_k; auto col_limit_right = [&](int qo_idx) { return qo_idx + 1 + kv_len - qo_len; }; auto col_limit_left = [&](int qo_idx) { return qo_idx + kv_len - qo_len - mainloop_params.window_left; }; auto mask_multi_item_scoring = [&](decltype(tSrS)& tSrS, int i, int qo_idx, int kv_idx) { const uint32_t idx_in_original_seq = qo_idx + kv_len - qo_len; const bool out_of_boundary = kv_idx > idx_in_original_seq || (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))); const bool is_prefix = idx_in_original_seq < prefix_len; uint16_t token_pos_in_items_regs = 0; // Only access idx_in_original_seq >= prefix_len && idx_in_original_seq < kv_len to avoid // out-of-bounds memory access if (idx_in_original_seq >= prefix_len & idx_in_original_seq < kv_len) { token_pos_in_items_regs = __ldca(token_pos_in_items + idx_in_original_seq - prefix_len); } if (out_of_boundary || is_prefix) { tSrS(i) = out_of_boundary ? (AttentionUpdater::fill_value) : tSrS(i); } else { tSrS(i) = (kv_idx < prefix_len | (idx_in_original_seq < kv_idx + token_pos_in_items_regs)) ? tSrS(i) : (AttentionUpdater::fill_value); } }; auto mask_multi_item_scoring_assume_in_bound = [&](decltype(tSrS)& tSrS, int i, int qo_idx, int kv_idx) { const uint32_t idx_in_original_seq = qo_idx + kv_len - qo_len; const bool is_prefix = idx_in_original_seq < prefix_len; if (is_prefix) { tSrS(i) = AttentionUpdater::fill_value; } else { uint16_t token_pos_in_items_regs = 0; // Only access idx_in_original_seq >= prefix_len && idx_in_original_seq < kv_len to avoid // out-of-bounds memory access if (idx_in_original_seq >= prefix_len & idx_in_original_seq < kv_len) { token_pos_in_items_regs = __ldca(token_pos_in_items + idx_in_original_seq - prefix_len); } tSrS(i) = (kv_idx < prefix_len | (idx_in_original_seq < kv_idx + token_pos_in_items_regs)) ? tSrS(i) : (AttentionUpdater::fill_value); } }; auto kv_tile_idx_decrement = [&](int kv_tile_idx) { int result = kv_tile_idx - 1; if constexpr (MULTIITEMSCORING) { if ((kv_tile_idx == num_kv_tiles_outside_items_window) & (kv_tile_idx >= num_kv_tiles_prefix)) { result = num_kv_tiles_prefix - 1; } } return result; }; { Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); Tensor tScS = threadMmaQK.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; int kv_idx = get<1>(tScS(i)) + kv_tile_idx * CTA_KV; tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, qo_head_idx, kv_head_idx); if constexpr (MULTIITEMSCORING) { mask_multi_item_scoring(tSrS, i, qo_idx, kv_idx); } else if constexpr (!CAUSAL) { // Just masking based on col if (kv_idx >= kv_len) { tSrS(i) = AttentionUpdater::fill_value; } } else { if (kv_idx >= std::min(kv_len, col_limit_right(qo_idx))) { tSrS(i) = AttentionUpdater::fill_value; } } if constexpr (LEFT_SLIDING_WINDOW) { if (kv_idx < col_limit_left(qo_idx)) { tSrS(i) = AttentionUpdater::fill_value; } } } } attention_updater.update(tSrS); Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())); constexpr int n_masking_steps = MULTIITEMSCORING ? (cute::ceil_div(CTA_Q, CTA_KV) + 1) : (CAUSAL ? cute::ceil_div(CTA_Q, CTA_KV) : 0); // masking loops // ziangl@nvidia.com: for multi item scoring, we use this loop only to mask along the diagonal #pragma unroll for (int masking_step = 0; masking_step < n_masking_steps && kv_tile_idx > swa_begin_kv_tile_idx; ++masking_step, kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx)) { Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); consumer_wait(pipeline_k, smem_pipe_read_k); WarpScheduler::barrier_sync(); gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); if (masking_step > 0) { attention_updater.rescale_o(tOrO); } consumer_wait(pipeline_v, smem_pipe_read_v); gemm(tiled_mma_pv, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); WarpScheduler::barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read_k); // release K Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); Tensor tScS = threadMmaQK.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; int kv_idx = get<1>(tScS(i)) + kv_tile_idx_decrement(kv_tile_idx) * CTA_KV; tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, qo_head_idx, kv_head_idx); if (MULTIITEMSCORING) { mask_multi_item_scoring(tSrS, i, qo_idx, kv_idx); } else { if (kv_idx >= col_limit_right(qo_idx)) { tSrS(i) = AttentionUpdater::fill_value; } } if constexpr (LEFT_SLIDING_WINDOW) { if (kv_idx < col_limit_left(qo_idx)) { tSrS(i) = AttentionUpdater::fill_value; } } } attention_updater.update(tSrS); warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V ++smem_pipe_read_k; ++smem_pipe_read_v; cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); } #pragma unroll 1 for (; kv_tile_idx > swa_end_kv_tile_idx + 1; kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx)) { Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); consumer_wait(pipeline_k, smem_pipe_read_k); WarpScheduler::barrier_sync(); gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); attention_updater.rescale_o(tOrO); consumer_wait(pipeline_v, smem_pipe_read_v); gemm(tiled_mma_pv, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); WarpScheduler::barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read_k); // release K // #pragma unroll Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); Tensor tScS = threadMmaQK.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; int kv_idx = get<1>(tScS(i)) + kv_tile_idx_decrement(kv_tile_idx) * CTA_KV; tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, qo_head_idx, kv_head_idx); } if constexpr (MULTIITEMSCORING) { // auto nums_tiles_outside_causal_diagonal = kv_tile_idx_count - cute::ceil_div(CTA_Q, // CTA_KV); if (kv_tile_idx >= num_kv_tiles_prefix - 1) { #pragma unroll for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; int kv_idx = get<1>(tScS(i)) + kv_tile_idx_decrement(kv_tile_idx) * CTA_KV; mask_multi_item_scoring_assume_in_bound(tSrS, i, qo_idx, kv_idx); } } } attention_updater.update(tSrS); warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V ++smem_pipe_read_k; ++smem_pipe_read_v; cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); } if constexpr (LEFT_SLIDING_WINDOW) { #pragma unroll 1 for (; kv_tile_idx > swa_begin_kv_tile_idx; --kv_tile_idx) { Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_QKD{})); consumer_wait(pipeline_k, smem_pipe_read_k); WarpScheduler::barrier_sync(); gemm(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); attention_updater.rescale_o(tOrO); consumer_wait(pipeline_v, smem_pipe_read_v); gemm(tiled_mma_pv, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); WarpScheduler::barrier_arrive(); warpgroup_wait<1>(); pipeline_k.consumer_release(smem_pipe_read_k); // release K Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_QKD{})); Tensor tScS = threadMmaQK.partition_C(cS); #pragma unroll for (int i = 0; i < size(tSrS); ++i) { int qo_idx = get<0>(tScS(i)) + q_tile_idx * CTA_Q; int kv_idx = get<1>(tScS(i)) + (kv_tile_idx - 1) * CTA_KV; tSrS(i) = variant.LogitsTransform(mainloop_params, tSrS(i), /*batch_idx=*/0, qo_idx, kv_idx, qo_head_idx, kv_head_idx); if (kv_idx < col_limit_left(qo_idx)) { tSrS(i) = AttentionUpdater::fill_value; } } attention_updater.update(tSrS); warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V ++smem_pipe_read_k; ++smem_pipe_read_v; cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); } } // Tell warp 0 that smem_q is ready cutlass::arch::NamedBarrier::arrive(NUM_MMA_THREADS + Ktraits::NUM_PRODUCER_THREADS, /*id=*/static_cast(NamedBarriers::kQueryEmpty)); attention_updater.rescale_o(tOrO); consumer_wait(pipeline_v, smem_pipe_read_v); gemm(tiled_mma_pv, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); attention_updater.finalize(tSrS); warpgroup_wait<0>(); pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang ++smem_pipe_read_v; attention_updater.rescale_o(tOrO); return; } } // namespace flashinfer #endif // FLASHINFER_ATTENTION_HOPPER_MAINLOOP_MMA_CUH_