/* * Copyright (c) 2024 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_ #define FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_ #include #include #include #include #include "../../math.cuh" #include "block_sparse_gather.cuh" #include "cute/tensor.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/pipeline/pipeline.hpp" #include "named_barrier.cuh" #include "utils.cuh" namespace flashinfer { using namespace cute; template struct SparseCollectiveMainloop { using DTypeQ = typename Ktraits::DTypeQ; using DTypeKV = typename Ktraits::DTypeKV; using IdType = typename Ktraits::IdType; using TileShape_QKD = typename Ktraits::TileShape_QKD; using TileShape_PDV = typename Ktraits::TileShape_PDV; static constexpr int CTA_Q = get<0>(TileShape_QKD{}); static constexpr int CTA_KV = get<1>(TileShape_QKD{}); static constexpr int NUM_STAGES = Ktraits::NUM_STAGES; static constexpr int HEAD_DIM_QK = Ktraits::HEAD_DIM_QK; static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO; static_assert(HEAD_DIM_QK == HEAD_DIM_VO); static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup; using GmemTiledCopyQ = cute::SM90_TMA_LOAD; static constexpr auto AlignmentKV = 128 / cutlass::sizeof_bits::value; using AlignmentTypeKV = cute::uint_byte_t(sizeof(DTypeKV)) * AlignmentKV>; // NOTE(Zihao): use SM80_CP_ASYNC for sparse loading of KV-cache using GmemCopyAtomKV = cute::Copy_Atom, DTypeKV>; using GmemTiledCopyK = decltype(cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy< GmemCopyAtomKV, NUM_COPY_THREADS, AlignmentKV, cutlass::detail::TagToStrideB_t, decltype(cute::get<1>(TileShape_QKD{})), decltype(cute::get<2>(TileShape_QKD{}))>()); using GmemTiledCopyV = decltype(cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy< GmemCopyAtomKV, NUM_COPY_THREADS, AlignmentKV, cutlass::detail::TagToStrideB_t, decltype(cute::get<2>(TileShape_PDV{})), decltype(cute::get<1>(TileShape_PDV{}))>()); using SmemLayoutQ = typename Ktraits::SmemLayoutQ; using SmemLayoutK = typename Ktraits::SmemLayoutK; using SmemLayoutV = typename Ktraits::SmemLayoutV; using SmemLayoutVt = typename Ktraits::SmemLayoutVt; using ShapeT = cute::Shape; using StrideT = cute::Shape; // (N, D, H) using LayoutT = cute::Layout; using ShapeLseT = cute::Shape; using StrideLseT = cute::Shape<_1, int64_t>; using LayoutLseT = cute::Layout; using TMA_Q = decltype(make_tma_copy( GmemTiledCopyQ{}, make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideT{}, int32_t(0)), StrideT{}), SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{})); // no mcast for Q static constexpr bool USE_TMA_LOAD_KV = false; static constexpr int NUM_MMA_THREADS = size(typename Ktraits::TiledMmaQK{}); using MainloopPipeline = typename Ktraits::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; using PipelineState = typename MainloopPipeline::PipelineState; static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); static constexpr bool UseSchedulerBarrier = cutlass::sizeof_bits_v == 8 ? HEAD_DIM_VO >= 128 : HEAD_DIM_VO <= 128; using WarpScheduler = WarpScheduler; // Host side kernel arguments struct Arguments { DTypeQ const* Q_ptr; LayoutT layout_Q; DTypeKV const* K_ptr; LayoutT layout_K; DTypeKV const* V_ptr; LayoutT layout_V; IdType const* kv_indices; int window_left; AdditionalParams additional_params; }; // Device side kernel params struct Params { LayoutT layout_Q; LayoutT layout_K; LayoutT layout_V; TMA_Q tma_load_Q; DTypeKV* K_ptr; DTypeKV* V_ptr; IdType* kv_indices; int window_left; AdditionalParams additional_params; }; static Params to_underlying_arguments(Arguments const& args) { Tensor mQ = make_tensor(make_gmem_ptr(args.Q_ptr), args.layout_Q); TMA_Q tma_load_Q = make_tma_copy(GmemTiledCopyQ{}, mQ, SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{}); return {args.layout_Q, args.layout_K, args.layout_V, tma_load_Q, const_cast(args.K_ptr), const_cast(args.V_ptr), const_cast(args.kv_indices), args.window_left, args.additional_params}; } CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& mainloop_params) { cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor()); } CUTLASS_DEVICE int get_num_kv_tiles(Params const& mainloop_params, int q_tile_idx, const int qo_len, const int kv_len) { static constexpr int CTA_Q = get<0>(TileShape_QKD{}); static constexpr int CTA_KV = get<1>(TileShape_QKD{}); int num_kv_tiles = cute::ceil_div(kv_len, CTA_KV); if constexpr (CAUSAL) { num_kv_tiles = std::min(num_kv_tiles, cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV)); } if constexpr (MULTIITEMSCORING) { num_kv_tiles = std::min(num_kv_tiles, cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV)); } return num_kv_tiles; } template CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v, SharedStorage& shared_storage, Scheduler& scheduler, typename Scheduler::Params const& scheduler_params, typename Scheduler::WorkTileInfo& work_tile_info, BlockCoord const& block_coord, int work_idx, const int num_kv_tiles_outside_items_window = 0, const int num_kv_tiles_prefix = 0) { int thread_idx = threadIdx.x; int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0); Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape()); auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] = block_coord; // Prepare the TMA loads Tensor gQ = get_local_tile_tensor(mQ, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr, qo_len)(_, _, q_tile_idx); // (Q, D) Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{}, group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA) int num_kv_tiles = get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len); int kv_tile_idx = num_kv_tiles - 1; int swa_begin_kv_tile_idx = 0; if constexpr (LEFT_SLIDING_WINDOW) { swa_begin_kv_tile_idx = get_swa_begin_kv_tile_idx(mainloop_params.window_left, q_tile_idx, qo_len, kv_len); } constexpr int HEAD_DIM_QK = get<2>(TileShape_QKD{}); constexpr int HEAD_DIM_VO = get<1>(TileShape_PDV{}); constexpr int CTA_KV = get<1>(TileShape_QKD{}); auto indexed_gather = BlockSparseIndexedGather(mainloop_params.kv_indices + kv_indptr); Tensor mK = make_block_sparse_tensor( // (kv_len, D_K) make_gmem_ptr(mainloop_params.K_ptr + kv_head_idx * stride<2>(mainloop_params.layout_K)), make_shape(kv_len, HEAD_DIM_QK), stride<0>(mainloop_params.layout_K), indexed_gather); Tensor mV = make_block_sparse_tensor( // (kv_len, D_V) make_gmem_ptr(mainloop_params.V_ptr + kv_head_idx * stride<2>(mainloop_params.layout_V)), make_shape(kv_len, HEAD_DIM_VO), stride<0>(mainloop_params.layout_V), indexed_gather); Tensor gK = local_tile(mK, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D_K, kv) Tensor gV = local_tile(mV, select<2, 1>(TileShape_PDV{}), make_coord(_, _0{})); // (KV, D_V, kv) Tensor cK = cute::make_identity_tensor(gK.shape()); Tensor cV = cute::make_identity_tensor(gV.shape()); GmemTiledCopyK gmem_tiled_copy_k; GmemTiledCopyV gmem_tiled_copy_v; auto gmem_thr_copy_k = gmem_tiled_copy_k.get_slice(thread_idx); auto gmem_thr_copy_v = gmem_tiled_copy_v.get_slice(thread_idx); Tensor tKgK = gmem_thr_copy_k.partition_S(gK); // (CPY, CPY_KV, CPY_D, kv) Tensor tKsK = gmem_thr_copy_k.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE) Tensor tVgV = gmem_thr_copy_v.partition_S(gV); // (CPY, CPY_KV, CPY_D, kv) Tensor tVsV = gmem_thr_copy_v.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE) Tensor tKcK = gmem_thr_copy_k.partition_D(cK); // (CPY, CPY_KV, CPY_D) Tensor tKcKGroup = flatten_1(tKcK); // (CPY, (CPY_KV, CPY_D)) Tensor tVcV = gmem_thr_copy_v.partition_D(cV); // (CPY, CPY_KV, CPY_D) Tensor tVcVGroup = flatten_1(tVcV); // (CPY, (CPY_KV, CPY_D)) int valid_last_kv_tile_size = std::min(kv_len - kv_tile_idx * CTA_KV, CTA_KV); auto k_predicate_fn = [&](auto coords) { auto s_coords = tKcKGroup(_0{}, coords); return elem_less(get<0>(s_coords), valid_last_kv_tile_size); }; auto v_predicate_fn = [&](auto coords) { auto s_coords = tVcVGroup(_0{}, coords); return elem_less(get<0>(s_coords), valid_last_kv_tile_size); }; auto kv_tile_idx_decrement = [&](int kv_tile_idx) { int result = kv_tile_idx - 1; if constexpr (MULTIITEMSCORING) { if ((kv_tile_idx == num_kv_tiles_outside_items_window) & (kv_tile_idx >= num_kv_tiles_prefix)) { result = num_kv_tiles_prefix - 1; } } return result; }; // load last k-tile { pipeline_k.producer_acquire(smem_pipe_write_k); Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) Tensor tKsKiGroup = flatten_1(tKsK(_, _, _, smem_pipe_write_k.index())); // (CPY, (CPY_KV, CPY_D)) copy_if(gmem_tiled_copy_k, k_predicate_fn, tKgKiGroup, tKsKiGroup); pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_k; } // load Q tile if (warp_idx_in_warpgroup == 0) { cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + cutlass::NumThreadsPerWarp, static_cast(NamedBarriers::kQueryEmpty)); int lane_predicate = cute::elect_one_sync(); if (lane_predicate) { shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); copy(mainloop_params.tma_load_Q.with( reinterpret_cast( shared_storage.barrier_Q), /*mcast_mask=*/0), tQgQ, tQsQ); } } shared_storage.barrier_O.wait((work_idx + 1) % 2); if (kv_tile_idx == swa_begin_kv_tile_idx) { pipeline_v.producer_acquire(smem_pipe_write_v); Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) Tensor tVsViGroup = flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) copy_if(gmem_tiled_copy_v, v_predicate_fn, tVgViGroup, tVsViGroup); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_v; } else { // load second last k-tile and last v-tile pipeline_k.producer_acquire(smem_pipe_write_k); Tensor tKgKi = tKgK(_, _, _, kv_tile_idx_decrement(kv_tile_idx)); // (CPY, CPY_KV, CPY_D) Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) copy(gmem_tiled_copy_k, tKgKi, tKsKi); pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_k; pipeline_v.producer_acquire(smem_pipe_write_v); Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D)) Tensor tVsViGroup = flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D)) copy_if(gmem_tiled_copy_v, v_predicate_fn, tVgViGroup, tVsViGroup); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx); ++smem_pipe_write_v; // load remaining k/v tiles #pragma unroll 2 for (; kv_tile_idx > swa_begin_kv_tile_idx; kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx)) { pipeline_k.producer_acquire(smem_pipe_write_k); Tensor tKgKi = tKgK(_, _, _, kv_tile_idx_decrement(kv_tile_idx)); // (CPY, CPY_KV, CPY_D) Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D) copy(gmem_tiled_copy_k, tKgKi, tKsKi); pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_k; pipeline_v.producer_acquire(smem_pipe_write_v); Tensor tVgVi = tVgV(_, _, _, kv_tile_idx); // (CPY, CPY_KV, CPY_D) Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, CPY_KV, CPY_D) copy(gmem_tiled_copy_v, tVgVi, tVsVi); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_v; } scheduler.prefetch_next_work(scheduler_params, work_tile_info); // load first v tile { pipeline_v.producer_acquire(smem_pipe_write_v); Tensor tVgVi = tVgV(_, _, _, 0); // (CPY, (CPY_KV, CPY_D)) Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, (CPY_KV, CPY_D)) copy(gmem_tiled_copy_v, tVgVi, tVsVi); pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive); ++smem_pipe_write_v; } } scheduler.broadcast_next_work(work_tile_info); } CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) { pipeline_k.producer_tail(smem_pipe_write_k); pipeline_v.producer_tail(smem_pipe_write_v); } }; } // namespace flashinfer #endif // FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_