/* * Copyright (c) 2023 by FlashInfer team. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FLASHINFER_DECODE_CUH_ #define FLASHINFER_DECODE_CUH_ #include #include #include #include #include #include #include "../cp_async.cuh" #include "../math.cuh" #include "../pos_enc.cuh" #include "../utils.cuh" #include "../vec_dtypes.cuh" #include "cascade.cuh" #include "state.cuh" namespace flashinfer { DEFINE_HAS_MEMBER(decode_maybe_q_rope_offset) namespace cg = cooperative_groups; using cp_async::PrefetchMode; using cp_async::SharedMemFillMode; namespace { /*! * \brief Load k tile from smem and compute qk * \tparam pos_encoding_mode The positional encoding mode used in the kernel * \tparam head_dim A template integer indicates the head dimension * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension * \tparam tile_size A template integer indicates the tile size per (bdx * bdy) threads. * \tparam T A template type indicates the input data type * \param smem A pointer to the start of shared memory * \param q_vec A vector of float indicates the thread-local query vector * \param freq A vector of float indicates the thread-local rope frequency * \param kv_shared_offset An array of uint32_t indicates the k/v tiles offset * in shared memory of different pipeline stages * \param kv_idx A integer indicates the thread-local kv position in kv-cache * \param compute_stage_idx A integer indicates the compute stage index in the pipeline * \param s A float indicates the thread-local result of qk * \param st The self-attention state to be updated */ template __device__ __forceinline__ void compute_qk( const Params& params, AttentionVariant variant, const uint32_t batch_idx, const T* smem, const vec_t& q_vec, const vec_t& freq, uint32_t kv_idx_base, uint32_t iter_base, uint32_t iter_bound, uint32_t qo_head_idx, uint32_t kv_head_idx, float* s, state_t& st, const uint32_t tx, const uint32_t ty, const uint32_t tz) { float m_prev = st.m; #pragma unroll for (uint32_t j = 0; j < tile_size; ++j) { vec_t k_vec; if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { // apply rotary embedding for all rows in k matrix of kv-cache k_vec = vec_apply_llama_rope(smem + j * bdx * vec_size, freq, kv_idx_base + tz * tile_size + j); } else { // do not apply rotary embedding k_vec.cast_load(smem + (j * bdx + tx) * vec_size); } s[j] = 0.f; #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { s[j] += q_vec[i] * k_vec[i]; } #pragma unroll for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) { s[j] += math::shfl_xor_sync(s[j], offset); } const uint32_t pos = kv_idx_base + tz * tile_size + j; s[j] = variant.LogitsTransform(params, s[j], batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, kv_head_idx); if constexpr (variant.use_softmax) { s[j] *= variant.sm_scale_log2; } bool mask = variant.LogitsMask(params, batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, kv_head_idx); s[j] = (iter_base + tz * tile_size + j < iter_bound && mask) ? s[j] : -math::inf; st.m = max(st.m, s[j]); } if constexpr (variant.use_softmax) { float o_scale = math::ptx_exp2(m_prev - st.m); st.d *= o_scale; #pragma unroll for (uint32_t j = 0; j < tile_size; ++j) { s[j] = math::ptx_exp2(s[j] - st.m); st.d += s[j]; } #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { st.o[i] = st.o[i] * o_scale; } } } /*! * \brief Load v tile from shared memory and update local state * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension * \tparam tile_size A template integer indicates the tile size per (bdx * bdy) threads. * \tparam T A template type indicates the input data type * \param smem A pointer to the start of shared memory * \param s A float indicates the pre-softmax attention score * \param kv_shared_offset An array of uint32_t indicates the k/v tiles offset * in shared memory of different pipeline stages * \param compute_stage_idx A integer indicates the compute stage index in the pipeline * \param st The flashattention state to be updated */ template __device__ __forceinline__ void update_local_state(const T* smem, const float* s, uint32_t compute_stage_idx, state_t& st, uint32_t tx) { #pragma unroll for (uint32_t j = 0; j < tile_size; ++j) { vec_t v_vec; v_vec.cast_load(smem + (j * bdx + tx) * vec_size); #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { st.o[i] = st.o[i] + s[j] * v_vec[i]; } } } /*! * \brief Synchronize the state of all warps inside a threadblock. * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension * \tparam bdy A template integer indicates the block size in y dimension * \param st The warp local state * \param smem The pointer to shared memory buffer for o * \param smem_md The pointer to shared memory buffer for m/d */ template __device__ __forceinline__ void sync_state(AttentionVariant variant, state_t& st, float* smem, float* smem_md, const uint32_t tx, const uint32_t ty, const uint32_t tz) { if constexpr (bdz > 1) { constexpr uint32_t head_dim = bdx * vec_size; auto block = cg::this_thread_block(); st.o.store(smem + (tz * bdy + ty) * head_dim + tx * vec_size); if constexpr (variant.use_softmax) { smem_md[(tz * bdy + ty) * 2] = st.m; smem_md[(tz * bdy + ty) * 2 + 1] = st.d; block.sync(); st.init(); #pragma unroll for (uint32_t j = 0; j < bdz; ++j) { float mz = smem_md[(j * bdy + ty) * 2], dz = smem_md[(j * bdy + ty) * 2 + 1]; vec_t oz; oz.load(smem + (j * bdy + ty) * head_dim + tx * vec_size); st.merge(oz, mz, dz); } } else { block.sync(); st.init(); #pragma unroll for (uint32_t j = 0; j < bdz; ++j) { vec_t oz; oz.load(smem + (j * bdy + ty) * head_dim + tx * vec_size); #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { st.o[i] += oz[i]; } } } } } } // namespace /*! * \brief FlashAttention decoding cuda kernel with kv-cache for a single request * \tparam pos_encoding_mode The positional encoding mode * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension * \tparam bdy A template integer indicates the block size in y dimension * \tparam DTypeQ A template type indicates the query data type * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeO A template type indicates the output data type * \param q [num_qo_heads, head_dim] The query matrix * \param k [seq_len, num_kv_heads, head_dim] The key matrix in kv-cache * \param v [seq_len, num_kv_heads, head_dim] The value matrix in kv-cache * \param o [num_qo_heads, head_dim] The output matrix * \param head_dim A integer indicates the head dimension * \param rope_rcp_scale A floating number indicate the reciprocal * of scaling ratio used in PI(Position Interpolation) for RoPE (Rotary * Positional Embeddings) * \param rope_rcp_theta A floating number indicate the reciprocal * of "theta" used in RoPE (Rotary Positional Embeddings) * \param kv_chunk_size A integer indicates the kv-chunk size */ template __global__ void SingleDecodeWithKVCacheKernel(const __grid_constant__ Params params) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; const DTypeQ* q = params.q; const DTypeKV* k = params.k; const DTypeKV* v = params.v; const uint32_t q_stride_n = params.q_stride_n; const uint32_t q_stride_h = params.q_stride_h; const uint32_t kv_stride_n = params.kv_stride_n; const uint32_t kv_stride_h = params.kv_stride_h; DTypeO* o = params.o; float* lse = params.lse; uint32_t kv_chunk_size = params.kv_chunk_size; auto block = cg::this_thread_block(); auto grid = cg::this_grid(); constexpr uint32_t head_dim = bdx * vec_size; uint32_t kv_head_idx = blockIdx.y; uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; uint32_t kv_chunk_idx = blockIdx.x; uint32_t num_qo_heads = params.num_qo_heads; extern __shared__ uint8_t smem[]; AttentionVariant variant(params, /*batch_idx=*/0, smem); const uint32_t seq_len = variant.kv_len; DTypeKV* k_smem = (DTypeKV*)smem; DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * sizeof(DTypeKV)); float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * sizeof(DTypeKV)); uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; vec_t q_vec; vec_t freq; if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { const float rope_rcp_scale = params.rope_rcp_scale; const float rope_rcp_theta = params.rope_rcp_theta; #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { freq[i] = rope_rcp_scale * __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); } // apply rotary embedding to q matrix q_vec = vec_apply_llama_rope(q + qo_head_idx * q_stride_h, freq, seq_len - 1); } else { // do not apply rotary embedding to q matrix q_vec.cast_load(q + qo_head_idx * q_stride_h + tx * vec_size); } block.sync(); uint32_t chunk_start = kv_chunk_idx * kv_chunk_size; kv_chunk_size = min(kv_chunk_size, seq_len - chunk_start); uint32_t chunk_end = chunk_start + kv_chunk_size; // preload k tiles and v tiles uint32_t producer_kv_idx_base = chunk_start; constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; #pragma unroll for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { cp_async::pred_load( k_smem + (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, k + (producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j) * kv_stride_n + kv_head_idx * kv_stride_h + tx * vec_size, producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); } cp_async::commit_group(); for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { cp_async::pred_load( v_smem + (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, v + (producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j) * kv_stride_n + kv_head_idx * kv_stride_h + tx * vec_size, producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); } cp_async::commit_group(); producer_kv_idx_base += bdy * bdz * tile_size_per_bdx; } // pipelining k/v tiles loading and state updating uint32_t consumer_kv_idx_base = chunk_start, stage_idx = 0; state_t st_local; float s[bdy * tile_size_per_bdx]; #pragma unroll 2 for (uint32_t iter = 0; iter < ceil_div(kv_chunk_size, tile_size_per_bdx * bdy * bdz); ++iter) { // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); compute_qk( params, variant, /*batch_idx=*/0, k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, q_vec, freq, consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, qo_head_idx, kv_head_idx, s, st_local, tx, ty, tz); block.sync(); // load k for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { cp_async::pred_load( k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, k + (producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j) * kv_stride_n + kv_head_idx * kv_stride_h + tx * vec_size, producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); } cp_async::commit_group(); // update m/d/o state cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); update_local_state( v_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, s, stage_idx, st_local, tx); block.sync(); // load v for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { cp_async::pred_load( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, v + (producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j) * kv_stride_n + kv_head_idx * kv_stride_h + tx * vec_size, producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); } cp_async::commit_group(); stage_idx = (stage_idx + 1) % num_stages_smem; producer_kv_idx_base += tile_size_per_bdx * bdy * bdz; consumer_kv_idx_base += tile_size_per_bdx * bdy * bdz; } cp_async::wait_group<0>(); block.sync(); // sync local state of all warps inside a threadblock sync_state(variant, st_local, reinterpret_cast(smem), smem_md, tx, ty, tz); if constexpr (variant.use_softmax) { st_local.normalize(); } st_local.o.cast_store(o + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); if (lse != nullptr) { lse[kv_chunk_idx * num_qo_heads + qo_head_idx] = st_local.get_lse(); } } /*! * \brief FlashAttention decoding cuda kernel with paged kv-cache for multiple requests * \tparam pos_encoding_mode The positional encoding mode * \tparam vec_size A template integer indicates the vector size * \tparam bdx A template integer indicates the block size in x dimension * \tparam bdy A template integer indicates the block size in y dimension * \tparam bdz A template integer indicates the block size in z dimension * \tparam DTypeQ A template type indicates the query data type * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeO A template type indicates the output data type * \tparam IdType A template type indicates the index data type * \param q [batch_size, num_qo_heads, head_dim] The query matrix * \param paged_kv The paged kv-cache data structure * \param o [num_qo_heads, head_dim] The output matrix * \param tmp Used-allocated temporary buffer * \param lse The logsumexp values * \param sm_scale A float indicates the scale applied to pre-softmax logits * \param rope_rcp_scale A floating number indicate the reciprocal * of scaling ratio used in PI(Position Interpolation) for RoPE (Rotary * Positional Embeddings) * \param rope_rcp_theta A floating number indicate the reciprocal * of "theta" used in RoPE (Rotary Positional Embeddings) */ template __device__ __inline__ void BatchDecodeWithPagedKVCacheDevice(const Params& params, uint8_t smem[], const uint32_t bx = blockIdx.x, const uint32_t by = blockIdx.y, const uint32_t tx = threadIdx.x, const uint32_t ty = threadIdx.y, const uint32_t tz = threadIdx.z) { auto block = cg::this_thread_block(); using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; using IdType = typename Params::IdType; const DTypeQ* q = params.q; DTypeO* o = params.o; float* lse = params.lse; const auto paged_kv = params.paged_kv; const bool* block_valid_mask = params.block_valid_mask; const uint32_t padded_batch_size = params.padded_batch_size; const uint32_t num_qo_heads = params.num_qo_heads; const bool partition_kv = params.partition_kv; constexpr uint32_t head_dim = bdx * vec_size; const uint32_t batch_idx = params.request_indices[bx]; const uint32_t kv_tile_idx = params.kv_tile_indices[bx]; const uint32_t kv_head_idx = by; const uint32_t qo_head_idx = kv_head_idx * bdy + ty; // NOTE(Zihao): when CUDAGraph is enabled, we will launch more blocks than // the actual batch size, so we need to check if the current batch is valid if (block_valid_mask && !block_valid_mask[bx]) return; const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); const uint32_t kv_len = paged_kv.get_length(batch_idx); const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; const uint32_t chunk_end = partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; const uint32_t chunk_size = chunk_end - chunk_start; AttentionVariant variant(params, batch_idx, smem); DTypeKV* k_smem = (DTypeKV*)smem; DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * sizeof(DTypeKV)); size_t* kv_offset_smem = (size_t*)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * sizeof(DTypeKV)); float* smem_md = (float*)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * sizeof(DTypeKV)); vec_t q_vec; vec_t freq; const uint32_t q_stride_n = params.q_stride_n; const uint32_t q_stride_h = params.q_stride_h; if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { const IdType* q_rope_offset = nullptr; if constexpr (has_decode_maybe_q_rope_offset_v) { q_rope_offset = params.decode_maybe_q_rope_offset; } int32_t q_rope_offset_val = q_rope_offset == nullptr ? (kv_len - 1) : q_rope_offset[batch_idx]; const float rope_rcp_scale = params.rope_rcp_scale; const float rope_rcp_theta = params.rope_rcp_theta; #pragma unroll for (uint32_t i = 0; i < vec_size; ++i) { freq[i] = rope_rcp_scale * __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); } #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif // apply rotary embedding to q matrix q_vec = vec_apply_llama_rope( q + batch_idx * q_stride_n + qo_head_idx * q_stride_h, freq, q_rope_offset_val); } else { // do not apply rotary embedding to q matrix #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif q_vec.cast_load(q + batch_idx * q_stride_n + qo_head_idx * q_stride_h + tx * vec_size); } // preload k/v tiles uint32_t stage_idx = 0; constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; static_assert(num_stages_smem <= bdx); uint32_t packed_page_iter_base = paged_kv.indptr[batch_idx] * paged_kv.page_size + chunk_start; #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { uint32_t q, r; paged_kv.page_size.divmod(packed_page_iter_base + ((j * bdz + tz) * bdy + ty) * bdx + tx, q, r); kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr); } block.sync(); size_t kv_offset[tile_size_per_bdx]; #pragma unroll for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { kv_offset[j] = kv_offset_smem[((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j] + tx * vec_size; } #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { cp_async::pred_load( k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, paged_kv.k_data + kv_offset[j], ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } cp_async::commit_group(); #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { cp_async::pred_load( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, paged_kv.v_data + kv_offset[j], ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } cp_async::commit_group(); stage_idx = (stage_idx + 1) % num_stages_smem; } state_t st; float s[bdy * tile_size_per_bdx]; #pragma unroll 2 for (uint32_t iter = 0; iter < ceil_div(chunk_size, tile_size_per_bdx * bdy * bdz); ++iter) { if ((iter + num_stages_smem) % bdx == 0) { #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { uint32_t q, r; paged_kv.page_size.divmod( packed_page_iter_base + ((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz + ((j * bdz + tz) * bdy + ty) * bdx + tx), q, r); kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr); } } // compute qk cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); compute_qk( params, variant, batch_idx, k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, q_vec, freq, (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[batch_idx]) + chunk_start + iter * tile_size_per_bdx * bdy * bdz, iter * tile_size_per_bdx * bdy * bdz, chunk_size, qo_head_idx, kv_head_idx, s, st, tx, ty, tz); block.sync(); #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { kv_offset[j] = kv_offset_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j] + tx * vec_size; } // load k tiles #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { cp_async::pred_load( k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, paged_kv.k_data + kv_offset[j], (((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } cp_async::commit_group(); // update m/d/o states cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); update_local_state( v_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, s, stage_idx, st, tx); block.sync(); // load v tiles #pragma unroll for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { cp_async::pred_load( v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + tx * vec_size, paged_kv.v_data + kv_offset[j], (((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } cp_async::commit_group(); stage_idx = (stage_idx + 1) % num_stages_smem; } cp_async::wait_group<0>(); block.sync(); // sync local state of all warps inside a threadblock sync_state(variant, st, reinterpret_cast(smem), smem_md, tx, ty, tz); if constexpr (variant.use_softmax) { st.normalize(); } if (tz == 0) { st.o.cast_store(o + (bx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); // write lse if (lse != nullptr) { lse[bx * num_qo_heads + qo_head_idx] = st.get_lse(); } } #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } template __global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ Params params) { extern __shared__ uint8_t smem[]; BatchDecodeWithPagedKVCacheDevice(params, smem); } /*! * \brief Get the heuristic number of threads per threadblock * \param group_size The number of qo heads that maps to the same kv head in GQA. * \param sizeof_dtype The size (in terms of bytes) of the input data type */ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeof_dtype) { if (group_size == 8U) { if (sizeof_dtype == 1U) { return 256U; // not enough registers for 512 threads } else { return 512U; } } else { return 128U; } } /*! * \brief FlashAttention decoding with kv-cache for a single request * \tparam DTypeQ A template type indicates the query data type * \tparam DTypeKV A template type indicates the key-value data type * \tparam DTypeO A template type indicates the output data type * \param q The query matrix, shape: [num_qo_heads, head_dim] * \param k The key matrix in kv-cache, shape: [seq_len, num_kv_heads, head_dim] * for NHD layout, [num_kv_heads, seq_len, head_dim] for HND layout * \param v The value matrix in kv-cache, shape: [seq_len, num_kv_heads, * head_dim] for NHD layout, [num_kv_heads, seq_len, head_dim] for HND layout * \param o The output matrix, shape: [num_qo_heads, head_dim] * \param tmp Used-allocated temporary buffer * \param num_qo_heads A integer indicates the number of heads of query and output * \param num_kv_heads A integer indicates the number of heads of key and value * \param seq_len A integer indicates the sequence length * \param head_dim A integer indicates the head dimension * \param pos_encoding_mode The positional encoding mode * \param rope_scale The scaling factor used in RoPE Interpolation * \param rope_theta The theta used in RoPE * \param stream The cuda stream to launch the kernel * \return status Indicates whether CUDA calls are successful */ template cudaError_t SingleDecodeWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, cudaStream_t stream) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; const uint32_t num_qo_heads = params.num_qo_heads; const uint32_t num_kv_heads = params.num_kv_heads; const uint32_t seq_len = params.kv_len; constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); constexpr uint32_t bdx = HEAD_DIM / vec_size; auto compute_capacity = GetCudaComputeCapability(); static_assert(bdx <= 32U); DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { constexpr uint32_t bdy = GROUP_SIZE; constexpr uint32_t num_threads = std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeKV)), bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U; DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { const uint32_t smem_size = 2U * NUM_STAGES_SMEM * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) + 2U * bdy * bdz * sizeof(float); auto kernel = SingleDecodeWithKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); if (seq_len <= 256 || tmp == nullptr) { // no need to use partition-kv kernel dim3 nblks = dim3(1, num_kv_heads); dim3 nthrs = dim3(bdx, bdy, bdz); params.kv_chunk_size = seq_len; void* args[] = {(void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } else { // use partition-kv kernel int num_blocks_per_sm = 0; int num_sm = 0; int dev_id = 0; FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); FLASHINFER_CUDA_CALL( cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( &num_blocks_per_sm, kernel, num_threads, smem_size)); uint32_t max_grid_size = uint32_t(num_blocks_per_sm) * uint32_t(num_sm); uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256); uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size); dim3 nblks = dim3(num_chunks, num_kv_heads); if (nblks.x == 0 || nblks.y == 0) { std::ostringstream err_msg; err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")"; FLASHINFER_ERROR(err_msg.str()); } dim3 nthrs = dim3(bdx, bdy, bdz); float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM); auto o = params.o; auto lse = params.lse; params.o = tmp; params.lse = tmp_lse; params.kv_chunk_size = kv_chunk_size; void* args[] = {(void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); if constexpr (AttentionVariant::use_softmax) { FLASHINFER_CUDA_CALL( MergeStates(tmp, tmp_lse, o, lse, num_chunks, 1, num_qo_heads, HEAD_DIM, stream)); } else { FLASHINFER_CUDA_CALL(AttentionSum(tmp, o, num_chunks, 1, num_qo_heads, HEAD_DIM, stream)); } } }); }); return cudaSuccess; } template cudaError_t BatchDecodeWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; using IdType = typename Params::IdType; const uint32_t num_qo_heads = params.num_qo_heads; const uint32_t num_kv_heads = params.paged_kv.num_heads; const uint32_t padded_batch_size = params.padded_batch_size; constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); auto compute_capacity = GetCudaComputeCapability(); constexpr uint32_t bdx = HEAD_DIM / vec_size; static_assert(bdx <= 32); DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { constexpr uint32_t bdy = GROUP_SIZE; constexpr uint32_t num_threads = std::max(128U, bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { const uint32_t smem_size = 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); auto kernel = BatchDecodeWithPagedKVCacheKernel; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); dim3 nblks(padded_batch_size, num_kv_heads); dim3 nthrs(bdx, bdy, bdz); // PDL launch config cudaLaunchAttribute attribute[1]; cudaLaunchConfig_t config; if (enable_pdl) { attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attribute[0].val.programmaticStreamSerializationAllowed = 1; config.attrs = attribute; config.numAttrs = 1; config.gridDim = nblks; config.blockDim = nthrs; config.dynamicSmemBytes = smem_size; config.stream = stream; } if (tmp_v == nullptr) { // do not use partition-kv kernel params.partition_kv = false; if (enable_pdl) { FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, params)); } else { void* args[] = {(void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } } else { // use partition-kv kernel params.partition_kv = true; auto o = params.o; auto lse = params.lse; params.o = tmp_v; params.lse = tmp_s; if (enable_pdl) { FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, params)); } else { void* args[] = {(void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } if constexpr (AttentionVariant::use_softmax) { FLASHINFER_CUDA_CALL(VariableLengthMergeStates( tmp_v, tmp_s, params.o_indptr, o, lse, params.paged_kv.batch_size, nullptr, num_qo_heads, HEAD_DIM, enable_pdl, stream)); } else { FLASHINFER_CUDA_CALL( VariableLengthAttentionSum(tmp_v, params.o_indptr, o, params.paged_kv.batch_size, nullptr, num_qo_heads, HEAD_DIM, enable_pdl, stream)); } } }); }); return cudaSuccess; } template __device__ __forceinline__ void compute_qk_and_update_local_stat_mla( const Params& params, AttentionVariant variant, const uint32_t batch_idx, const T* ckv_smem, const vec_t& q_nope_vec, const T* kpe_smem, const vec_t& q_pe_vec, const vec_t& freq, uint32_t kv_idx_base, uint32_t iter_base, uint32_t iter_bound, state_t& st) { uint32_t tx = threadIdx.x, tz = threadIdx.z; constexpr uint32_t head_dim_ckv = bdx * vec_size_ckv; constexpr uint32_t head_dim_kpe = bdx * vec_size_kpe; float s[tile_size]; float m_prev = st.m; #pragma unroll for (uint32_t j = 0; j < tile_size; ++j) { vec_t ckv_vec; ckv_vec.cast_load(ckv_smem + j * head_dim_ckv + tx * vec_size_ckv); vec_t kpe_vec; kpe_vec.cast_load(kpe_smem + j * head_dim_kpe + tx * vec_size_kpe); s[j] = 0.f; #pragma unroll for (uint32_t i = 0; i < vec_size_ckv; ++i) { s[j] += q_nope_vec[i] * ckv_vec[i]; } #pragma unroll for (uint32_t i = 0; i < vec_size_kpe; ++i) { s[j] += q_pe_vec[i] * kpe_vec[i]; } s[j] *= params.sm_scale; #pragma unroll for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) { s[j] += math::shfl_xor_sync(s[j], offset); } s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -math::inf; st.m = max(st.m, s[j]); } float o_scale = math::ptx_exp2(m_prev - st.m); st.d *= o_scale; #pragma unroll for (uint32_t j = 0; j < tile_size; ++j) { s[j] = math::ptx_exp2(s[j] - st.m); st.d += s[j]; } #pragma unroll for (uint32_t i = 0; i < vec_size_ckv; ++i) { st.o[i] = st.o[i] * o_scale; } #pragma unroll for (uint32_t j = 0; j < tile_size; ++j) { vec_t v_vec; v_vec.cast_load(ckv_smem + j * head_dim_ckv + tx * vec_size_ckv); #pragma unroll for (uint32_t i = 0; i < vec_size_ckv; ++i) { st.o[i] = st.o[i] + s[j] * v_vec[i]; } } } template __global__ void BatchDecodeWithPagedKVCacheKernelMLA(Params params) { auto block = cg::this_thread_block(); using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; using IdType = typename Params::IdType; const DTypeQ* q_nope = params.q_nope; const DTypeQ* q_pe = params.q_pe; DTypeO* o = params.o; float* lse = params.lse; const auto& paged_kv = params.paged_kv; const IdType* q_rope_offset = params.q_rope_offset; const bool* block_valid_mask = params.block_valid_mask; const uint32_t num_qo_heads = params.num_qo_heads; const float rope_rcp_scale = params.rope_rcp_scale; const float rope_rcp_theta = params.rope_rcp_theta; const bool partition_kv = params.partition_kv; params.sm_scale *= math::log2e; constexpr uint32_t head_dim_ckv = bdx * vec_size_ckv; constexpr uint32_t head_dim_kpe = bdx * vec_size_kpe; const uint32_t batch_idx = blockIdx.x; const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; const uint32_t t_offset = dim3_offset(bdy, bdx, tz, ty, tx); // NOTE(Zihao): when CUDAGraph is enabled, we will launch more blocks than // the actual batch size, so we need to check if the current batch is valid if (block_valid_mask && !block_valid_mask[batch_idx]) return; const uint32_t mapped_batch_idx = params.request_indices[batch_idx]; const uint32_t orig_seq_len = paged_kv.get_length(mapped_batch_idx); int32_t q_rope_offset_val = q_rope_offset == nullptr ? (orig_seq_len - 1) : q_rope_offset[mapped_batch_idx]; const uint32_t kv_chunk_idx_in_orig_mapped_batch = params.kv_tile_indices[batch_idx]; const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); const uint32_t cur_chunk_start = partition_kv ? kv_chunk_idx_in_orig_mapped_batch * kv_chunk_size : 0; const uint32_t cur_chunk_end = partition_kv ? min((kv_chunk_idx_in_orig_mapped_batch + 1) * kv_chunk_size, orig_seq_len) : orig_seq_len; const uint32_t cur_chunk_len = cur_chunk_end - cur_chunk_start; uint32_t packed_page_iter_base = paged_kv.indptr[mapped_batch_idx] * paged_kv.page_size + cur_chunk_start; const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; constexpr uint32_t kv_iter_len = bdy * bdz; constexpr uint32_t compute_qk_tile = bdy; extern __attribute__((shared)) uint8_t smem[]; DTypeKV* ckv_smem = (DTypeKV*)smem; DTypeKV* kpe_smem = (DTypeKV*)((uint8_t*)ckv_smem + num_stages_smem * kv_iter_len * head_dim_ckv * sizeof(DTypeKV)); size_t* ckv_offset_smem = (size_t*)((uint8_t*)kpe_smem + num_stages_smem * kv_iter_len * head_dim_kpe * sizeof(DTypeKV)); size_t* kpe_offset_smem = (size_t*)((uint8_t*)ckv_offset_smem + bdx * bdy * bdz * sizeof(size_t)); float* smem_md = (float*)ckv_offset_smem; AttentionVariant variant(params, batch_idx, smem); vec_t q_nope_vec[tile_size_qo_heads]; vec_t q_pe_vec[tile_size_qo_heads]; state_t st[tile_size_qo_heads]; uint32_t qo_head_idx[tile_size_qo_heads]; vec_t freq; #pragma unroll for (uint32_t i = 0; i < vec_size_kpe; ++i) { freq[i] = rope_rcp_scale * __powf(rope_rcp_theta, float(2 * ((tx * vec_size_kpe + i) / 2)) / float(head_dim_kpe)); } #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif // load q_nope and q_pe tile #pragma unroll for (int i = 0; i < tile_size_qo_heads; ++i) { qo_head_idx[i] = dim3_offset(bdy, tile_size_qo_heads, blockIdx.y, threadIdx.y, i); if (qo_head_idx[i] < num_qo_heads) { q_nope_vec[i].cast_load(q_nope + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv + tx * vec_size_ckv); q_pe_vec[i].cast_load(q_pe + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_kpe + tx * vec_size_kpe); } } // init paged-cache read offset to be used uint32_t q, r; paged_kv.page_size.divmod(packed_page_iter_base + t_offset, q, r); ckv_offset_smem[t_offset] = paged_kv.protective_get_offset_ckv(q, r, /*feat_idx*/ 0, last_indptr); kpe_offset_smem[t_offset] = paged_kv.protective_get_offset_kpe(q, r, /*feat_idx*/ 0, last_indptr); block.sync(); uint32_t stage_idx = 0; constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size_ckv * 8; constexpr uint32_t tx_fold = vec_size_ckv / vec_size_kpe; static_assert(num_stages_smem <= bdx); size_t offset_bytes; bool is_valid_range; #pragma unroll for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { is_valid_range = (iter * kv_iter_len + dim2_offset(bdy, tz, ty)) < cur_chunk_len; offset_bytes = ckv_offset_smem[dim3_offset(bdz, bdy, iter, tz, ty)] + tx * vec_size_ckv; cp_async::pred_load( ckv_smem + (stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty)) * head_dim_ckv + tx * vec_size_ckv, paged_kv.ckv_data + offset_bytes, is_valid_range); offset_bytes = kpe_offset_smem[dim3_offset(bdz, bdy, iter, tz, ty)] + tx / tx_fold * vec_size_ckv; cp_async::pred_load( kpe_smem + (stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty)) * head_dim_kpe + tx / tx_fold * vec_size_ckv, paged_kv.kpe_data + offset_bytes, is_valid_range); cp_async::commit_group(); stage_idx = (stage_idx + 1) % num_stages_smem; } #pragma unroll for (uint32_t iter = 0; iter < ceil_div(cur_chunk_len, kv_iter_len); ++iter) { cp_async::wait_group<1 * num_stages_smem - 1>(); block.sync(); const int32_t kv_idx_base = (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) + cur_chunk_start + iter * kv_iter_len; #pragma unroll for (int i = 0; i < tile_size_qo_heads; ++i) { compute_qk_and_update_local_stat_mla( params, variant, mapped_batch_idx, ckv_smem + (stage_idx * kv_iter_len + tz * compute_qk_tile) * head_dim_ckv, q_nope_vec[i], kpe_smem + (stage_idx * kv_iter_len + tz * compute_qk_tile) * head_dim_kpe, q_pe_vec[i], freq, kv_idx_base, /*iter_base*/ iter * kv_iter_len, /*iter_bound*/ cur_chunk_len, st[i]); } if ((iter + num_stages_smem) % bdx == 0) { uint32_t q, r; paged_kv.page_size.divmod( packed_page_iter_base + (iter + num_stages_smem) * kv_iter_len + t_offset, q, r); ckv_offset_smem[t_offset] = paged_kv.protective_get_offset_ckv(q, r, /*feat_idx*/ 0, last_indptr); kpe_offset_smem[t_offset] = paged_kv.protective_get_offset_kpe(q, r, /*feat_idx*/ 0, last_indptr); } block.sync(); is_valid_range = ((iter + num_stages_smem) * kv_iter_len + dim2_offset(bdy, tz, ty)) < cur_chunk_len; offset_bytes = ckv_offset_smem[dim3_offset(bdz, bdy, (iter + num_stages_smem) % bdx, tz, ty)] + tx * vec_size_ckv; cp_async::pred_load( ckv_smem + (stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty)) * head_dim_ckv + tx * vec_size_ckv, paged_kv.ckv_data + offset_bytes, is_valid_range); offset_bytes = kpe_offset_smem[dim3_offset(bdz, bdy, (iter + num_stages_smem) % bdx, tz, ty)] + tx / tx_fold * vec_size_ckv; cp_async::pred_load( kpe_smem + (stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty)) * head_dim_kpe + tx / tx_fold * vec_size_ckv, paged_kv.kpe_data + offset_bytes, is_valid_range); cp_async::commit_group(); stage_idx = (stage_idx + 1) % num_stages_smem; } cp_async::wait_group<0>(); block.sync(); if (bdz != 1) { #pragma unroll for (int i = 0; i < tile_size_qo_heads; ++i) { if (qo_head_idx[i] < num_qo_heads) sync_state(variant, st[i], (float*)smem, smem_md, tx, ty, tz); } } if (tz == 0) { #pragma unroll for (int i = 0; i < tile_size_qo_heads; ++i) { if (qo_head_idx[i] < num_qo_heads) { st[i].normalize(); st[i].o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv + tx * vec_size_ckv); if (lse != nullptr) { lse[batch_idx * num_qo_heads + qo_head_idx[i]] = st[i].get_lse(); } } } } #if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif } template cudaError_t BatchDecodeWithPagedKVCacheDispatchedMLA(Params params, typename Params::DTypeO* tmp_v, float* tmp_s, bool enable_pdl, cudaStream_t stream) { using DTypeQ = typename Params::DTypeQ; using DTypeKV = typename Params::DTypeKV; using DTypeO = typename Params::DTypeO; using IdType = typename Params::IdType; const uint32_t num_qo_heads = params.num_qo_heads; const uint32_t padded_batch_size = params.padded_batch_size; constexpr uint32_t vec_size_ckv = std::max(16UL / sizeof(DTypeKV), HEAD_DIM_CKV / 32UL); constexpr uint32_t bdx = HEAD_DIM_CKV / vec_size_ckv; constexpr uint32_t vec_size_kpe = HEAD_DIM_KPE / bdx; constexpr uint32_t bdy = 8; constexpr uint32_t tile_size_qo_heads = 2; constexpr uint32_t qo_heads_per_block = bdy * tile_size_qo_heads; constexpr uint32_t num_threads = std::max(128U, bdx * bdy); constexpr uint32_t bdz = num_threads / (bdx * bdy); const uint32_t gdy = ceil_div(num_qo_heads, qo_heads_per_block); auto compute_capacity = GetCudaComputeCapability(); DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { const uint32_t smem_size = NUM_STAGES_SMEM * bdy * bdz * (HEAD_DIM_CKV + HEAD_DIM_KPE) * sizeof(DTypeKV) + std::max(num_threads * sizeof(size_t) * 2, 2 * bdy * bdz * sizeof(float)); auto kernel = BatchDecodeWithPagedKVCacheKernelMLA; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); dim3 nblks(padded_batch_size, gdy); dim3 nthrs(bdx, bdy, bdz); // PDL launch config cudaLaunchAttribute attribute[1]; cudaLaunchConfig_t config; if (enable_pdl) { attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; attribute[0].val.programmaticStreamSerializationAllowed = 1; config.attrs = attribute; config.numAttrs = 1; config.gridDim = nblks; config.blockDim = nthrs; config.dynamicSmemBytes = smem_size; config.stream = stream; } if (tmp_v == nullptr) { // do not use partition-kv kernel params.partition_kv = false; if (enable_pdl) { FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, params)); } else { void* args[] = {(void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } } else { // use partition-kv kernel params.partition_kv = true; auto o = params.o; auto lse = params.lse; params.o = tmp_v; params.lse = tmp_s; if (enable_pdl) { FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, params)); } else { void* args[] = {(void*)¶ms}; FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); } FLASHINFER_CUDA_CALL(VariableLengthMergeStates( tmp_v, tmp_s, params.o_indptr, o, lse, params.paged_kv.batch_size, nullptr, num_qo_heads, HEAD_DIM_CKV, enable_pdl, stream)); } }); return cudaSuccess; } } // namespace flashinfer #endif // FLASHINFER_DECODE_CUH_