/* * Copyright (c) 2023 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FLASHINFER_PREFILL_CUH_ #define FLASHINFER_PREFILL_CUH_ #include #include #include #include #include #include "../cp_async.cuh" #include "../fastdiv.cuh" #ifdef FP16_QK_REDUCTION_SUPPORTED #include "../fp16.h" #endif #include "../frag_layout_swizzle.cuh" #include "../math.cuh" #include "../mma.cuh" #include "../page.cuh" #include "../permuted_smem.cuh" #include "../pos_enc.cuh" #include "../utils.cuh" #include "cascade.cuh" #include "mask.cuh" #include "variants.cuh" namespace flashinfer { DEFINE_HAS_MEMBER(maybe_q_rope_offset) DEFINE_HAS_MEMBER(maybe_k_rope_offset) DEFINE_HAS_MEMBER(maybe_prefix_len_ptr) DEFINE_HAS_MEMBER(maybe_token_pos_in_items_ptr) DEFINE_HAS_MEMBER(token_pos_in_items_len) DEFINE_HAS_MEMBER(maybe_max_item_len_ptr) namespace cg = cooperative_groups; using cp_async::SharedMemFillMode; using mma::MMAMode; constexpr uint32_t WARP_SIZE = 32; constexpr uint32_t get_num_warps_q(const uint32_t cta_tile_q) { if (cta_tile_q > 16) { return 4; } else { return 1; } } constexpr uint32_t get_num_warps_kv(const uint32_t cta_tile_kv) { return 4 / get_num_warps_q(cta_tile_kv); } constexpr uint32_t get_num_mma_q(const uint32_t cta_tile_q) { if (cta_tile_q > 64) { return 2; } else { return 1; } } template struct SharedStorageQKVO { union { struct { alignas(16) DTypeQ q_smem[CTA_TILE_Q * HEAD_DIM_QK]; alignas(16) DTypeKV k_smem[CTA_TILE_KV * HEAD_DIM_QK]; alignas(16) DTypeKV v_smem[CTA_TILE_KV * HEAD_DIM_VO]; }; struct { // NOTE(Zihao): synchronize attention states across warps alignas( 16) std::conditional_t cta_sync_o_smem; alignas(16) std::conditional_t cta_sync_md_smem; }; alignas(16) DTypeO smem_o[CTA_TILE_Q * HEAD_DIM_VO]; }; }; template struct KernelTraits { static constexpr uint32_t NUM_STAGES = 1; // used for BatchAttention Template static constexpr MaskMode MASK_MODE = MASK_MODE_; static constexpr uint32_t NUM_MMA_Q = NUM_MMA_Q_; static constexpr uint32_t NUM_MMA_KV = NUM_MMA_KV_; static constexpr uint32_t NUM_MMA_D_QK = NUM_MMA_D_QK_; static constexpr uint32_t NUM_MMA_D_VO = NUM_MMA_D_VO_; static constexpr uint32_t NUM_WARPS_Q = NUM_WARPS_Q_; static constexpr uint32_t NUM_WARPS_KV = NUM_WARPS_KV_; static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * WARP_SIZE; static constexpr uint32_t NUM_WARPS = NUM_WARPS_Q * NUM_WARPS_KV; static constexpr uint32_t HEAD_DIM_QK = NUM_MMA_D_QK * 16; static constexpr uint32_t HEAD_DIM_VO = NUM_MMA_D_VO * 16; static constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM_QK / upcast_size(); static constexpr uint32_t UPCAST_STRIDE_K = HEAD_DIM_QK / upcast_size(); static constexpr uint32_t UPCAST_STRIDE_V = HEAD_DIM_VO / upcast_size(); static constexpr uint32_t UPCAST_STRIDE_O = HEAD_DIM_VO / upcast_size(); static constexpr uint32_t CTA_TILE_Q = CTA_TILE_Q_; static constexpr uint32_t CTA_TILE_KV = NUM_MMA_KV * NUM_WARPS_KV * 16; static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::k128B; static constexpr SwizzleMode SWIZZLE_MODE_KV = (sizeof(DTypeKV_) == 1 && HEAD_DIM_VO == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; static constexpr uint32_t KV_THR_LAYOUT_ROW = SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 8; static constexpr uint32_t KV_THR_LAYOUT_COL = SWIZZLE_MODE_KV == SwizzleMode::k128B ? 8 : 4; static constexpr PosEncodingMode POS_ENCODING_MODE = POS_ENCODING_MODE_; using DTypeQ = DTypeQ_; using DTypeKV = DTypeKV_; using DTypeO = DTypeO_; using DTypeQKAccum = DTypeQKAccum_; using IdType = IdType_; using AttentionVariant = AttentionVariant_; static constexpr bool IsInvalid() { return ((NUM_MMA_D_VO < 4) || (NUM_MMA_D_VO == 4 && NUM_MMA_KV % 2 == 1) || (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && NUM_MMA_D_VO > 4 && NUM_MMA_D_VO % (2 * NUM_WARPS_Q) != 0) || (NUM_MMA_Q * (8 * NUM_MMA_D_VO + 2 * sizeof(DTypeQKAccum) * NUM_MMA_KV) >= 256) || (sizeof(DTypeKV) == 1 && NUM_MMA_KV * 2 % NUM_WARPS_Q != 0) || (sizeof(DTypeKV) == 1 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama)); } using SharedStorage = SharedStorageQKVO; #ifdef FP16_QK_REDUCTION_SUPPORTED template static constexpr DT getNegInf() { if constexpr (std::is_same::value) { return std::bit_cast(fp16_ieee_from_fp32_value(-math::inf)); } else { return static_cast(-math::inf); } } static constexpr DTypeQKAccum MaskFillValue = AttentionVariant::use_softmax ? getNegInf() : DTypeQKAccum(0.f); #else static_assert(!std::is_same::value, "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " "then recompile to support fp16 reduction"); static constexpr DTypeQKAccum MaskFillValue = AttentionVariant::use_softmax ? DTypeQKAccum(-math::inf) : DTypeQKAccum(0.f); #endif }; namespace { template __device__ __forceinline__ uint32_t get_warp_idx_q(const uint32_t tid_y = threadIdx.y) { if constexpr (KTraits::NUM_WARPS_Q == 1) { return 0; } else { return tid_y; } } template __device__ __forceinline__ uint32_t get_warp_idx_kv(const uint32_t tid_z = threadIdx.z) { if constexpr (KTraits::NUM_WARPS_KV == 1) { return 0; } else { return tid_z; } } template __device__ __forceinline__ uint32_t get_warp_idx(const uint32_t tid_y = threadIdx.y, const uint32_t tid_z = threadIdx.z) { return get_warp_idx_kv(tid_z) * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid_y); } /*! * \brief Apply Llama style rotary embedding to two 16x16 fragments. * \tparam T The data type of the input fragments. * \param x_first_half First fragment x[offset:offset+16, j*16:(j+1)*16] * \param x_second_half Second fragment x[offset:offset*16, j*16+d/2:(j+1)*16+d/2] * \param rope_freq Rope frequency * \param offset The offset of the first row in both fragments. * \note The sin/cos computation is slow, especially for A100 GPUs which has low * non tensor-ops flops, will optimize in the future. */ template __device__ __forceinline__ void k_frag_apply_llama_rope(T* x_first_half, T* x_second_half, const float* rope_freq, const uint32_t kv_offset) { static_assert(sizeof(T) == 2); #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { float cos, sin, tmp; // 0 1 | 2 3 // --------- // 4 5 | 6 7 uint32_t i = reg_id / 4, j = (reg_id % 4) / 2; __sincosf(float(kv_offset + 8 * i) * rope_freq[2 * j + reg_id % 2], &sin, &cos); tmp = x_first_half[reg_id]; x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); } } template __device__ __forceinline__ void q_frag_apply_llama_rope(T* x_first_half, T* x_second_half, const float* rope_freq, const uint32_t qo_packed_offset, const uint_fastdiv group_size) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { float cos, sin, tmp; // 0 1 | 4 5 // --------- // 2 3 | 6 7 uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); __sincosf(float((qo_packed_offset + 8 * i) / group_size) * rope_freq[2 * j + reg_id % 2], &sin, &cos); tmp = x_first_half[reg_id]; x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); } } template __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T* x_first_half, T* x_second_half, const float* rope_freq, const uint32_t qo_packed_offset, const uint_fastdiv group_size, const IdType* q_rope_offset) { float pos[2] = {static_cast(q_rope_offset[qo_packed_offset / group_size]), static_cast(q_rope_offset[(qo_packed_offset + 8) / group_size])}; #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { float cos, sin, tmp; // 0 1 | 4 5 // --------- // 2 3 | 6 7 uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); __sincosf(pos[i] * rope_freq[2 * j + reg_id % 2], &sin, &cos); tmp = x_first_half[reg_id]; x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); } } /*! * \brief Produce k/v fragments from global memory to shared memory. * \tparam fill_mode The fill mode of the shared memory. * \tparam NUM_MMA_D_VO The number of fragments in y dimension. * \tparam NUM_MMA_KV The number of fragments in z dimension. * \tparam num_warps The number of warps in the threadblock. * \tparam T The data type of the input tensor. * \param smem The shared memory to store kv fragments. * \param gptr The global memory pointer. * \param kv_idx_base The base kv index. * \param kv_len The length of kv tensor. */ template __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* smem_offset, typename KTraits::DTypeKV** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, const uint32_t kv_len, const dim3 tid = threadIdx) { // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment using DTypeKV = typename KTraits::DTypeKV; constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); *gptr += 8 * upcast_size(); } kv_idx += NUM_WARPS * 4; *smem_offset = smem.template advance_offset_by_row(*smem_offset) - sizeof(DTypeKV) * NUM_MMA_D; *gptr += NUM_WARPS * 4 * stride_n - sizeof(DTypeKV) * NUM_MMA_D * upcast_size(); } *smem_offset -= CTA_TILE_KV * UPCAST_STRIDE; } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / num_warps static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_row(*smem_offset); kv_idx += NUM_WARPS * 8; *gptr += NUM_WARPS * 8 * stride_n; } *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } } template __device__ __forceinline__ void page_produce_kv(typename KTraits::SharedStorage* smem_storage, uint32_t* smem_offset, typename KTraits::DTypeKV* kv_ptr, const uint32_t kv_idx_base, const size_t* thr_local_kv_offset, const uint32_t kv_len, const uint32_t warp_idx, const uint32_t lane_idx) { // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment smem_t smem(produce_v ? smem_storage->v_smem : smem_storage->k_smem); using DType = typename KTraits::DTypeKV; using IdType = typename KTraits::IdType; constexpr SharedMemFillMode fill_mode = produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { DType* gptr = kv_ptr + thr_local_kv_offset[i]; #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); gptr += 8 * upcast_size(); } kv_idx += NUM_WARPS * 4; *smem_offset = smem.template advance_offset_by_row(*smem_offset) - sizeof(DType) * NUM_MMA_D; } *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / num_warps static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { DType* gptr = kv_ptr + thr_local_kv_offset[i]; smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); kv_idx += NUM_WARPS * 8; *smem_offset = smem.template advance_offset_by_row(*smem_offset); } *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } } template __device__ __forceinline__ void init_rope_freq(float (*rope_freq)[4], const float rope_rcp_scale, const float rope_rcp_theta, const uint32_t tid_x = threadIdx.x) { constexpr uint32_t HEAD_DIM = KTraits::NUM_MMA_D_QK * 16; const uint32_t lane_idx = tid_x; #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { #pragma unroll for (uint32_t j = 0; j < 4; ++j) { rope_freq[mma_d][j] = rope_rcp_scale * __powf(rope_rcp_theta, float(2 * ((mma_d * 16 + (j / 2) * 8 + (lane_idx % 4) * 2 + (j % 2)) % (HEAD_DIM / 2))) / float(HEAD_DIM)); } } } template __device__ __forceinline__ void init_states(typename KTraits::AttentionVariant variant, float (*o_frag)[KTraits::NUM_MMA_D_VO][8], typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { o_frag[mma_q][mma_d][reg_id] = 0.f; } } } if constexpr (variant.use_softmax) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { m[mma_q][j] = typename KTraits::DTypeQKAccum(-math::inf); d[mma_q][j] = 1.f; } } } } template __device__ __forceinline__ void load_q_global_smem( uint32_t packed_offset, const uint32_t qo_upper_bound, typename KTraits::DTypeQ* q_ptr_base, const uint32_t q_stride_n, const uint32_t q_stride_h, const uint_fastdiv group_size, smem_t* q_smem, const dim3 tid = threadIdx) { using DTypeQ = typename KTraits::DTypeQ; constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; const uint32_t lane_idx = tid.x, warp_idx_x = get_warp_idx_q(tid.y); if (get_warp_idx_kv(tid.z) == 0) { uint32_t q_smem_offset_w = q_smem->get_permuted_offset( warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2 * 2; ++j) { uint32_t q, r; group_size.divmod(packed_offset + lane_idx / 8 + mma_q * 16 + j * 4, q, r); const uint32_t q_idx = q; DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + (lane_idx % 8) * upcast_size(); #pragma unroll for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) { // load q fragment from gmem to smem q_smem->load_128b_async(q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); q_smem_offset_w = q_smem->template advance_offset_by_column<8>(q_smem_offset_w, mma_do); q_ptr += 8 * upcast_size(); } q_smem_offset_w = q_smem->template advance_offset_by_row<4, UPCAST_STRIDE_Q>(q_smem_offset_w) - 2 * KTraits::NUM_MMA_D_QK; } } } } template __device__ __forceinline__ void q_smem_inplace_apply_rotary( const uint32_t q_packed_idx, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, smem_t* q_smem, uint32_t* q_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { if (get_warp_idx_kv(tid.z) == 0) { constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; const uint32_t lane_idx = tid.x; uint32_t q_frag_local[2][4]; static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; #pragma unroll for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) { q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); uint32_t q_smem_offset_r_last_half = q_smem->template advance_offset_by_column( q_smem_offset_r_first_half, 0); q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_frag_apply_llama_rope( (typename KTraits::DTypeQ*)q_frag_local[0], (typename KTraits::DTypeQ*)q_frag_local[1], rope_freq[mma_di], q_packed_idx + kv_len * group_size - qo_len * group_size + mma_q * 16 + lane_idx / 4, group_size); q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, mma_di); } *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; } *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; } } template __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( const uint32_t q_packed_idx_base, const typename KTraits::IdType* q_rope_offset, smem_t* q_smem, const uint_fastdiv group_size, uint32_t* q_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { if (get_warp_idx_kv(tid.z) == 0) { constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; const uint32_t lane_idx = tid.x; uint32_t q_frag_local[2][4]; static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; #pragma unroll for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) { q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); uint32_t q_smem_offset_r_last_half = q_smem->template advance_offset_by_column( q_smem_offset_r_first_half, 0); q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_frag_apply_llama_rope_with_pos( (typename KTraits::DTypeQ*)q_frag_local[0], (typename KTraits::DTypeQ*)q_frag_local[1], rope_freq[mma_di], q_packed_idx_base + mma_q * 16 + lane_idx / 4, group_size, q_rope_offset); q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, q_frag_local[1]); q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, mma_di); } *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; } *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; } } template __device__ __forceinline__ void k_smem_inplace_apply_rotary( const uint32_t kv_idx_base, smem_t* k_smem, uint32_t* k_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { using DTypeKV = typename KTraits::DTypeKV; static_assert(sizeof(DTypeKV) == 2); constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; uint32_t k_frag_local[2][4]; const uint32_t lane_idx = tid.x; if constexpr (KTraits::NUM_MMA_D_QK == 4 && KTraits::NUM_WARPS_Q == 4) { static_assert(KTraits::NUM_WARPS_KV == 1); const uint32_t warp_idx = get_warp_idx_q(tid.y); // horizontal-axis: y // vertical-axis: z // | 1-16 | 16-32 | 32-48 | 48-64 | // | 1-16 | warp_idx=0 | warp_idx=1 | warp_idx=0 | warp_idx=1 | // | 16-32 | warp_idx=2 | warp_idx=3 | warp_idx=2 | warp_idx=3 | static_assert(KTraits::NUM_MMA_KV % 2 == 0, "when NUM_MMA_D_QK == 4, NUM_MMA_KV must be a multiple of 2"); uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + lane_idx / 4; *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + (warp_idx / 2) * 16 * UPCAST_STRIDE_K; #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_MMA_KV / 2; ++i) { uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; uint32_t mma_di = (warp_idx % 2); k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); uint32_t k_smem_offset_r_last_half = k_smem->template advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_frag_apply_llama_rope((DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], rope_freq[mma_di], kv_idx); k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); *k_smem_offset_r += 32 * UPCAST_STRIDE_K; kv_idx += 32; } *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) - ((warp_idx / 2) + KTraits::NUM_MMA_KV) * 16 * UPCAST_STRIDE_K; } else { const uint32_t warp_idx_x = get_warp_idx_q(tid.y), warp_idx_z = get_warp_idx_kv(tid.z); static_assert(KTraits::NUM_MMA_D_QK % (2 * KTraits::NUM_WARPS_Q) == 0); // horizontal axis: y // vertical axis: z // | (warp_idx_z, warp_idx_x) | 1-16 | 16-32 | 32-48 | 48-64 | ... // | 1-16*NUM_MMA_KV | (0, 0) | (0, 1) | (0, 2) | (0, 3) | ... // | 16*NUM_MMA_KV-32*NUM_MMA_KV | (1, 0) | (1, 1) | (1, 2) | (1, 3) | ... // ... uint32_t kv_idx = kv_idx_base + (warp_idx_z * KTraits::NUM_MMA_KV * 16) + lane_idx / 4; *k_smem_offset_r = *k_smem_offset_r ^ (0x2 * warp_idx_x); #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_MMA_KV; ++i) { uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; #pragma unroll for (uint32_t j = 0; j < KTraits::NUM_MMA_D_QK / (2 * KTraits::NUM_WARPS_Q); ++j) { uint32_t mma_di = warp_idx_x + j * KTraits::NUM_WARPS_Q; k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); uint32_t k_smem_offset_r_last_half = k_smem->template advance_offset_by_column( k_smem_offset_r_first_half, 0); k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_frag_apply_llama_rope((DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], rope_freq[mma_di], kv_idx); k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, k_frag_local[0]); k_smem_offset_r_first_half = k_smem->template advance_offset_by_column<2 * KTraits::NUM_WARPS_Q>( k_smem_offset_r_first_half, mma_di); } *k_smem_offset_r += 16 * UPCAST_STRIDE_K; kv_idx += 16; } *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * warp_idx_x)) - KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; } } template __device__ __forceinline__ void compute_qk( smem_t* q_smem, uint32_t* q_smem_offset_r, smem_t* k_smem, uint32_t* k_smem_offset_r, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8]) { constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; uint32_t a_frag[KTraits::NUM_MMA_Q][4], b_frag[4]; // compute q*k^T #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_QK; ++mma_d) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[mma_q]); *q_smem_offset_r = q_smem->template advance_offset_by_row<16, UPCAST_STRIDE_Q>(*q_smem_offset_r); } *q_smem_offset_r = q_smem->template advance_offset_by_column<2>(*q_smem_offset_r, mma_d) - KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { uint32_t b_frag_f8[2]; if (mma_d % 2 == 0) { k_smem->ldmatrix_m8n8x4_left_half(*k_smem_offset_r, b_frag_f8); } else { k_smem->ldmatrix_m8n8x4_right_half(*k_smem_offset_r, b_frag_f8); } b_frag_f8[0] = frag_layout_swizzle_16b_to_8b(b_frag_f8[0]); b_frag_f8[1] = frag_layout_swizzle_16b_to_8b(b_frag_f8[1]); vec_cast::cast<8>( (typename KTraits::DTypeQ*)b_frag, (typename KTraits::DTypeKV*)b_frag_f8); } else { k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); } *k_smem_offset_r = k_smem->template advance_offset_by_row<16, UPCAST_STRIDE_K>(*k_smem_offset_r); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { if constexpr (std::is_same_v) { if (mma_d == 0) { mma::mma_sync_m16n16k16_row_col_f16f16f32( s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } else { mma::mma_sync_m16n16k16_row_col_f16f16f32( s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } } else if (std::is_same_v) { if (mma_d == 0) { mma::mma_sync_m16n16k16_row_col_f16f16f16( (uint32_t*)s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } else { mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } } } } if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { if (mma_d % 2 == 1) { *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, mma_d / 2); } *k_smem_offset_r -= KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; } else { *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, mma_d) - KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; } } *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * 2; *k_smem_offset_r -= KTraits::NUM_MMA_D_QK * sizeof(typename KTraits::DTypeKV); } template __device__ __forceinline__ void logits_transform( const Params& params, typename KTraits::AttentionVariant variant, const uint32_t batch_idx, const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], const dim3 tid = threadIdx, const uint32_t kv_head_idx = blockIdx.z) { const uint32_t lane_idx = tid.x; uint32_t q[KTraits::NUM_MMA_Q][2], r[KTraits::NUM_MMA_Q][2]; float logits = 0., logitsTransformed = 0.; #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + 8 * j, q[mma_q][j], r[mma_q][j]); } } #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + 2 * (lane_idx % 4) + 8 * (reg_id / 4) + reg_id % 2; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; #ifdef FP16_QK_REDUCTION_SUPPORTED if constexpr (std::is_same::value) { logits = std::bit_cast(fp16_ieee_to_fp32_value(s_frag[mma_q][mma_kv][reg_id])); } else if constexpr (!std::is_same::value) { logits = s_frag[mma_q][mma_kv][reg_id]; } #else static_assert(!std::is_same::value, "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " "then recompile to support fp16 reduction"); logits = s_frag[mma_q][mma_kv][reg_id]; #endif logitsTransformed = variant.LogitsTransform(params, logits, batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx); #ifdef FP16_QK_REDUCTION_SUPPORTED if constexpr (std::is_same::value) { s_frag[mma_q][mma_kv][reg_id] = std::bit_cast(fp16_ieee_from_fp32_value(logitsTransformed)); } else if constexpr (!std::is_same::value) { s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; } #else s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; #endif } } } } template __device__ __forceinline__ void logits_mask( const Params& params, typename KTraits::AttentionVariant variant, const uint32_t batch_idx, const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, const uint32_t chunk_end, const uint_fastdiv group_size, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], const dim3 tid = threadIdx, const uint32_t kv_head_idx = blockIdx.z) { const uint32_t lane_idx = tid.x; constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; using DTypeQKAccum = typename KTraits::DTypeQKAccum; constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; uint32_t q[NUM_MMA_Q][2], r[NUM_MMA_Q][2]; #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + 8 * j, q[mma_q][j], r[mma_q][j]); } } #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + 2 * (lane_idx % 4) + 8 * (reg_id / 4) + reg_id % 2; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; const bool mask = (!(MASK_MODE == MaskMode::kCausal || MASK_MODE == MaskMode::kMultiItemScoring ? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= chunk_end)) : kv_idx >= chunk_end)) && variant.LogitsMask(params, batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx); s_frag[mma_q][mma_kv][reg_id] = (mask) ? s_frag[mma_q][mma_kv][reg_id] : (KTraits::MaskFillValue); } } } } template __device__ __forceinline__ void logits_mask_multi_item_scoring( const Params& params, typename KTraits::AttentionVariant variant, const uint32_t batch_idx, const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, const uint32_t kv_len, const uint32_t window_left, const uint32_t chunk_end, const uint_fastdiv group_size, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], // new arguments for compact description of mask const uint32_t prefix_len, uint16_t* token_pos_in_items, const uint32_t lane_idx = threadIdx.x, const uint32_t kv_head_idx = blockIdx.z) { constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; using DTypeQKAccum = typename KTraits::DTypeQKAccum; uint32_t q[NUM_MMA_Q][2], r[NUM_MMA_Q][2]; #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + 8 * j, q[mma_q][j], r[mma_q][j]); } } // prefetching global memory to registers uint16_t token_pos_in_items_regs[NUM_MMA_Q][(4 / 2)]; #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t eff_reg_id = 0; eff_reg_id < (4 / 2); ++eff_reg_id) { const uint32_t q_idx = q[mma_q][eff_reg_id]; // use __ldca to hint compiler to cache in L1 for further reuse by other tiles const int idx_in_original_seq = q_idx + kv_len - qo_len; if (idx_in_original_seq >= prefix_len & idx_in_original_seq < kv_len) { token_pos_in_items_regs[mma_q][eff_reg_id] = __ldca(token_pos_in_items + idx_in_original_seq - prefix_len); } } } #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + 2 * (lane_idx % 4) + 8 * (reg_id / 4) + reg_id % 2; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; const uint32_t idx_in_original_seq = q_idx + kv_len - qo_len; const bool out_of_boundary = kv_idx > idx_in_original_seq || (kv_idx >= chunk_end) || kv_idx + window_left < idx_in_original_seq; const bool is_prefix = idx_in_original_seq < prefix_len; if (out_of_boundary || is_prefix) { s_frag[mma_q][mma_kv][reg_id] = out_of_boundary ? (KTraits::MaskFillValue) : s_frag[mma_q][mma_kv][reg_id]; } else { s_frag[mma_q][mma_kv][reg_id] = (kv_idx < prefix_len | (idx_in_original_seq < kv_idx + token_pos_in_items_regs[mma_q][((reg_id % 4) / 2)])) ? s_frag[mma_q][mma_kv][reg_id] : (KTraits::MaskFillValue); } } } } } template __device__ __forceinline__ void update_mdo_states( typename KTraits::AttentionVariant variant, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], float (*o_frag)[KTraits::NUM_MMA_D_VO][8], typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) { using DTypeQKAccum = typename KTraits::DTypeQKAccum; using AttentionVariant = typename KTraits::AttentionVariant; constexpr bool use_softmax = AttentionVariant::use_softmax; if constexpr (use_softmax) { const float sm_scale = variant.sm_scale_log2; if constexpr (std::is_same_v) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { float m_prev = m[mma_q][j]; #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { float m_local = max(max(s_frag[mma_q][mma_kv][j * 2 + 0], s_frag[mma_q][mma_kv][j * 2 + 1]), max(s_frag[mma_q][mma_kv][j * 2 + 4], s_frag[mma_q][mma_kv][j * 2 + 5])); m[mma_q][j] = max(m[mma_q][j], m_local); } m[mma_q][j] = max(m[mma_q][j], math::shfl_xor_sync(m[mma_q][j], 0x2)); m[mma_q][j] = max(m[mma_q][j], math::shfl_xor_sync(m[mma_q][j], 0x1)); float o_scale = math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); d[mma_q][j] *= o_scale; #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; } #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { s_frag[mma_q][mma_kv][j * 2 + 0] = math::ptx_exp2( s_frag[mma_q][mma_kv][j * 2 + 0] * sm_scale - m[mma_q][j] * sm_scale); s_frag[mma_q][mma_kv][j * 2 + 1] = math::ptx_exp2( s_frag[mma_q][mma_kv][j * 2 + 1] * sm_scale - m[mma_q][j] * sm_scale); s_frag[mma_q][mma_kv][j * 2 + 4] = math::ptx_exp2( s_frag[mma_q][mma_kv][j * 2 + 4] * sm_scale - m[mma_q][j] * sm_scale); s_frag[mma_q][mma_kv][j * 2 + 5] = math::ptx_exp2( s_frag[mma_q][mma_kv][j * 2 + 5] * sm_scale - m[mma_q][j] * sm_scale); } } } } else if constexpr (std::is_same_v) { const half2 sm_scale = __float2half2_rn(variant.sm_scale_log2); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { half m_prev[2]; #pragma unroll for (uint32_t j = 0; j < 2; ++j) { m_prev[j] = m[mma_q][j]; #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { half2 m_local = __hmax2(*(half2*)&s_frag[mma_q][mma_kv][j * 2], *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4]); m[mma_q][j] = __hmax(m[mma_q][j], __hmax(m_local.x, m_local.y)); } } *(half2*)&m[mma_q] = __hmax2(*(half2*)&m[mma_q], math::shfl_xor_sync(*(half2*)&m[mma_q], 0x2)); *(half2*)&m[mma_q] = __hmax2(*(half2*)&m[mma_q], math::shfl_xor_sync(*(half2*)&m[mma_q], 0x1)); #pragma unroll for (uint32_t j = 0; j < 2; ++j) { float o_scale = math::ptx_exp2(float(m_prev[j] * sm_scale.x - m[mma_q][j] * sm_scale.x)); d[mma_q][j] *= o_scale; #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; } half2 m2 = make_half2(m[mma_q][j], m[mma_q][j]); #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { *(half2*)&s_frag[mma_q][mma_kv][j * 2] = math::ptx_exp2(*(half2*)&s_frag[mma_q][mma_kv][j * 2] * sm_scale - m2 * sm_scale); *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] = math::ptx_exp2( *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] * sm_scale - m2 * sm_scale); } } } } } } template __device__ __forceinline__ void compute_sfm_v( smem_t* v_smem, uint32_t* v_smem_offset_r, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], float (*o_frag)[KTraits::NUM_MMA_D_VO][8], float (*d)[2]) { constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV][8]; if constexpr (std::is_same_v) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { vec_cast::cast<8>(s_frag_f16[mma_q][mma_kv], s_frag[mma_q][mma_kv]); } } } if constexpr (KTraits::AttentionVariant::use_softmax) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { if constexpr (std::is_same_v) { mma::m16k16_rowsum_f16f16f32(d[mma_q], s_frag_f16[mma_q][mma_kv]); } else { mma::m16k16_rowsum_f16f16f32(d[mma_q], s_frag[mma_q][mma_kv]); } } } } #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { uint32_t b_frag[4]; if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { uint32_t b_frag_f8[2]; if (mma_d % 2 == 0) { v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, b_frag_f8); } else { v_smem->ldmatrix_m8n8x4_trans_right_half(*v_smem_offset_r, b_frag_f8); } b_frag_f8[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); vec_cast::cast<8>( (typename KTraits::DTypeQ*)b_frag, (typename KTraits::DTypeKV*)b_frag_f8); swap(b_frag[1], b_frag[2]); } else { v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); } #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { if constexpr (std::is_same_v) { mma::mma_sync_m16n16k16_row_col_f16f16f32( o_frag[mma_q][mma_d], (uint32_t*)s_frag_f16[mma_q][mma_kv], b_frag); } else { mma::mma_sync_m16n16k16_row_col_f16f16f32( o_frag[mma_q][mma_d], (uint32_t*)s_frag[mma_q][mma_kv], b_frag); } } if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { if (mma_d % 2 == 1) { *v_smem_offset_r = v_smem->template advance_offset_by_column<2>(*v_smem_offset_r, mma_d / 2); } } else { *v_smem_offset_r = v_smem->template advance_offset_by_column<2>(*v_smem_offset_r, mma_d); } } *v_smem_offset_r = v_smem->template advance_offset_by_row<16, UPCAST_STRIDE_V>(*v_smem_offset_r) - sizeof(typename KTraits::DTypeKV) * KTraits::NUM_MMA_D_VO; } *v_smem_offset_r -= 16 * KTraits::NUM_MMA_KV * UPCAST_STRIDE_V; } template __device__ __forceinline__ void finalize_m(typename KTraits::AttentionVariant variant, typename KTraits::DTypeQKAccum (*m)[2]) { if constexpr (variant.use_softmax) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { if (m[mma_q][j] != typename KTraits::DTypeQKAccum(-math::inf)) { m[mma_q][j] *= variant.sm_scale_log2; } } } } } template __device__ __forceinline__ void transform_output( const Params& params, typename KTraits::AttentionVariant variant, float (*o_frag)[KTraits::NUM_MMA_D_VO][8], typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2], const uint32_t batch_idx, const uint32_t kv_tile_idx, const uint32_t qo_packed_idx_base, const uint32_t warp_idx, const uint32_t lane_idx, uint32_t kv_head_idx, const uint_fastdiv group_size) { uint32_t q[KTraits::NUM_MMA_Q][2], r[KTraits::NUM_MMA_Q][2]; float scale[KTraits::NUM_MMA_Q][2]; #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + 8 * j, q[mma_q][j], r[mma_q][j]); uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][j]; // Update the m and d when attention sinks are used. variant.update_m_d(params, kv_tile_idx, qo_head_idx, m[mma_q][j], d[mma_q][j], scale[mma_q][j]); } } #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { const uint32_t qo_idx = q[mma_q][(reg_id % 4) / 2]; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; o_frag[mma_q][mma_d][reg_id] = variant.OutputTransform( params, o_frag[mma_q][mma_d][reg_id], batch_idx, qo_idx, qo_head_idx, m[mma_q][(reg_id % 4) / 2], d[mma_q][(reg_id % 4) / 2], scale[mma_q][(reg_id % 4) / 2]); } } } } /*! * \brief Synchronize the states of the MDO kernel across the threadblock along threadIdx.z. */ template __device__ __forceinline__ void threadblock_sync_mdo_states( float (*o_frag)[KTraits::NUM_MMA_D_VO][8], typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2], const uint32_t warp_idx, const uint32_t lane_idx, const dim3 tid = threadIdx) { // only necessary when blockDim.z > 1 if constexpr (KTraits::NUM_WARPS_KV > 1) { float* smem_o = smem_storage->cta_sync_o_smem; float2* smem_md = smem_storage->cta_sync_md_smem; // o: [num_warps, NUM_MMA_Q, NUM_MMA_D_VO, WARP_SIZE(32), 8] // md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { vec_t::memcpy( smem_o + (((warp_idx * KTraits::NUM_MMA_Q + mma_q) * KTraits::NUM_MMA_D_VO + mma_d) * WARP_SIZE + lane_idx) * 8, o_frag[mma_q][mma_d]); } } if constexpr (KTraits::AttentionVariant::use_softmax) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * 2 + j) * 8 + lane_idx / 4] = make_float2(float(m[mma_q][j]), d[mma_q][j]); } } // synchronize m,d first __syncthreads(); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { float o_scale[2][KTraits::NUM_WARPS_KV]; #pragma unroll for (uint32_t j = 0; j < 2; ++j) { float m_new = -math::inf, d_new = 1.f; #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * KTraits::NUM_MMA_Q + mma_q) * 2 + j) * 8 + lane_idx / 4]; float m_prev = m_new, d_prev = d_new; m_new = max(m_new, md.x); d_new = d_prev * math::ptx_exp2(m_prev - m_new) + md.y * math::ptx_exp2(md.x - m_new); } #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * KTraits::NUM_MMA_Q + mma_q) * 2 + j) * 8 + lane_idx / 4]; float mi = md.x; o_scale[j][i] = math::ptx_exp2(float(mi - m_new)); } m[mma_q][j] = typename KTraits::DTypeQKAccum(m_new); d[mma_q][j] = d_new; } #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { vec_t o_new; o_new.fill(0.f); #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { vec_t oi; oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * KTraits::NUM_MMA_Q + mma_q) * KTraits::NUM_MMA_D_VO + mma_d) * WARP_SIZE + lane_idx) * 8); #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { o_new[reg_id] += oi[reg_id] * o_scale[(reg_id % 4) / 2][i]; } } o_new.store(o_frag[mma_q][mma_d]); } } } else { // synchronize m,d first __syncthreads(); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { vec_t o_new; o_new.fill(0.f); #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { vec_t oi; oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * KTraits::NUM_MMA_Q + mma_q) * KTraits::NUM_MMA_D_VO + mma_d) * WARP_SIZE + lane_idx) * 8); #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { o_new[reg_id] += oi[reg_id]; } } o_new.store(o_frag[mma_q][mma_d]); } } } } } template __device__ __forceinline__ void write_o_reg_gmem( float (*o_frag)[KTraits::NUM_MMA_D_VO][8], smem_t* o_smem, typename KTraits::DTypeO* o_ptr_base, const uint32_t o_packed_idx_base, const uint32_t qo_upper_bound, const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv group_size, const dim3 tid = threadIdx) { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; const uint32_t warp_idx_x = get_warp_idx_q(tid.y); const uint32_t lane_idx = tid.x; if constexpr (sizeof(DTypeO) == 4) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { uint32_t q, r; group_size.divmod(o_packed_idx_base + lane_idx / 4 + mma_q * 16 + j * 8, q, r); const uint32_t o_idx = q; #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { if (o_idx < qo_upper_bound) { *reinterpret_cast(o_ptr_base + q * o_stride_n + r * o_stride_h + mma_d * 16 + (lane_idx % 4) * 2) = *reinterpret_cast(&o_frag[mma_q][mma_d][j * 2]); *reinterpret_cast(o_ptr_base + q * o_stride_n + r * o_stride_h + mma_d * 16 + 8 + (lane_idx % 4) * 2) = *reinterpret_cast(&o_frag[mma_q][mma_d][4 + j * 2]); } } } } } else { if (get_warp_idx_kv(tid.z) == 0) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { uint32_t o_frag_f16[8 / 2]; vec_cast::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED uint32_t o_smem_offset_w = o_smem->get_permuted_offset( (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16, mma_d * 2 + lane_idx / 16); o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); #else uint32_t o_smem_offset_w = o_smem->get_permuted_offset( (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = o_frag_f16[1]; ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2]; ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = o_frag_f16[3]; #endif } } uint32_t o_smem_offset_w = o_smem->get_permuted_offset( warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2 * 2; ++j) { uint32_t q, r; group_size.divmod(o_packed_idx_base + lane_idx / 8 + mma_q * 16 + j * 4, q, r); const uint32_t o_idx = q; DTypeO* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h + (lane_idx % 8) * upcast_size(); #pragma unroll for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) { if (o_idx < qo_upper_bound) { o_smem->store_128b(o_smem_offset_w, o_ptr); } o_ptr += 8 * upcast_size(); o_smem_offset_w = o_smem->template advance_offset_by_column<8>(o_smem_offset_w, mma_do); } o_smem_offset_w = o_smem->template advance_offset_by_row<4, UPCAST_STRIDE_O>(o_smem_offset_w) - 2 * KTraits::NUM_MMA_D_VO; } } } } } } // namespace /*! * \brief FlashAttention prefill CUDA kernel for a single request. * \tparam partition_kv Whether to split kv_len into chunks. * \tparam mask_mode The mask mode used in the attention operation. * \tparam POS_ENCODING_MODE The positional encoding mode. * \tparam NUM_MMA_Q The number of fragments in x dimension. * \tparam NUM_MMA_D_VO The number of fragments in y dimension. * \tparam NUM_MMA_KV The number of fragments in z dimension. * \tparam num_warps The number of warps in the threadblock. * \tparam DTypeQ The data type of the query tensor. * \tparam DTypeKV The data type of the key/value tensor. * \tparam DTypeO The data type of the output tensor. * \param q The query tensor. * \param k The key tensor. * \param v The value tensor. * \param o The output tensor. * \param tmp The temporary buffer (used when partition_kv is true). * \param lse The logsumexp value. * \param rope_rcp_scale 1/(rope_scale), where rope_scale is the scaling * factor used in RoPE interpolation. * \param rope_rcp_theta 1/(rope_theta), where rope_theta is the theta * used in RoPE. */ template __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( const Params params, typename KTraits::SharedStorage& smem_storage, const dim3 tid = threadIdx, const uint32_t bx = blockIdx.x, const uint32_t chunk_idx = blockIdx.y, const uint32_t kv_head_idx = blockIdx.z, const uint32_t num_chunks = gridDim.y, const uint32_t num_kv_heads = gridDim.z) { using DTypeQ = typename Params::DTypeQ; #if (__CUDA_ARCH__ < 800) if constexpr (std::is_same_v) { FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); } else { #endif using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; using DTypeQKAccum = typename KTraits::DTypeQKAccum; using AttentionVariant = typename KTraits::AttentionVariant; [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = KTraits::SWIZZLE_MODE_Q; [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; DTypeQ* q = params.q; DTypeKV* k = params.k; DTypeKV* v = params.v; DTypeO* o = params.o; float* lse = params.lse; const uint32_t qo_len = params.qo_len; const uint32_t kv_len = params.kv_len; const bool partition_kv = params.partition_kv; const uint32_t q_stride_n = params.q_stride_n; const uint32_t q_stride_h = params.q_stride_h; const uint32_t k_stride_n = params.k_stride_n; const uint32_t k_stride_h = params.k_stride_h; const uint32_t v_stride_n = params.v_stride_n; const uint32_t v_stride_h = params.v_stride_h; const int32_t maybe_window_left = params.window_left; const uint_fastdiv& group_size = params.group_size; static_assert(sizeof(DTypeQ) == 2); const uint32_t lane_idx = tid.x, warp_idx = get_warp_idx(tid.y, tid.z); const uint32_t num_qo_heads = num_kv_heads * group_size; const uint32_t max_chunk_size = partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; const uint32_t chunk_start = partition_kv ? chunk_idx * max_chunk_size : 0; const uint32_t chunk_end = partition_kv ? min((chunk_idx + 1) * max_chunk_size, kv_len) : kv_len; const uint32_t chunk_size = chunk_end - chunk_start; auto block = cg::this_thread_block(); auto smem = reinterpret_cast(&smem_storage); AttentionVariant variant(params, /*batch_idx=*/0, smem); const uint32_t window_left = variant.window_left; DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; DTypeQKAccum m[NUM_MMA_Q][2]; float d[NUM_MMA_Q][2]; float rope_freq[NUM_MMA_D_QK / 2][4]; if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { const float rope_rcp_scale = params.rope_rcp_scale; const float rope_rcp_theta = params.rope_rcp_theta; init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, tid.x); } init_states(variant, o_frag, m, d); // cooperative fetch q fragment from gmem to reg const uint32_t qo_packed_idx_base = (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; smem_t qo_smem(smem_storage.q_smem); const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; DTypeQ* q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; DTypeO* o_ptr_base = partition_kv ? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h : o + (kv_head_idx * group_size) * o_stride_h; uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem, tid); cp_async::commit_group(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { cp_async::wait_group<0>(); block.sync(); q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, tid); block.sync(); } smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); const uint32_t num_iterations = ceil_div( MASK_MODE == MaskMode::kCausal ? min(chunk_size, sub_if_greater_or_zero( kv_len - qo_len + ceil_div(((bx + 1) * CTA_TILE_Q), group_size), chunk_start)) : chunk_size, CTA_TILE_KV); const uint32_t window_iteration = ceil_div(sub_if_greater_or_zero(kv_len + ceil_div((bx + 1) * CTA_TILE_Q, group_size), qo_len + window_left + chunk_start), CTA_TILE_KV); const uint32_t mask_iteration = (MASK_MODE == MaskMode::kCausal ? min(chunk_size, sub_if_greater_or_zero(kv_len + ceil_div((bx * CTA_TILE_Q), group_size) - qo_len, chunk_start)) : chunk_size) / CTA_TILE_KV; DTypeKV* k_ptr = k + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * k_stride_n + kv_head_idx * k_stride_h + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); DTypeKV* v_ptr = v + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * v_stride_n + kv_head_idx * v_stride_h + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), k_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL), v_smem_offset_w = v_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL); produce_kv(k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); cp_async::commit_group(); produce_kv(v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); cp_async::commit_group(); #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { cp_async::wait_group<1>(); block.sync(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { k_smem_inplace_apply_rotary(chunk_start + iter * CTA_TILE_KV, &k_smem, &k_smem_offset_r, rope_freq, tid); block.sync(); } // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); uint32_t kv_idx_base = chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16; logits_transform(params, variant, /*batch_idx=*/0, qo_packed_idx_base, kv_idx_base, qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { logits_mask(params, variant, /*batch_idx=*/0, qo_packed_idx_base, kv_idx_base, qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); } // compute m,d states in online softmax update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); produce_kv( k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); // compute sfm*v compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); produce_kv( v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); cp_async::commit_group(); } cp_async::wait_group<0>(); block.sync(); finalize_m(variant, m); // threadblock synchronization threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx, tid); // transform output transform_output(params, variant, o_frag, m, d, /*batch_idx=*/0, chunk_idx, qo_packed_idx_base, warp_idx, lane_idx, kv_head_idx, group_size); // write back write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, /*o_stride_n=*/ partition_kv ? num_chunks * o_stride_n : o_stride_n, /*o_stride_h=*/o_stride_h, group_size, tid); // write lse if constexpr (variant.use_softmax) { if (lse != nullptr || partition_kv) { if (get_warp_idx_kv(tid.z) == 0) { #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { uint32_t q, r; group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r); const uint32_t qo_head_idx = kv_head_idx * group_size + r; const uint32_t qo_idx = q; if (qo_idx < qo_len) { if (partition_kv) { lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else { lse[qo_idx * num_qo_heads + qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } } } } } #if (__CUDA_ARCH__ < 800) } #endif } template __global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCacheKernel( const __grid_constant__ Params params) { extern __shared__ uint8_t smem[]; auto& smem_storage = reinterpret_cast(smem); SinglePrefillWithKVCacheDevice(params, smem_storage); } template cudaError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, cudaStream_t stream) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; const uint32_t num_qo_heads = params.num_qo_heads; const uint32_t num_kv_heads = params.num_kv_heads; const uint32_t qo_len = params.qo_len; const uint32_t kv_len = params.kv_len; if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { std::ostringstream err_msg; err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal " "to qo_len, got kv_len" << kv_len << " and qo_len " << qo_len; FLASHINFER_ERROR(err_msg.str()); } const uint32_t group_size = num_qo_heads / num_kv_heads; constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; int64_t packed_qo_len = qo_len * group_size; uint32_t cta_tile_q = FA2DetermineCtaTileQ(packed_qo_len, HEAD_DIM_VO); DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, { constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); using DTypeQKAccum = typename std::conditional, half, float>::type; int dev_id = 0; FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); int max_smem_per_sm = 0; FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks const int num_ctas_per_sm = max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; const uint32_t max_num_mma_kv_reg = (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION) ? 2 : (8 / NUM_MMA_Q); const uint32_t max_num_mma_kv_smem = (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); // control NUM_MMA_KV for maximum warp occupancy DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { using KTraits = KernelTraits; if constexpr (KTraits::IsInvalid()) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; FLASHINFER_ERROR(err_msg.str()); } else { constexpr uint32_t num_threads = (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; auto kernel = SinglePrefillWithKVCacheKernel; size_t smem_size = sizeof(typename KTraits::SharedStorage); FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); int num_blocks_per_sm = 0; int num_sm = 0; FLASHINFER_CUDA_CALL( cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &num_blocks_per_sm, kernel, num_threads, smem_size)); uint32_t max_num_kv_chunks = (num_blocks_per_sm * num_sm) / (num_kv_heads * ceil_div(qo_len * group_size, CTA_TILE_Q)); uint32_t num_chunks; if (max_num_kv_chunks > 0) { uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); num_chunks = ceil_div(kv_len, chunk_size); } else { num_chunks = 0; } if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv params.partition_kv = false; void* args[] = {(void*)¶ms}; dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, num_kv_heads); dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { // Use cooperative groups to increase occupancy params.partition_kv = true; float* tmp_lse = (float*)(tmp + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO); auto o = params.o; auto lse = params.lse; params.o = tmp; params.lse = tmp_lse; void* args[] = {(void*)¶ms}; dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), num_chunks, num_kv_heads); dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); if constexpr (AttentionVariant::use_softmax) { FLASHINFER_CUDA_CALL(MergeStates(tmp, tmp_lse, o, lse, num_chunks, qo_len, num_qo_heads, HEAD_DIM_VO, stream)); } else { FLASHINFER_CUDA_CALL( AttentionSum(tmp, o, num_chunks, qo_len, num_qo_heads, HEAD_DIM_VO, stream)); } } } }) }); return cudaSuccess; } template __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKVCacheKernel( const __grid_constant__ Params params) { using DTypeQ = typename Params::DTypeQ; #if (__CUDA_ARCH__ < 800) if constexpr (std::is_same_v) { FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); } else { #endif using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; using IdType = typename Params::IdType; using DTypeQKAccum = typename KTraits::DTypeQKAccum; using AttentionVariant = typename KTraits::AttentionVariant; [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = KTraits::SWIZZLE_MODE_Q; [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; DTypeQ* q = params.q; IdType* request_indices = params.request_indices; IdType* qo_tile_indices = params.qo_tile_indices; IdType* kv_tile_indices = params.kv_tile_indices; IdType* q_indptr = params.q_indptr; IdType* kv_indptr = params.kv_indptr; DTypeKV* k = params.k; DTypeKV* v = params.v; IdType* o_indptr = params.o_indptr; DTypeO* o = params.o; float* lse = params.lse; bool* block_valid_mask = params.block_valid_mask; const bool partition_kv = params.partition_kv; const uint32_t q_stride_n = params.q_stride_n; const uint32_t q_stride_h = params.q_stride_h; const uint32_t k_stride_n = params.k_stride_n; const uint32_t k_stride_h = params.k_stride_h; const uint32_t v_stride_n = params.v_stride_n; const uint32_t v_stride_h = params.v_stride_h; const int32_t maybe_window_left = params.window_left; const uint_fastdiv& group_size = params.group_size; static_assert(sizeof(DTypeQ) == 2); const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); const dim3& tid = threadIdx; auto block = cg::this_thread_block(); const uint32_t bx = blockIdx.x, lane_idx = tid.x, warp_idx = get_warp_idx(tid.y, tid.z), kv_head_idx = blockIdx.z; if (block_valid_mask && !block_valid_mask[bx]) { return; } const uint32_t num_kv_heads = gridDim.z, num_qo_heads = group_size * num_kv_heads; const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx], kv_tile_idx = kv_tile_indices[bx]; extern __shared__ uint8_t smem[]; auto& smem_storage = reinterpret_cast(smem); AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, window_left = variant.window_left; const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; const uint32_t chunk_size = chunk_end - chunk_start; const uint32_t qo_upper_bound = min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size)); DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; DTypeQKAccum m[NUM_MMA_Q][2]; float d[NUM_MMA_Q][2]; float rope_freq[NUM_MMA_D_QK / 2][4]; if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { const float rope_rcp_scale = params.rope_rcp_scale; const float rope_rcp_theta = params.rope_rcp_theta; init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, tid.x); } init_states(variant, o_frag, m, d); const uint32_t qo_packed_idx_base = (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; smem_t qo_smem(smem_storage.q_smem); const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; DTypeQ* q_ptr_base = q + q_indptr[request_idx] * q_stride_n + kv_head_idx * group_size * q_stride_h; DTypeO* o_ptr_base = partition_kv ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n + (kv_head_idx * group_size) * o_stride_h : o + o_indptr[request_idx] * o_stride_n + (kv_head_idx * group_size) * o_stride_h; uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem, tid); cp_async::commit_group(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { cp_async::wait_group<0>(); block.sync(); IdType* q_rope_offset = nullptr; if constexpr (has_maybe_q_rope_offset_v) { q_rope_offset = params.maybe_q_rope_offset; } if (!q_rope_offset) { q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, tid); } else { q_smem_inplace_apply_rotary_with_pos( qo_packed_idx_base, q_rope_offset + q_indptr[request_idx], &qo_smem, group_size, &q_smem_offset_r, rope_freq, tid); } block.sync(); } const uint32_t num_iterations = ceil_div( (MASK_MODE == MaskMode::kCausal ? min(chunk_size, sub_if_greater_or_zero( kv_len - qo_len + ceil_div(((qo_tile_idx + 1) * CTA_TILE_Q), group_size), chunk_start)) : chunk_size), CTA_TILE_KV); const uint32_t window_iteration = ceil_div( sub_if_greater_or_zero(kv_len + ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size), qo_len + window_left + chunk_start), CTA_TILE_KV); const uint32_t mask_iteration = (MASK_MODE == MaskMode::kCausal ? min(chunk_size, sub_if_greater_or_zero( kv_len + ceil_div((qo_tile_idx * CTA_TILE_Q), group_size) - qo_len, chunk_start)) : chunk_size) / CTA_TILE_KV; smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), k_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL), v_smem_offset_w = v_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL); DTypeKV* k_ptr = k + (kv_indptr[request_idx] + chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * k_stride_n + kv_head_idx * k_stride_h + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); DTypeKV* v_ptr = v + (kv_indptr[request_idx] + chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * v_stride_n + kv_head_idx * v_stride_h + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); produce_kv(k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); cp_async::commit_group(); produce_kv(v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); cp_async::commit_group(); #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { cp_async::wait_group<1>(); block.sync(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { IdType* k_rope_offset = nullptr; if constexpr (has_maybe_k_rope_offset_v) { k_rope_offset = params.maybe_k_rope_offset; } k_smem_inplace_apply_rotary( (k_rope_offset == nullptr ? 0 : k_rope_offset[request_idx]) + chunk_start + iter * CTA_TILE_KV, &k_smem, &k_smem_offset_r, rope_freq, tid); block.sync(); } // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); uint32_t kv_idx_base = chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16; logits_transform(params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, kv_idx_base, qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { logits_mask(params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, kv_idx_base, qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); } // compute m,d states in online softmax update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); produce_kv( k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); // compute sfm*v compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); produce_kv( v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); cp_async::commit_group(); } cp_async::wait_group<0>(); block.sync(); finalize_m(variant, m); // threadblock synchronization threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx, tid); const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; // transform output transform_output(params, variant, o_frag, m, d, /*batch_idx=*/request_idx, kv_tile_idx, qo_packed_idx_base, warp_idx, lane_idx, kv_head_idx, group_size); // write back write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, /*o_stride_n=*/ partition_kv ? num_kv_chunks * o_stride_n : o_stride_n, /*o_stride_h=*/o_stride_h, group_size, tid); // write lse if constexpr (AttentionVariant::use_softmax) { if (lse != nullptr) { if (get_warp_idx_kv(tid.z) == 0) { #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { uint32_t q, r; group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r); const uint32_t qo_head_idx = kv_head_idx * group_size + r; const uint32_t qo_idx = q; if (qo_idx < qo_len) { if (partition_kv) { lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else { lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } } } } } #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif #if (__CUDA_ARCH__ < 800) } #endif } template __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( const Params params, typename KTraits::SharedStorage& smem_storage, const dim3 tid = threadIdx, const uint32_t bx = blockIdx.x, const uint32_t kv_head_idx = blockIdx.z, const uint32_t num_kv_heads = gridDim.z) { using DTypeQ = typename Params::DTypeQ; #if (__CUDA_ARCH__ < 800) if constexpr (std::is_same_v) { FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); } else { #endif using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; using IdType = typename Params::IdType; using DTypeQKAccum = typename KTraits::DTypeQKAccum; using AttentionVariant = typename KTraits::AttentionVariant; [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = KTraits::SWIZZLE_MODE_Q; [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; IdType* request_indices = params.request_indices; IdType* qo_tile_indices = params.qo_tile_indices; IdType* kv_tile_indices = params.kv_tile_indices; DTypeQ* q = params.q; IdType* q_indptr = params.q_indptr; IdType* o_indptr = params.o_indptr; DTypeO* o = params.o; float* lse = params.lse; bool* block_valid_mask = params.block_valid_mask; const paged_kv_t& paged_kv = params.paged_kv; const bool partition_kv = params.partition_kv; const int32_t maybe_window_left = params.window_left; const uint_fastdiv& group_size = params.group_size; uint32_t* maybe_prefix_len_ptr = nullptr; if constexpr (has_maybe_prefix_len_ptr_v) { maybe_prefix_len_ptr = params.maybe_prefix_len_ptr; } uint16_t* maybe_token_pos_in_items_ptr = nullptr; if constexpr (has_maybe_token_pos_in_items_ptr_v) { maybe_token_pos_in_items_ptr = params.maybe_token_pos_in_items_ptr; } uint32_t token_pos_in_items_len = 0; if constexpr (has_token_pos_in_items_len_v) { token_pos_in_items_len = params.token_pos_in_items_len; } uint16_t* maybe_max_item_len_ptr = nullptr; if constexpr (has_maybe_max_item_len_ptr_v) { maybe_max_item_len_ptr = params.maybe_max_item_len_ptr; } static_assert(sizeof(DTypeQ) == 2); auto block = cg::this_thread_block(); const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); const uint32_t lane_idx = tid.x, warp_idx = get_warp_idx(tid.y, tid.z); if (block_valid_mask && !block_valid_mask[bx]) { return; } const uint32_t num_qo_heads = num_kv_heads * group_size; const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx], kv_tile_idx = kv_tile_indices[bx]; auto smem = reinterpret_cast(&smem_storage); AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, window_left = variant.window_left; const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; const uint32_t chunk_size = chunk_end - chunk_start; const uint32_t qo_upper_bound = min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size)); DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; DTypeQKAccum m[NUM_MMA_Q][2]; float d[NUM_MMA_Q][2]; float rope_freq[NUM_MMA_D_QK / 2][4]; if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { const float rope_rcp_scale = params.rope_rcp_scale; const float rope_rcp_theta = params.rope_rcp_theta; init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, tid.x); } init_states(variant, o_frag, m, d); const uint32_t qo_packed_idx_base = (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; const uint32_t q_stride_n = params.q_stride_n, q_stride_h = params.q_stride_h; smem_t qo_smem(smem_storage.q_smem); const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; DTypeQ* q_ptr_base = q + q_indptr[request_idx] * q_stride_n + (kv_head_idx * group_size) * q_stride_h; DTypeO* o_ptr_base = partition_kv ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n + (kv_head_idx * group_size) * o_stride_h : o + o_indptr[request_idx] * o_stride_n + (kv_head_idx * group_size) * o_stride_h; uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem, tid); cp_async::commit_group(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { cp_async::wait_group<0>(); block.sync(); IdType* q_rope_offset = nullptr; if constexpr (has_maybe_q_rope_offset_v) { q_rope_offset = params.maybe_q_rope_offset; } if (q_rope_offset == nullptr) { q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, tid); } else { q_smem_inplace_apply_rotary_with_pos( qo_packed_idx_base, q_rope_offset + q_indptr[request_idx], &qo_smem, group_size, &q_smem_offset_r, rope_freq, tid); } block.sync(); } smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); size_t thr_local_kv_offset[NUM_MMA_KV * KV_THR_LAYOUT_COL / 2 / NUM_WARPS_Q]; uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), k_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL), v_smem_offset_w = v_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL); const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; uint32_t packed_page_iter_base = paged_kv.indptr[request_idx] * paged_kv.page_size + chunk_start; #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { uint32_t page_iter, entry_idx; paged_kv.page_size.divmod(packed_page_iter_base + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i, page_iter, entry_idx); thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( page_iter, kv_head_idx, entry_idx, (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); } page_produce_kv(&smem_storage, &k_smem_offset_w, paged_kv.k_data, 0, thr_local_kv_offset, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); page_produce_kv(&smem_storage, &v_smem_offset_w, paged_kv.v_data, 0, thr_local_kv_offset, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); uint32_t num_iterations_prefix; uint32_t num_iterations_mask; uint32_t num_iterations = 0; if constexpr (MASK_MODE != MaskMode::kMultiItemScoring) { num_iterations = ceil_div( (MASK_MODE == MaskMode::kCausal ? min(chunk_size, sub_if_greater_or_zero( kv_len - qo_len + ceil_div(((qo_tile_idx + 1) * CTA_TILE_Q), group_size), chunk_start)) : chunk_size), CTA_TILE_KV); } else if constexpr (MASK_MODE == MaskMode::kMultiItemScoring) { num_iterations_prefix = ceil_div( min(min(chunk_size, sub_if_greater_or_zero( kv_len - qo_len + ceil_div(((qo_tile_idx + 1) * CTA_TILE_Q), group_size), chunk_start)), sub_if_greater_or_zero(__ldg(maybe_prefix_len_ptr + request_idx), chunk_start)), CTA_TILE_KV); num_iterations_mask = max(min(chunk_size, sub_if_greater_or_zero( sub_if_greater_or_zero( kv_len - qo_len + ceil_div((qo_tile_idx * CTA_TILE_Q), group_size), __ldg(maybe_max_item_len_ptr + request_idx)), chunk_start)) / (CTA_TILE_KV), num_iterations_prefix); num_iterations = max( num_iterations_mask, ceil_div(min(chunk_size, sub_if_greater_or_zero( kv_len - qo_len + ceil_div(((qo_tile_idx + 1) * CTA_TILE_Q), group_size), chunk_start)), CTA_TILE_KV)); } const uint32_t window_iteration = ceil_div( sub_if_greater_or_zero(kv_len + ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size), qo_len + window_left + chunk_start), CTA_TILE_KV); const uint32_t mask_iteration = (MASK_MODE == MaskMode::kCausal || MASK_MODE == MaskMode::kMultiItemScoring ? min(chunk_size, sub_if_greater_or_zero( kv_len + ceil_div((qo_tile_idx * CTA_TILE_Q), group_size) - qo_len, chunk_start)) : chunk_size) / CTA_TILE_KV; #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; iter = (MASK_MODE == MaskMode::kMultiItemScoring) ? ((iter + 1 == num_iterations_prefix) ? num_iterations_mask : (iter + 1)) : (iter + 1)) { const uint32_t prefetch_skip_step = (MASK_MODE == MaskMode::kMultiItemScoring) ? ((iter + 1 == num_iterations_prefix) ? (num_iterations_mask - num_iterations_prefix) : 0) : 0; packed_page_iter_base += (1 + prefetch_skip_step) * CTA_TILE_KV; #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { uint32_t page_iter, entry_idx; paged_kv.page_size.divmod(packed_page_iter_base + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i, page_iter, entry_idx); thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( page_iter, kv_head_idx, entry_idx, (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); } cp_async::wait_group<1>(); block.sync(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { k_smem_inplace_apply_rotary( (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[request_idx]) + chunk_start + iter * CTA_TILE_KV, &k_smem, &k_smem_offset_r, rope_freq, tid); block.sync(); } // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); uint32_t kv_idx_base = chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16; logits_transform(params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, kv_idx_base, qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); // apply mask if (MASK_MODE == MaskMode::kCustom) { logits_mask(params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, kv_idx_base, qo_len, kv_len, chunk_end, group_size, s_frag); } else { if constexpr (MASK_MODE != MaskMode::kMultiItemScoring) { if (iter >= mask_iteration || iter < window_iteration) { logits_mask(params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, kv_idx_base, qo_len, kv_len, chunk_end, group_size, s_frag); } } else if constexpr (MASK_MODE == MaskMode::kMultiItemScoring) { if (iter + 1 >= num_iterations_prefix) { logits_mask_multi_item_scoring( params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, kv_idx_base, qo_len, kv_len, window_left, chunk_end, group_size, s_frag, __ldg(maybe_prefix_len_ptr + request_idx), maybe_token_pos_in_items_ptr + request_idx * token_pos_in_items_len, tid.x, kv_head_idx); } else { if (iter >= mask_iteration || iter < window_iteration) { logits_mask(params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, kv_idx_base, qo_len, kv_len, chunk_end, group_size, s_frag); } } } } // compute m,d states in online softmax update_mdo_states(variant, s_frag, o_frag, m, d); block.sync(); page_produce_kv(&smem_storage, &k_smem_offset_w, paged_kv.k_data, (iter + 1) * CTA_TILE_KV, thr_local_kv_offset, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); // compute sfm*v compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); page_produce_kv(&smem_storage, &v_smem_offset_w, paged_kv.v_data, (iter + 1) * CTA_TILE_KV, thr_local_kv_offset, chunk_size, warp_idx, lane_idx); cp_async::commit_group(); } cp_async::wait_group<0>(); block.sync(); finalize_m(variant, m); // threadblock synchronization threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx, tid); const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; // transform output transform_output(params, variant, o_frag, m, d, /*batch_idx=*/request_idx, kv_tile_idx, qo_packed_idx_base, warp_idx, lane_idx, kv_head_idx, group_size); // write_back write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, /*o_stride_n=*/ partition_kv ? num_kv_chunks * o_stride_n : o_stride_n, /*o_stride_h=*/o_stride_h, group_size, tid); // write lse if constexpr (variant.use_softmax) { if (lse != nullptr) { if (get_warp_idx_kv(tid.z) == 0) { #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { uint32_t q, r; group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r); const uint32_t qo_head_idx = kv_head_idx * group_size + r; const uint32_t qo_idx = q; if (qo_idx < qo_upper_bound) { if (partition_kv) { lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } else { lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } } } } } } } #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif #if (__CUDA_ARCH__ < 800) } #endif } template __global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVCacheKernel( const __grid_constant__ Params params) { extern __shared__ uint8_t smem[]; auto& smem_storage = reinterpret_cast(smem); BatchPrefillWithPagedKVCacheDevice(params, smem_storage); } template cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; const uint32_t padded_batch_size = params.padded_batch_size; const uint32_t num_qo_heads = params.num_qo_heads; const uint32_t num_kv_heads = params.num_kv_heads; constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); if (padded_batch_size == 0) { // No request, skip // this won't happen in CUDAGraph mode because we fixed the padded_batch_size return cudaSuccess; } dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; using DTypeQKAccum = typename std::conditional, half, float>::type; int dev_id = 0; FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); int max_smem_per_sm = 0; FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks const int num_ctas_per_sm = max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; const uint32_t max_num_mma_kv_reg = (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION) ? 2 : (8 / NUM_MMA_Q); const uint32_t max_num_mma_kv_smem = (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { using KTraits = KernelTraits; if constexpr (KTraits::IsInvalid()) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; FLASHINFER_ERROR(err_msg.str()); } else { size_t smem_size = sizeof(typename KTraits::SharedStorage); auto kernel = BatchPrefillWithRaggedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // PDL launch config cudaLaunchAttribute attribute[1]; cudaLaunchConfig_t config; if (enable_pdl) { attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attribute[0].val.programmaticStreamSerializationAllowed = 1; config.attrs = attribute; config.numAttrs = 1; config.gridDim = nblks; config.blockDim = nthrs; config.dynamicSmemBytes = smem_size; config.stream = stream; } if (tmp_v == nullptr) { // do not partition kv params.partition_kv = false; void* args[] = {(void*)¶ms}; if (enable_pdl) { FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, params)); } else { FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } } else { // partition kv params.partition_kv = true; auto o = params.o; auto lse = params.lse; params.o = tmp_v; params.lse = tmp_s; void* args[] = {(void*)¶ms}; if (enable_pdl) { FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, params)); } else { FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } if constexpr (AttentionVariant::use_softmax) { FLASHINFER_CUDA_CALL(VariableLengthMergeStates( tmp_v, tmp_s, params.merge_indptr, o, lse, params.max_total_num_rows, params.total_num_rows, num_qo_heads, HEAD_DIM_VO, enable_pdl, stream)); } else { FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( tmp_v, params.merge_indptr, o, params.max_total_num_rows, params.total_num_rows, num_qo_heads, HEAD_DIM_VO, enable_pdl, stream)); } } } }); return cudaSuccess; } template cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; const uint32_t padded_batch_size = params.padded_batch_size; const uint32_t num_qo_heads = params.num_qo_heads; const uint32_t num_kv_heads = params.paged_kv.num_heads; constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); if (padded_batch_size == 0) { // No request, skip // this won't happen in CUDAGraph mode because we fixed the padded_batch_size return cudaSuccess; } dim3 nblks(padded_batch_size, 1, num_kv_heads); dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; using DTypeQKAccum = typename std::conditional, half, float>::type; int dev_id = 0; FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); int max_smem_per_sm = 0; FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks const int num_ctas_per_sm = max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; const uint32_t max_num_mma_kv_reg = (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION) ? 2 : (8 / NUM_MMA_Q); const uint32_t max_num_mma_kv_smem = (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { using KTraits = KernelTraits; if constexpr (KTraits::IsInvalid()) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q << " NUM_WARPS_KV=" << NUM_WARPS_KV << " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; FLASHINFER_ERROR(err_msg.str()); } else { size_t smem_size = sizeof(typename KTraits::SharedStorage); auto kernel = BatchPrefillWithPagedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); // PDL launch config cudaLaunchAttribute attribute[1]; cudaLaunchConfig_t config; if (enable_pdl) { attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attribute[0].val.programmaticStreamSerializationAllowed = 1; config.attrs = attribute; config.numAttrs = 1; config.gridDim = nblks; config.blockDim = nthrs; config.dynamicSmemBytes = smem_size; config.stream = stream; } if (tmp_v == nullptr) { // do not partition kv params.partition_kv = false; if (enable_pdl) { FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, params)); } else { void* args[] = {(void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } } else { params.partition_kv = true; auto o = params.o; auto lse = params.lse; params.o = tmp_v; params.lse = tmp_s; if (enable_pdl) { FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, params)); } else { void* args[] = {(void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } if constexpr (AttentionVariant::use_softmax) { FLASHINFER_CUDA_CALL(VariableLengthMergeStates( tmp_v, tmp_s, params.merge_indptr, o, lse, params.max_total_num_rows, params.total_num_rows, num_qo_heads, HEAD_DIM_VO, enable_pdl, stream)); } else { FLASHINFER_CUDA_CALL(VariableLengthAttentionSum( tmp_v, params.merge_indptr, o, params.max_total_num_rows, params.total_num_rows, num_qo_heads, HEAD_DIM_VO, enable_pdl, stream)); } } } }); return cudaSuccess; } } // namespace flashinfer #endif // FLASHINFER_PREFILL_CUH_