sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/attention/persistent.cuh

642 lines
30 KiB
Plaintext

/*
* Copyright (c) 2025 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_PERSISTENT_CUH_
#define FLASHINFER_PERSISTENT_CUH_
#include "../cp_async.cuh"
#include "../math.cuh"
#include "../utils.cuh"
#include "mask.cuh"
#include "persistent_template.cuh"
#include "prefill.cuh"
#include "state.cuh"
namespace flashinfer {
using cp_async::PrefetchMode;
using cp_async::SharedMemFillMode;
template <typename Params>
__device__ __forceinline__ auto get_block_coord(const Params& params, const uint32_t work_idx) {
return std::tuple(params.q_indptr[work_idx], params.kv_indptr[work_idx],
params.partial_indptr[work_idx], params.q_len[work_idx],
params.kv_len[work_idx], params.q_start[work_idx], params.kv_start[work_idx],
params.kv_end[work_idx], params.kv_head_idx_arr[work_idx],
*params.len_kv_chunk);
}
template <typename KTraits>
__device__ __forceinline__ void prefetch_offest(
const uint32_t packed_block_iter_base, const uint32_t packed_kv_bound,
const uint32_t kv_head_idx, const uint32_t kv_stride_page, const uint32_t kv_stride_h,
const uint32_t kv_stride_n, const uint_fastdiv& block_size, typename KTraits::IdType* indices,
size_t* kv_offset) {
using DTypeKV = typename KTraits::DTypeKV;
constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW;
constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL;
constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q;
constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV;
constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV;
constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV;
const uint32_t lane_idx = threadIdx.x % 32, warp_idx = threadIdx.x / 32;
#pragma unroll
for (uint32_t i = 0;
i < NUM_MMA_KV * (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) {
uint32_t page_iter, entry_idx;
uint32_t packed_block_iter = packed_block_iter_base + warp_idx * KV_THR_LAYOUT_ROW +
lane_idx / KV_THR_LAYOUT_COL +
KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i;
block_size.divmod(packed_block_iter, page_iter, entry_idx);
kv_offset[i] = (packed_block_iter < packed_kv_bound ? indices[page_iter] : 0) * kv_stride_page +
entry_idx * kv_stride_n + kv_head_idx * kv_stride_h +
(lane_idx % KV_THR_LAYOUT_COL) * upcast_size<DTypeKV>();
}
}
template <typename KTraits>
__device__ __forceinline__ void write_o_(float (*o_frag)[KTraits::NUM_MMA_D_VO][8],
smem_t<KTraits::SWIZZLE_MODE_Q>* o_smem,
typename KTraits::DTypeO* o_ptr_base,
const uint32_t o_packed_idx_base_warp,
const uint32_t o_packed_idx_base_cta,
const uint32_t qo_upper_bound, const uint32_t o_stride_n,
const uint_fastdiv group_size, const uint32_t warp_idx,
const uint32_t lane_idx, const dim3 tid) {
using DTypeO = typename KTraits::DTypeO;
constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O;
const uint32_t warp_idx_x = get_warp_idx_q<KTraits>(tid.y),
warp_idx_z = get_warp_idx_kv<KTraits>(tid.z);
static_assert(sizeof(DTypeO) == 2);
if (warp_idx_z == 0) {
#pragma unroll
for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
#pragma unroll
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) {
uint32_t o_frag_f16[8 / 2];
vec_cast<DTypeO, float>::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_q][mma_d]);
#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_STRIDE_O>(
(warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16,
mma_d * 2 + lane_idx / 16);
o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16);
#else
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_STRIDE_O>(
(warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2);
((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0];
((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_STRIDE_O))[lane_idx % 4] =
o_frag_f16[1];
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2];
((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + 8 * UPCAST_STRIDE_O))[lane_idx % 4] =
o_frag_f16[3];
#endif
}
}
uint32_t o_smem_offset_w = o_smem->get_permuted_offset<UPCAST_STRIDE_O>(
warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8);
#pragma unroll
for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
#pragma unroll
for (uint32_t j = 0; j < 2 * 2; ++j) {
uint32_t q, r;
const uint32_t o_packed_idx = o_packed_idx_base_warp + lane_idx / 8 + mma_q * 16 + j * 4;
group_size.divmod(o_packed_idx, q, r);
const uint32_t o_idx = q;
DTypeO* o_ptr = o_ptr_base + (o_packed_idx - o_packed_idx_base_cta) * o_stride_n +
(lane_idx % 8) * upcast_size<DTypeO>();
#pragma unroll
for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) {
if (o_idx < qo_upper_bound) {
o_smem->store_128b(o_smem_offset_w, o_ptr);
}
o_ptr += 8 * upcast_size<DTypeO>();
o_smem_offset_w = o_smem->template advance_offset_by_column<8>(o_smem_offset_w, mma_do);
}
o_smem_offset_w =
o_smem->template advance_offset_by_row<4, UPCAST_STRIDE_O>(o_smem_offset_w) -
2 * KTraits::NUM_MMA_D_VO;
}
}
}
}
template <typename KTraits>
__device__ __forceinline__ void normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_VO][8],
typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) {
using AttentionVariant = typename KTraits::AttentionVariant;
if constexpr (AttentionVariant::use_softmax) {
float d_rcp[KTraits::NUM_MMA_Q][2];
// compute reciprocal of d
#pragma unroll
for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
d_rcp[mma_q][j] = (m[mma_q][j] != typename KTraits::DTypeQKAccum(-math::inf))
? math::ptx_rcp(d[mma_q][j])
: 0.f;
}
}
#pragma unroll
for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) {
#pragma unroll
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) {
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
o_frag[mma_q][mma_d][reg_id] =
o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id >> 1) & 1];
}
}
}
}
}
template <typename KTraits_, typename Params_>
struct BlockBatchPagedAttentionPersistent {
using KTraits = KTraits_;
using Params = Params_;
static __device__ __forceinline__ void Run(const Params& params,
typename KTraits::SharedStorage* smem_storage
PROFILER_CLOSURE_FUNC_PARAMS) {
using DTypeQ = typename Params::DTypeQ;
using DTypeKV = typename Params::DTypeKV;
using DTypeO = typename Params::DTypeO;
using IdType = typename Params::IdType;
using DTypeQKAccum = typename KTraits::DTypeQKAccum;
using AttentionVariant = typename KTraits::AttentionVariant;
[[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q;
[[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV;
[[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK;
[[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO;
[[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK;
[[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO;
[[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q;
[[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K;
[[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V;
[[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O;
[[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q;
[[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV;
[[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = KTraits::SWIZZLE_MODE_Q;
[[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV;
[[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q;
[[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV;
[[maybe_unused]] constexpr bool CAUSAL = KTraits::MASK_MODE == MaskMode::kCausal;
[[maybe_unused]] constexpr uint32_t NUM_STAGES = KTraits::NUM_STAGES;
DTypeQ* q = params.q;
DTypeKV* k = params.k;
DTypeKV* v = params.v;
IdType* kv_indices = params.kv_indices;
float* partial_lse = params.partial_lse;
IdType* work_indptr = params.work_indptr;
float s_frag[NUM_MMA_Q][NUM_MMA_KV][8];
alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8];
float m[NUM_MMA_Q][2];
float d[NUM_MMA_Q][2];
const uint_fastdiv& gqa_group_size = params.gqa_group_size;
const uint32_t num_kv_heads = params.num_kv_heads;
const uint_fastdiv& block_size = params.page_size;
const uint32_t q_stride_n = params.q_stride_n;
const uint32_t q_stride_h = params.q_stride_h;
const uint32_t k_stride_page = params.k_stride_page;
const uint32_t k_stride_h = params.k_stride_h;
const uint32_t k_stride_n = params.k_stride_n;
const uint32_t v_stride_page = params.v_stride_page;
const uint32_t v_stride_h = params.v_stride_h;
const uint32_t v_stride_n = params.v_stride_n;
const uint32_t cluster_tile_q = gridDim.x * CTA_TILE_Q;
smem_t<SWIZZLE_MODE_Q> q_smem(smem_storage->q_smem);
AttentionVariant variant(params, /*batch_idx=*/0, nullptr);
const uint32_t lane_idx = threadIdx.x % 32;
const uint32_t warp_idx = threadIdx.x / 32;
// threadIdx: [32, NUM_WARPS_Q, NUM_WARPS_KV]
// remap to utilize tool function in FA2 prefill
const dim3 tid = dim3(lane_idx, warp_idx % NUM_WARPS_Q, warp_idx / NUM_WARPS_Q);
uint32_t q_smem_offset_r = get_permuted_offset<SWIZZLE_MODE_Q, UPCAST_STRIDE_Q>(
get_warp_idx_q<KTraits>(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16);
uint32_t k_smem_offset_r = get_permuted_offset<SWIZZLE_MODE_KV, UPCAST_STRIDE_K>(
get_warp_idx_kv<KTraits>(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) +
lane_idx % 8,
(lane_idx % 16) / 8),
v_smem_offset_r = get_permuted_offset<SWIZZLE_MODE_KV, UPCAST_STRIDE_V>(
get_warp_idx_kv<KTraits>(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16);
uint32_t k_smem_offset_w = get_permuted_offset<SWIZZLE_MODE_KV, UPCAST_STRIDE_K>(
warp_idx * KTraits::KV_THR_LAYOUT_ROW + lane_idx / KTraits::KV_THR_LAYOUT_COL,
lane_idx % KTraits::KV_THR_LAYOUT_COL),
v_smem_offset_w = get_permuted_offset<SWIZZLE_MODE_KV, UPCAST_STRIDE_V>(
warp_idx * KTraits::KV_THR_LAYOUT_ROW + lane_idx / KTraits::KV_THR_LAYOUT_COL,
lane_idx % KTraits::KV_THR_LAYOUT_COL);
size_t thr_local_kv_offset[NUM_MMA_KV * KTraits::KV_THR_LAYOUT_COL / 2 / KTraits::NUM_WARPS_Q];
#pragma unroll 1
for (IdType work_idx = work_indptr[blockIdx.y]; work_idx < work_indptr[blockIdx.y + 1];
++work_idx) {
// profile log
if constexpr (CTA_TILE_Q > 64) {
PROFILER_EVENT_START(profiler_closure, PersistentProfileEventType::kRunner1);
} else {
PROFILER_EVENT_START(profiler_closure, PersistentProfileEventType::kRunner2);
}
const auto [q_indptr, kv_indptr, o_indptr, q_len, kv_len, packed_qo_start, kv_start, kv_end,
kv_head_idx, len_kv_chunk] = get_block_coord(params, work_idx);
const uint32_t kv_chunk_idx = kv_start / len_kv_chunk;
const uint32_t num_kv_chunks = ceil_div(
CAUSAL
? min((kv_len - q_len) + ceil_div(packed_qo_start + cluster_tile_q, gqa_group_size),
kv_len)
: kv_len,
len_kv_chunk);
const uint32_t qo_packed_idx_base = packed_qo_start + blockIdx.x * CTA_TILE_Q +
get_warp_idx_q<KTraits>(tid.y) * NUM_MMA_Q * 16;
const uint32_t qo_upperbound =
min(q_len, ceil_div(qo_packed_idx_base + CTA_TILE_Q, gqa_group_size));
init_states<KTraits>(variant, o_frag, m, d);
DTypeQ* q_ptr_base = q + q_indptr * q_stride_n + (kv_head_idx * gqa_group_size) * q_stride_h;
// load_q
load_q_global_smem<KTraits>(qo_packed_idx_base, qo_upperbound, q_ptr_base, q_stride_n,
q_stride_h, gqa_group_size, &q_smem, tid);
smem_t<SWIZZLE_MODE_KV> k_smem(smem_storage->k_smem), v_smem(smem_storage->v_smem);
int kv_tile_idx =
ceil_div((CAUSAL ? min(kv_end,
kv_len - q_len +
ceil_div((packed_qo_start + cluster_tile_q), gqa_group_size))
: kv_end),
CTA_TILE_KV) -
1 - (kv_start / CTA_TILE_KV);
int mask_tile_idx =
(CAUSAL ? min(kv_end, kv_len - q_len + ceil_div(packed_qo_start, gqa_group_size))
: kv_end) /
CTA_TILE_KV -
(kv_start / CTA_TILE_KV);
uint32_t block_iter_base = kv_indptr * block_size + kv_start;
// last kv tile
__syncthreads();
uint32_t packed_kv_bound = kv_indptr * block_size + kv_len;
prefetch_offest<KTraits>(block_iter_base + kv_tile_idx * CTA_TILE_KV, packed_kv_bound,
kv_head_idx, k_stride_page, k_stride_h, k_stride_n, block_size,
kv_indices, thr_local_kv_offset);
page_produce_kv<false, KTraits>(smem_storage, &k_smem_offset_w, k,
kv_start + kv_tile_idx * CTA_TILE_KV, thr_local_kv_offset,
kv_end, warp_idx, lane_idx);
cp_async::commit_group();
page_produce_kv<true, KTraits>(smem_storage, &v_smem_offset_w, v,
kv_start + kv_tile_idx * CTA_TILE_KV, thr_local_kv_offset,
kv_end, warp_idx, lane_idx);
cp_async::commit_group();
// loop with mask
LOOP_SPLIT_MASK(
kv_tile_idx, kv_tile_idx >= mask_tile_idx && kv_tile_idx > 0,
kv_tile_idx + 1 > NUM_STAGES, {
prefetch_offest<KTraits>(block_iter_base + (kv_tile_idx - 1) * CTA_TILE_KV,
packed_kv_bound, kv_head_idx, k_stride_page, k_stride_h,
k_stride_n, block_size, kv_indices, thr_local_kv_offset);
cp_async::wait_group<1>();
__syncthreads();
compute_qk<KTraits>(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
if constexpr (AttentionVariant::use_logits_soft_cap) {
logits_transform<KTraits>(
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
kv_start + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) *
NUM_MMA_KV * 16,
q_len, kv_len, gqa_group_size, s_frag, tid, kv_head_idx);
}
if constexpr (WITH_MASK) {
logits_mask<KTraits>(
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
kv_start + (kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) *
NUM_MMA_KV * 16,
q_len, kv_len, kv_end, gqa_group_size, s_frag, tid, kv_head_idx);
}
update_mdo_states<KTraits>(variant, s_frag, o_frag, m, d);
__syncthreads();
page_produce_kv<false, KTraits>(smem_storage, &k_smem_offset_w, k,
kv_start + (kv_tile_idx - 1) * CTA_TILE_KV,
thr_local_kv_offset, kv_end, warp_idx, lane_idx);
cp_async::commit_group();
cp_async::wait_group<1>();
__syncthreads();
compute_sfm_v<KTraits>(&v_smem, &v_smem_offset_r, s_frag, o_frag, d);
__syncthreads();
page_produce_kv<true, KTraits>(smem_storage, &v_smem_offset_w, v,
kv_start + (kv_tile_idx - 1) * CTA_TILE_KV,
thr_local_kv_offset, kv_end, warp_idx, lane_idx);
cp_async::commit_group();
});
cp_async::wait_group<0>();
__syncthreads();
#pragma unroll
for (; kv_tile_idx >= 0; --kv_tile_idx) {
compute_qk<KTraits>(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag);
if constexpr (AttentionVariant::use_logits_soft_cap) {
logits_transform<KTraits>(
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
kv_start +
(kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
q_len, kv_len, gqa_group_size, s_frag, tid, kv_head_idx);
}
logits_mask<KTraits>(
params, variant, /*batch_idx=*/0, qo_packed_idx_base,
kv_start +
(kv_tile_idx * NUM_WARPS_KV + get_warp_idx_kv<KTraits>(tid.z)) * NUM_MMA_KV * 16,
q_len, kv_len, kv_end, gqa_group_size, s_frag, tid, kv_head_idx);
update_mdo_states<KTraits>(variant, s_frag, o_frag, m, d);
compute_sfm_v<KTraits>(&v_smem, &v_smem_offset_r, s_frag, o_frag, d);
}
__syncthreads();
finalize_m<KTraits>(variant, m);
// threadblock synchronization
threadblock_sync_mdo_states<KTraits>(o_frag, smem_storage, m, d, warp_idx, lane_idx, tid);
// normalize d
normalize_d<KTraits>(o_frag, m, d);
// write back to global memory
// o_indptr (partial_o): [packed_qo_len * num_kv_chunks, num_kv_heads, head_dim]
// q_indpt (final_o): [qo_len, num_kv_heads, gqa_group_size, head_dim]
if (num_kv_chunks > 1) {
DTypeO* o_ptr_base = params.partial_o +
((o_indptr + kv_chunk_idx) * num_kv_heads + kv_head_idx) * HEAD_DIM_VO;
write_o_<KTraits>(o_frag, &q_smem, o_ptr_base, qo_packed_idx_base, packed_qo_start,
qo_upperbound, num_kv_chunks * num_kv_heads * HEAD_DIM_VO, gqa_group_size,
warp_idx, lane_idx, tid);
} else {
// write through
DTypeO* o_ptr_base =
params.final_o + q_indptr * q_stride_n + (kv_head_idx * gqa_group_size) * q_stride_h;
write_o_reg_gmem<KTraits>(o_frag, &q_smem, o_ptr_base, qo_packed_idx_base, q_len,
q_stride_n, q_stride_h, gqa_group_size, tid);
}
if constexpr (variant.use_softmax) {
if (get_warp_idx_kv<KTraits>(tid.z) == 0) {
#pragma unroll
for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
uint32_t q, r;
const uint32_t packed_qo_idx = qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16;
gqa_group_size.divmod(packed_qo_idx, q, r);
if (q < qo_upperbound) {
if (num_kv_chunks > 1) {
partial_lse[(o_indptr + (packed_qo_idx - packed_qo_start) * num_kv_chunks +
kv_chunk_idx) *
num_kv_heads +
kv_head_idx] = math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]);
} else if (params.final_lse != nullptr) {
// write through
const uint32_t qo_head_idx = kv_head_idx * gqa_group_size + r;
params.final_lse[(q_indptr + q) * num_kv_heads * gqa_group_size + qo_head_idx] =
math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]);
}
}
}
}
}
}
// profile
if constexpr (CTA_TILE_Q > 64) {
PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kRunner1);
} else {
PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kRunner2);
}
}
}
};
template <uint32_t HEAD_DIM_VO_, uint32_t NUM_SMEM_STAGES_, uint32_t NUM_THREADS_,
typename DTypeIn_, typename DTypeO_, typename IdType_>
struct StateReductionKernelTraits {
using DTypeIn = DTypeIn_;
using DTypeO = DTypeO_;
using IdType = IdType_;
static constexpr uint32_t HEAD_DIM_VO = HEAD_DIM_VO_;
static constexpr uint32_t NUM_SMEM_STAGES = NUM_SMEM_STAGES_;
static constexpr uint32_t NUM_THREADS = NUM_THREADS_;
static constexpr uint32_t NUM_WARPS = NUM_THREADS / 32;
static constexpr uint32_t vec_size =
std::max<uint32_t>(16U / static_cast<uint32_t>(sizeof(DTypeIn)), HEAD_DIM_VO / 32U);
static constexpr uint32_t bdx = HEAD_DIM_VO / vec_size;
// gridDim is accessed by runtime variable and should be set by core attention
// workload layout [bdx, bdy, num_warps]
static_assert(NUM_THREADS % bdx == 0);
static constexpr uint32_t bdy = 32 / bdx;
// pipeline load & reduction
static constexpr size_t SMEM_SIZE =
NUM_WARPS * NUM_SMEM_STAGES * bdy * HEAD_DIM_VO * sizeof(DTypeIn) +
NUM_THREADS * sizeof(float);
};
template <typename KTraits_>
struct BlockBatchReductionPersistent {
using KTraits = KTraits_;
static __device__ __forceinline__ void Run(
typename KTraits::DTypeIn* __restrict__ V, typename KTraits::DTypeO* __restrict__ v_merged,
float* __restrict__ S, float* __restrict__ s_merged,
const typename KTraits::IdType num_packed_qo_len, const uint_fastdiv gqa_group_size,
const uint32_t num_kv_heads, const typename KTraits::IdType* indptr,
const typename KTraits::IdType* o_indices, uint8_t* smem PROFILER_CLOSURE_FUNC_PARAMS) {
__syncthreads(); // NOTE(Zihao): required for guarantee correctness on blackwell
using DTypeIn = typename KTraits::DTypeIn;
using DTypeO = typename KTraits::DTypeO;
using IdType = typename KTraits::IdType;
[[maybe_unused]] constexpr uint32_t bdx = KTraits::bdx;
[[maybe_unused]] constexpr uint32_t bdy = KTraits::bdy;
[[maybe_unused]] constexpr uint32_t num_warps = KTraits::NUM_WARPS;
[[maybe_unused]] constexpr uint32_t vec_size = KTraits::vec_size;
[[maybe_unused]] constexpr uint32_t head_dim = KTraits::HEAD_DIM_VO;
[[maybe_unused]] constexpr uint32_t num_smem_stages = KTraits::NUM_SMEM_STAGES;
[[maybe_unused]] constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8;
// control flow metadata
const uint32_t warp_idx = threadIdx.x / 32;
const uint32_t tx = (threadIdx.x % 32) % bdx, ty = (threadIdx.x % 32) / bdx;
const uint32_t worker_id = blockIdx.y * num_warps + warp_idx;
const uint32_t num_workers = gridDim.x * gridDim.y * gridDim.z * num_warps;
DTypeIn* v_smem = (DTypeIn*)smem + warp_idx * num_smem_stages * bdy * head_dim;
// FIXME: fix the offset calculation
float* s_smem = (float*)(smem + num_warps * num_smem_stages * bdy * head_dim * sizeof(DTypeIn) +
warp_idx * 32 * sizeof(float));
// V: [num_packed_qo_len x num_kv_tiles, num_kv_heads, head_dim]
// v_merged: [qo_len, num_kv_heads, gqa_group_size, head_dim]
#pragma unroll 1
for (uint32_t i = worker_id; i < num_packed_qo_len * num_kv_heads; i += num_workers) {
PROFILER_EVENT_START(profiler_closure, PersistentProfileEventType::kReduction);
// remap workload
uint32_t packed_qo_idx = i / num_kv_heads;
uint32_t kv_head_idx = i % num_kv_heads;
const uint32_t num_index_sets = indptr[packed_qo_idx + 1] - indptr[packed_qo_idx];
if (num_index_sets == 0 || num_index_sets == 1) {
// already write through, bypass
PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kReduction);
continue;
}
// index calculation
auto partial_idx_to_offset = [&](uint32_t off) {
return (indptr[packed_qo_idx] + off) * num_kv_heads + kv_head_idx;
};
auto merge_idx_to_offset = [&]() {
// NOTE (Yilong): qo_head_idx has been calculated in schedule.plan
return o_indices[packed_qo_idx] + kv_head_idx * gqa_group_size;
};
state_t<vec_size> st;
#pragma unroll
for (uint32_t iter = 0; iter < num_smem_stages; ++iter) {
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
v_smem + (iter * bdy + ty) * head_dim + tx * vec_size,
V + partial_idx_to_offset(iter * bdy + ty) * head_dim + tx * vec_size,
(iter * bdy + ty) < num_index_sets);
cp_async::commit_group();
}
#pragma unroll 4
for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) {
if (iter % bdx == 0) {
s_smem[ty * bdx + tx] = iter * bdy + (ty * bdx + tx) < num_index_sets
? S[partial_idx_to_offset(iter * bdy + ty * bdx + tx)]
: 0.f;
__syncwarp();
}
cp_async::wait_group<num_smem_stages - 1>();
__syncwarp();
vec_t<float, vec_size> v;
v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size);
if (iter * bdy + ty < num_index_sets) {
float s = s_smem[(iter % bdx) * bdy + ty];
st.merge(v, s, 1);
}
__syncwarp();
cp_async::pred_load<vec_bits, PrefetchMode::kPrefetch, SharedMemFillMode::kNoFill>(
v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size,
V + partial_idx_to_offset((iter + num_smem_stages) * bdy + ty) * head_dim +
tx * vec_size,
(iter + num_smem_stages) * bdy + ty < num_index_sets);
cp_async::commit_group();
}
cp_async::wait_group<0>();
__syncwarp();
st.normalize();
if constexpr (bdy > 1) {
warp_sync_state<bdx, bdy, vec_size>(st, v_smem, s_smem, tx, ty);
st.normalize();
}
st.o.cast_store(v_merged + merge_idx_to_offset() * head_dim + tx * vec_size);
if (s_merged != nullptr) {
s_merged[merge_idx_to_offset()] = st.get_lse();
}
PROFILER_EVENT_END(profiler_closure, PersistentProfileEventType::kReduction);
}
}
};
template <uint32_t CTA_TILE_Q_1, uint32_t CTA_TILE_Q_2, uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO,
MaskMode MASK_MODE, typename AttentionVariant, typename Params>
cudaError_t BatchPagedAttentionPersistent(const Params params_1, const Params params_2,
const uint32_t num_blks_x, const uint32_t num_blks_y,
const cudaStream_t stream) {
using DTypeQ = typename Params::DTypeQ;
using DTypeKV = typename Params::DTypeKV;
using DTypeO = typename Params::DTypeO;
using IdType = typename Params::IdType;
constexpr uint32_t NUM_WARPS_Q_1 = get_num_warps_q(CTA_TILE_Q_1);
constexpr uint32_t NUM_WARPS_KV_1 = get_num_warps_kv(CTA_TILE_Q_1);
constexpr uint32_t NUM_MMA_Q_1 = get_num_mma_q(CTA_TILE_Q_1);
constexpr uint32_t NUM_MMA_KV_1 = 4;
constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16;
constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16;
using KTraits1 = KernelTraits<MASK_MODE, CTA_TILE_Q_1, NUM_MMA_Q_1, NUM_MMA_KV_1, NUM_MMA_D_QK,
NUM_MMA_D_VO, NUM_WARPS_Q_1, NUM_WARPS_KV_1, PosEncodingMode::kNone,
DTypeQ, DTypeKV, DTypeO, float, IdType, AttentionVariant>;
constexpr uint32_t NUM_WARPS_Q_2 = get_num_warps_q(CTA_TILE_Q_2);
constexpr uint32_t NUM_WARPS_KV_2 = get_num_warps_kv(CTA_TILE_Q_2);
constexpr uint32_t NUM_MMA_Q_2 = get_num_mma_q(CTA_TILE_Q_2);
constexpr uint32_t NUM_MMA_KV_2 = 2;
using KTraits2 = KernelTraits<MASK_MODE, CTA_TILE_Q_2, NUM_MMA_Q_2, NUM_MMA_KV_2, NUM_MMA_D_QK,
NUM_MMA_D_VO, NUM_WARPS_Q_2, NUM_WARPS_KV_2, PosEncodingMode::kNone,
DTypeQ, DTypeKV, DTypeO, float, IdType, AttentionVariant>;
// Attention state reduction kernel
constexpr uint32_t NUM_THREADS =
KTraits1::NUM_THREADS > KTraits2::NUM_THREADS ? KTraits1::NUM_THREADS : KTraits2::NUM_THREADS;
using ReductionKTraits =
StateReductionKernelTraits<HEAD_DIM_VO, 4, NUM_THREADS, DTypeO, DTypeO, IdType>;
size_t smem_size =
max(sizeof(typename KTraits1::SharedStorage), sizeof(typename KTraits2::SharedStorage));
smem_size = max(smem_size, ReductionKTraits::SMEM_SIZE);
// Launch persistent kernel
auto kernel = PersistentKernelTemplate<BlockBatchPagedAttentionPersistent<KTraits1, Params>,
BlockBatchPagedAttentionPersistent<KTraits2, Params>,
BlockBatchReductionPersistent<ReductionKTraits>>;
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
dim3 nblks(num_blks_x, num_blks_y);
dim3 nthrs(NUM_THREADS);
void* args[] = {(void*)&params_1, (void*)&params_2};
FLASHINFER_CUDA_CALL(
cudaLaunchCooperativeKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
return cudaSuccess;
}
}; // namespace flashinfer
#endif // FLASHINFER_PERSISTENT_CUH_