/* * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri * Dao. Licensed under the BSD 3-Clause. * * Modified by the FlashInfer team. */ #ifndef FLASHINFER_ATTENTION_HOPPER_KERNEL_TRAITS_CUH_ #define FLASHINFER_ATTENTION_HOPPER_KERNEL_TRAITS_CUH_ #include #include "../../cutlass_utils.cuh" #include "cute/algorithm/copy.hpp" #include "cute/atom/mma_atom.hpp" #include "cutlass/cutlass.h" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/layout/layout.h" #include "cutlass/numeric_types.h" #include "cutlass/pipeline/pipeline.hpp" namespace flashinfer { using namespace cute; template struct SharedStorageQKVO { cute::array_aligned> smem_q; cute::array_aligned> smem_k; union { cute::array_aligned> smem_v; cute::array_aligned> smem_o; }; struct { cutlass::arch::ClusterTransactionBarrier barrier_Q; cutlass::arch::ClusterBarrier barrier_O; typename MainloopPipeline::SharedStorage pipeline_k; typename MainloopPipeline::SharedStorage pipeline_v; }; }; template struct AttentionKernelTraits { using AttentionVariant = AttentionVariant_; using DTypeQ = DTypeQ_; using DTypeKV = DTypeKV_; using DTypeO = DTypeO_; using IdType = IdType_; using DTypeQKAccum = float; static constexpr int CTA_Q = CTA_Q_; static_assert(CTA_Q % 64 == 0); static constexpr int CTA_KV = CTA_KV_; static constexpr int HEAD_DIM_QK = HEAD_DIM_QK_; static constexpr int HEAD_DIM_VO = HEAD_DIM_VO_; static_assert(HEAD_DIM_QK % 32 == 0); static_assert(HEAD_DIM_VO % 32 == 0); static constexpr int NUM_WARPS = ((CTA_Q / 64) + 1) * 4; static constexpr int NUM_THREADS = NUM_WARPS * cutlass::NumThreadsPerWarp; // NOTE(Zihao): the following constant should only be used when TMA is enabled, // where only one warp inside a warp group is used for TMA. static constexpr int NUM_PRODUCER_THREADS = cutlass::NumThreadsPerWarp; using TileShape_QKD = Shape, Int, Int>; using TileShape_PDV = Shape, Int, Int>; static constexpr int NUM_STAGES = NUM_STAGES_; using AtomLayoutQKD = Layout, _1, _1>>; using TiledMmaQK = decltype(cute::make_tiled_mma( cute::GMMA::ss_op_selector(), AtomLayoutQKD{})); using TiledMmaPV = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(), AtomLayoutQKD{})); static constexpr int NUM_MMA_THREADS = size(TiledMmaQK{}); using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector< GMMA::Major::K, DTypeQ, decltype(cute::get<0>(TileShape_QKD{})), decltype(cute::get<2>(TileShape_QKD{}))>()); using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_QKD{}))); using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector< GMMA::Major::K, DTypeKV, decltype(cute::get<1>(TileShape_QKD{})), decltype(cute::get<2>(TileShape_QKD{}))>()); using SmemLayoutK = decltype(tile_to_shape( SmemLayoutAtomK{}, make_shape(shape<1>(TileShape_QKD{}), shape<2>(TileShape_QKD{}), Int{}))); using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector< GMMA::Major::K, DTypeKV, decltype(cute::get<2>(TileShape_PDV{})), decltype(cute::get<1>(TileShape_PDV{}))>()); using SmemLayoutV = decltype(tile_to_shape( SmemLayoutAtomV{}, make_shape(get<2>(TileShape_PDV{}), get<1>(TileShape_PDV{}), Int{}))); // Note this is the transpose in terms of the view, not in terms of memory. using SmemLayoutVt = decltype(composition( SmemLayoutV{}, make_ordered_layout(make_shape(get<1>(TileShape_PDV{}), get<2>(TileShape_PDV{}), Int{}), Step<_2, _1, _3>{}))); using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector< GMMA::Major::K, DTypeO, decltype(cute::get<0>(TileShape_PDV{})), decltype(cute::get<1>(TileShape_PDV{}))>()); using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 1>(TileShape_PDV{}))); using MainloopPipeline = std::conditional_t, typename cutlass::PipelineAsync>; using PipelineState = typename cutlass::PipelineState; using SharedStorage = SharedStorageQKVO; }; } // namespace flashinfer #endif // FLASHINFER_ATTENTION_HOPPER_KERNEL_TRAITS_CUH_