1049 lines
48 KiB
Plaintext
1049 lines
48 KiB
Plaintext
/*
|
|
* Copyright (c) 2023 by FlashInfer team.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#ifndef FLASHINFER_MLA_FA2_CUH_
|
|
#define FLASHINFER_MLA_FA2_CUH_
|
|
#include <cooperative_groups.h>
|
|
|
|
#include <cstdint>
|
|
#include <sstream>
|
|
|
|
#include "../profiler.cuh"
|
|
#include "mla_params.cuh"
|
|
#include "prefill.cuh"
|
|
#include "variant_helper.cuh"
|
|
|
|
namespace flashinfer {
|
|
|
|
namespace mla {
|
|
|
|
struct StandardAttention : AttentionVariantBase {
|
|
float sm_scale_log2;
|
|
|
|
PROFILER_CLOSURE_PARAMS_DECL
|
|
|
|
template <typename Params>
|
|
__device__ __host__ StandardAttention(const Params& params, uint32_t batch_idx,
|
|
uint8_t* smem_ptr) {
|
|
sm_scale_log2 = params.sm_scale * math::log2e;
|
|
}
|
|
};
|
|
|
|
template <uint32_t NUM_STAGES, uint32_t CTA_TILE_Q, uint32_t CTA_TILE_KV, uint32_t HEAD_DIM_CKV,
|
|
uint32_t HEAD_DIM_KPE, typename DTypeQ, typename DTypeKV, typename DTypeO>
|
|
struct SharedStorageQKVO {
|
|
union {
|
|
struct {
|
|
alignas(16) DTypeQ q_smem_nope[CTA_TILE_Q * HEAD_DIM_CKV];
|
|
alignas(16) DTypeQ q_smem_pe[CTA_TILE_Q * HEAD_DIM_KPE];
|
|
alignas(16) DTypeKV ckv_smem[NUM_STAGES][CTA_TILE_KV * HEAD_DIM_CKV];
|
|
alignas(16) DTypeKV
|
|
kpe_p_smem[NUM_STAGES]
|
|
[CTA_TILE_KV * (HEAD_DIM_KPE > CTA_TILE_Q ? HEAD_DIM_KPE : CTA_TILE_Q)];
|
|
union {
|
|
alignas(16) float m_wg[2][CTA_TILE_Q]; // cross warpgroup synchronization
|
|
alignas(16) float d_wg[2][CTA_TILE_Q]; // cross warpgroup synchronization
|
|
};
|
|
};
|
|
alignas(16) DTypeO o_smem[CTA_TILE_Q * HEAD_DIM_CKV];
|
|
};
|
|
};
|
|
|
|
template <bool CAUSAL_, uint32_t NUM_STAGES_, bool QK_SHARD_, uint32_t HEAD_DIM_CKV_,
|
|
uint32_t HEAD_DIM_KPE_, uint32_t CTA_TILE_Q_, uint32_t CTA_TILE_KV_, typename DTypeQ_,
|
|
typename DTypeKV_, typename DTypeO_, typename IdType_>
|
|
struct KernelTraits {
|
|
static constexpr bool CAUSAL = CAUSAL_;
|
|
static constexpr uint32_t NUM_STAGES = NUM_STAGES_;
|
|
// NOTE(Zihao): whether to shard Q*K computation across warpgroups
|
|
// if true, each warpgroup will compute a subset of Q*K (sharded on the KV dimension)
|
|
// if false, each warpgroup will compute the full Q*K, which is duplicated across warpgroups
|
|
static constexpr bool QK_SHARD = QK_SHARD_;
|
|
static constexpr uint32_t NUM_MMA_KV = CTA_TILE_KV_ / 16;
|
|
static constexpr uint32_t HEAD_DIM_CKV = HEAD_DIM_CKV_;
|
|
static constexpr uint32_t HEAD_DIM_KPE = HEAD_DIM_KPE_;
|
|
static constexpr uint32_t HEAD_DIM_ALL = HEAD_DIM_CKV + HEAD_DIM_KPE;
|
|
static constexpr uint32_t NUM_MMA_D_CKV = HEAD_DIM_CKV / 16;
|
|
static constexpr uint32_t NUM_MMA_D_KPE = HEAD_DIM_KPE / 16;
|
|
static constexpr uint32_t NUM_THREADS = 256;
|
|
static constexpr uint32_t CTA_TILE_Q = CTA_TILE_Q_;
|
|
static constexpr uint32_t CTA_TILE_KV = CTA_TILE_KV_;
|
|
|
|
static constexpr SwizzleMode SWIZZLE_MODE_Q_NOPE = SwizzleMode::k128B;
|
|
static constexpr SwizzleMode SWIZZLE_MODE_Q_PE = SwizzleMode::k128B;
|
|
static constexpr SwizzleMode SWIZZLE_MODE_CKV = SwizzleMode::k128B;
|
|
static constexpr SwizzleMode SWIZZLE_MODE_KPE = SwizzleMode::k128B;
|
|
static constexpr SwizzleMode SWIZZLE_MODE_P =
|
|
CTA_TILE_KV >= 64 ? SwizzleMode::k128B : SwizzleMode::k64B;
|
|
static constexpr SwizzleMode SWIZZLE_MODE_O = SwizzleMode::k128B;
|
|
static constexpr uint32_t UPCAST_STRIDE_Q_NOPE = HEAD_DIM_CKV / upcast_size<DTypeQ_>();
|
|
static constexpr uint32_t UPCAST_STRIDE_Q_PE = HEAD_DIM_KPE / upcast_size<DTypeQ_>();
|
|
static constexpr uint32_t UPCAST_STRIDE_CKV = HEAD_DIM_CKV / upcast_size<DTypeKV_>();
|
|
static constexpr uint32_t UPCAST_STRIDE_KPE = HEAD_DIM_KPE / upcast_size<DTypeKV_>();
|
|
static constexpr uint32_t UPCAST_STRIDE_FINAL_O = HEAD_DIM_CKV / upcast_size<DTypeO_>();
|
|
static constexpr uint32_t UPCAST_STRIDE_P = CTA_TILE_KV / upcast_size<DTypeKV_>();
|
|
|
|
using DTypeQ = DTypeQ_;
|
|
using DTypeKV = DTypeKV_;
|
|
using DTypeO = DTypeO_;
|
|
using IdType = IdType_;
|
|
using DTypeQKAccum = float;
|
|
|
|
using SharedStorage = SharedStorageQKVO<NUM_STAGES, CTA_TILE_Q, CTA_TILE_KV, HEAD_DIM_CKV,
|
|
HEAD_DIM_KPE, DTypeQ, DTypeKV, DTypeO>;
|
|
using AttentionVariant = StandardAttention;
|
|
|
|
static constexpr DTypeQKAccum MaskFillValue = -math::inf;
|
|
};
|
|
|
|
template <typename KTraits>
|
|
__device__ __forceinline__ void init_states_(float (*o_frag)[8], typename KTraits::DTypeQKAccum* m,
|
|
float* d) {
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_CKV / 2; ++mma_d) {
|
|
#pragma unroll
|
|
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
|
|
o_frag[mma_d][reg_id] = 0.f;
|
|
}
|
|
}
|
|
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 2; ++j) {
|
|
m[j] = typename KTraits::DTypeQKAccum(-math::inf);
|
|
d[j] = 1.f;
|
|
}
|
|
}
|
|
|
|
template <typename KTraits>
|
|
__device__ __forceinline__ void load_q(
|
|
typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeQ* q_nope,
|
|
typename KTraits::DTypeQ* q_pe, const uint32_t q_nope_stride_n, const uint32_t q_nope_stride_h,
|
|
const uint32_t q_pe_stride_n, const uint32_t q_pe_stride_h, const uint32_t q_len,
|
|
const uint32_t packed_offset, const uint_fastdiv& num_heads) {
|
|
using DTypeQ = typename KTraits::DTypeQ;
|
|
constexpr uint32_t UPCAST_STRIDE_Q_NOPE = KTraits::UPCAST_STRIDE_Q_NOPE;
|
|
constexpr uint32_t UPCAST_STRIDE_Q_PE = KTraits::UPCAST_STRIDE_Q_PE;
|
|
constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV;
|
|
constexpr uint32_t NUM_MMA_D_KPE = KTraits::NUM_MMA_D_KPE;
|
|
const uint32_t lane_idx = threadIdx.x;
|
|
const uint32_t warpgroup_idx = threadIdx.z;
|
|
const uint32_t warp_idx_in_wg = threadIdx.y;
|
|
|
|
smem_t<KTraits::SWIZZLE_MODE_Q_NOPE> q_smem_nope(smem_storage->q_smem_nope);
|
|
smem_t<KTraits::SWIZZLE_MODE_Q_PE> q_smem_pe(smem_storage->q_smem_pe);
|
|
|
|
#pragma unroll
|
|
for (uint32_t mma_q = 0; mma_q < 2; ++mma_q) {
|
|
uint32_t q, r;
|
|
num_heads.divmod(
|
|
packed_offset + lane_idx / 8 + (warpgroup_idx + mma_q * 2) * 16 + warp_idx_in_wg * 4, q, r);
|
|
DTypeQ* q_nope_ptr =
|
|
q_nope + q * q_nope_stride_n + r * q_nope_stride_h + (lane_idx % 8) * upcast_size<DTypeQ>();
|
|
DTypeQ* q_pe_ptr =
|
|
q_pe + q * q_pe_stride_n + r * q_pe_stride_h + (lane_idx % 8) * upcast_size<DTypeQ>();
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_CKV / 4; ++mma_d) {
|
|
uint32_t q_smem_nope_offset_w =
|
|
q_smem_nope.template get_permuted_offset<UPCAST_STRIDE_Q_NOPE>(
|
|
32 * mma_q + warpgroup_idx * 16 + warp_idx_in_wg * 4 + lane_idx / 8,
|
|
mma_d * 8 + lane_idx % 8);
|
|
q_smem_nope.load_128b_async<SharedMemFillMode::kFillZero>(q_smem_nope_offset_w, q_nope_ptr,
|
|
q < q_len);
|
|
q_nope_ptr += 8 * upcast_size<DTypeQ>();
|
|
}
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_KPE / 4; ++mma_d) {
|
|
uint32_t q_smem_pe_offset_w = q_smem_pe.template get_permuted_offset<UPCAST_STRIDE_Q_PE>(
|
|
32 * mma_q + warpgroup_idx * 16 + warp_idx_in_wg * 4 + lane_idx / 8,
|
|
mma_d * 8 + lane_idx % 8);
|
|
|
|
q_smem_pe.load_128b_async<SharedMemFillMode::kFillZero>(q_smem_pe_offset_w, q_pe_ptr,
|
|
q < q_len);
|
|
q_pe_ptr += 8 * upcast_size<DTypeQ>();
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename KTraits>
|
|
__device__ __forceinline__ void load_kv(
|
|
typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeKV* ckv,
|
|
typename KTraits::DTypeKV* kpe, typename KTraits::IdType* indices, const uint32_t ckv_stride_n,
|
|
const uint32_t ckv_stride_page, const uint32_t kpe_stride_n, const uint32_t kpe_stride_page,
|
|
const uint32_t packed_kv_bound, const uint32_t packed_block_iter_base,
|
|
const uint_fastdiv& block_size, const uint32_t stage_idx) {
|
|
using DTypeKV = typename KTraits::DTypeKV;
|
|
constexpr uint32_t UPCAST_STRIDE_CKV = KTraits::UPCAST_STRIDE_CKV;
|
|
constexpr uint32_t UPCAST_STRIDE_KPE = KTraits::UPCAST_STRIDE_KPE;
|
|
constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV;
|
|
constexpr uint32_t NUM_MMA_D_KPE = KTraits::NUM_MMA_D_KPE;
|
|
const uint32_t lane_idx = threadIdx.x;
|
|
const uint32_t warpgroup_idx = threadIdx.z;
|
|
const uint32_t warp_idx_in_wg = threadIdx.y;
|
|
|
|
smem_t<KTraits::SWIZZLE_MODE_CKV> ckv_smem(smem_storage->ckv_smem[stage_idx]);
|
|
smem_t<KTraits::SWIZZLE_MODE_KPE> kpe_smem(smem_storage->kpe_p_smem[stage_idx]);
|
|
|
|
if constexpr (KTraits::NUM_MMA_KV == 1) {
|
|
if (warpgroup_idx == 0) {
|
|
uint32_t q, r;
|
|
uint32_t packed_block_iter = packed_block_iter_base + lane_idx / 8 + warp_idx_in_wg * 4;
|
|
block_size.divmod(packed_block_iter, q, r);
|
|
|
|
DTypeKV* ckv_ptr = ckv +
|
|
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
|
|
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
|
|
DTypeKV* kpe_ptr = kpe +
|
|
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
|
|
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
|
|
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_CKV / 4; ++mma_d) {
|
|
uint32_t ckv_smem_offset_w = ckv_smem.template get_permuted_offset<UPCAST_STRIDE_CKV>(
|
|
warp_idx_in_wg * 4 + lane_idx / 8, 8 * mma_d + lane_idx % 8);
|
|
ckv_smem.load_128b_async<SharedMemFillMode::kFillZero>(ckv_smem_offset_w, ckv_ptr,
|
|
packed_block_iter < packed_kv_bound);
|
|
ckv_ptr += 8 * upcast_size<DTypeKV>();
|
|
}
|
|
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_KPE / 4; ++mma_d) {
|
|
uint32_t kpe_smem_offset_w = kpe_smem.template get_permuted_offset<UPCAST_STRIDE_KPE>(
|
|
warp_idx_in_wg * 4 + lane_idx / 8, 8 * mma_d + lane_idx % 8);
|
|
kpe_smem.load_128b_async<SharedMemFillMode::kFillZero>(kpe_smem_offset_w, kpe_ptr,
|
|
packed_block_iter < packed_kv_bound);
|
|
kpe_ptr += 8 * upcast_size<DTypeKV>();
|
|
}
|
|
}
|
|
} else {
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV / 2; ++mma_kv) {
|
|
uint32_t q, r;
|
|
uint32_t packed_block_iter = packed_block_iter_base + lane_idx / 8 +
|
|
(warpgroup_idx + mma_kv * 2) * 16 + warp_idx_in_wg * 4;
|
|
block_size.divmod(packed_block_iter, q, r);
|
|
|
|
DTypeKV* ckv_ptr = ckv +
|
|
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * ckv_stride_page +
|
|
r * ckv_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
|
|
DTypeKV* kpe_ptr = kpe +
|
|
(packed_block_iter < packed_kv_bound ? indices[q] : 0) * kpe_stride_page +
|
|
r * kpe_stride_n + (lane_idx % 8) * upcast_size<DTypeKV>();
|
|
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_CKV / 4; ++mma_d) {
|
|
uint32_t ckv_smem_offset_w = ckv_smem.template get_permuted_offset<UPCAST_STRIDE_CKV>(
|
|
32 * mma_kv + warpgroup_idx * 16 + warp_idx_in_wg * 4 + lane_idx / 8,
|
|
8 * mma_d + lane_idx % 8);
|
|
ckv_smem.load_128b_async<SharedMemFillMode::kFillZero>(ckv_smem_offset_w, ckv_ptr,
|
|
packed_block_iter < packed_kv_bound);
|
|
ckv_ptr += 8 * upcast_size<DTypeKV>();
|
|
}
|
|
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_KPE / 4; ++mma_d) {
|
|
uint32_t kpe_smem_offset_w = kpe_smem.template get_permuted_offset<UPCAST_STRIDE_KPE>(
|
|
32 * mma_kv + warpgroup_idx * 16 + warp_idx_in_wg * 4 + lane_idx / 8,
|
|
8 * mma_d + lane_idx % 8);
|
|
kpe_smem.load_128b_async<SharedMemFillMode::kFillZero>(kpe_smem_offset_w, kpe_ptr,
|
|
packed_block_iter < packed_kv_bound);
|
|
kpe_ptr += 8 * upcast_size<DTypeKV>();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <bool init, typename KTraits, uint32_t NUM_MMA_D_QK, uint32_t UPCAST_STRIDE_Q,
|
|
uint32_t UPCAST_STRIDE_K, SwizzleMode SWIZZLE_MODE_Q, SwizzleMode SWIZZLE_MODE_KV>
|
|
__device__ __forceinline__ void compute_qk_(smem_t<SWIZZLE_MODE_Q> q_smem,
|
|
smem_t<SWIZZLE_MODE_KV> k_smem,
|
|
typename KTraits::DTypeQKAccum (*s_frag)[8]) {
|
|
const uint32_t lane_idx = threadIdx.x, warpgroup_idx = threadIdx.z, warp_idx_in_wg = threadIdx.y;
|
|
alignas(16) uint32_t q_frag[4], k_frag[4];
|
|
// compute q*k^T
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_QK; ++mma_d) {
|
|
uint32_t q_smem_offset_r = q_smem.template get_permuted_offset<UPCAST_STRIDE_Q>(
|
|
warp_idx_in_wg * 16 + lane_idx % 16, mma_d * 2 + lane_idx / 16);
|
|
q_smem.ldmatrix_m8n8x4(q_smem_offset_r, q_frag);
|
|
|
|
if constexpr (KTraits::QK_SHARD) {
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV / 2; ++mma_kv) {
|
|
uint32_t k_smem_offset_r = k_smem.template get_permuted_offset<UPCAST_STRIDE_K>(
|
|
(warpgroup_idx * (KTraits::NUM_MMA_KV / 2) + mma_kv) * 16 + 8 * (lane_idx / 16) +
|
|
lane_idx % 8,
|
|
2 * mma_d + (lane_idx % 16) / 8);
|
|
|
|
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
|
|
|
if (init && mma_d == 0) {
|
|
mma::mma_sync_m16n16k16_row_col_f16f16f32<typename KTraits::DTypeQ, MMAMode::kInit>(
|
|
s_frag[mma_kv], q_frag, k_frag);
|
|
} else {
|
|
mma::mma_sync_m16n16k16_row_col_f16f16f32<typename KTraits::DTypeQ>(s_frag[mma_kv],
|
|
q_frag, k_frag);
|
|
}
|
|
}
|
|
} else {
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) {
|
|
uint32_t k_smem_offset_r = k_smem.template get_permuted_offset<UPCAST_STRIDE_K>(
|
|
mma_kv * 16 + 8 * (lane_idx / 16) + lane_idx % 8, 2 * mma_d + (lane_idx % 16) / 8);
|
|
|
|
k_smem.ldmatrix_m8n8x4(k_smem_offset_r, k_frag);
|
|
|
|
if (init && mma_d == 0) {
|
|
mma::mma_sync_m16n16k16_row_col_f16f16f32<typename KTraits::DTypeQ, MMAMode::kInit>(
|
|
s_frag[mma_kv], q_frag, k_frag);
|
|
} else {
|
|
mma::mma_sync_m16n16k16_row_col_f16f16f32<typename KTraits::DTypeQ>(s_frag[mma_kv],
|
|
q_frag, k_frag);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename KTraits>
|
|
__device__ __forceinline__ void logits_mask_(const uint32_t qo_packed_idx_base,
|
|
const uint32_t kv_idx_base, const uint32_t qo_len,
|
|
const uint32_t kv_len, const uint32_t kv_end,
|
|
const uint_fastdiv num_heads,
|
|
typename KTraits::DTypeQKAccum (*s_frag)[8]) {
|
|
const uint32_t lane_idx = threadIdx.x, warpgroup_idx = threadIdx.z, warp_idx_in_wg = threadIdx.y;
|
|
constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV;
|
|
using DTypeQKAccum = typename KTraits::DTypeQKAccum;
|
|
uint32_t q[2];
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 2; ++j) {
|
|
q[j] = (qo_packed_idx_base + warp_idx_in_wg * 16 + lane_idx / 4 + 8 * j) / num_heads;
|
|
}
|
|
|
|
if constexpr (KTraits::QK_SHARD) {
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV / 2; ++mma_kv) {
|
|
#pragma unroll
|
|
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
|
|
const uint32_t q_idx = q[(reg_id % 4) / 2],
|
|
kv_idx = kv_idx_base + warpgroup_idx * (NUM_MMA_KV / 2) * 16 + mma_kv * 16 +
|
|
2 * (lane_idx % 4) + 8 * (reg_id / 4) + reg_id % 2;
|
|
const bool mask =
|
|
(!(KTraits::CAUSAL ? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= kv_end))
|
|
: kv_idx >= kv_end));
|
|
s_frag[mma_kv][reg_id] = (mask) ? s_frag[mma_kv][reg_id] : (KTraits::MaskFillValue);
|
|
}
|
|
}
|
|
} else {
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) {
|
|
#pragma unroll
|
|
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
|
|
const uint32_t q_idx = q[(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 +
|
|
2 * (lane_idx % 4) + 8 * (reg_id / 4) +
|
|
reg_id % 2;
|
|
const bool mask =
|
|
(!(KTraits::CAUSAL ? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= kv_end))
|
|
: kv_idx >= kv_end));
|
|
s_frag[mma_kv][reg_id] = (mask) ? s_frag[mma_kv][reg_id] : (KTraits::MaskFillValue);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename KTraits>
|
|
__device__ __forceinline__ void update_mdo_states_(typename KTraits::SharedStorage* smem_storage,
|
|
const uint32_t stage_idx,
|
|
typename KTraits::AttentionVariant variant,
|
|
typename KTraits::DTypeQKAccum (*s_frag)[8],
|
|
float (*o_frag)[8],
|
|
typename KTraits::DTypeQKAccum* m, float* d) {
|
|
using DTypeQKAccum = typename KTraits::DTypeQKAccum;
|
|
using AttentionVariant = typename KTraits::AttentionVariant;
|
|
const float sm_scale = variant.sm_scale_log2;
|
|
const uint32_t warpgroup_idx = threadIdx.z, lane_idx = threadIdx.x, warp_idx_in_wg = threadIdx.y;
|
|
float m_prev[2];
|
|
if constexpr (KTraits::QK_SHARD) {
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 2; ++j) {
|
|
m_prev[j] = m[j];
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV / 2; ++mma_kv) {
|
|
float m_local = max(max(s_frag[mma_kv][j * 2 + 0], s_frag[mma_kv][j * 2 + 1]),
|
|
max(s_frag[mma_kv][j * 2 + 4], s_frag[mma_kv][j * 2 + 5]));
|
|
m[j] = max(m[j], m_local);
|
|
}
|
|
m[j] = max(m[j], math::shfl_xor_sync(m[j], 0x2));
|
|
m[j] = max(m[j], math::shfl_xor_sync(m[j], 0x1));
|
|
if (lane_idx % 4 == 0) {
|
|
smem_storage->m_wg[warpgroup_idx][warp_idx_in_wg * 16 + j * 8 + lane_idx / 4] = m[j];
|
|
}
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 2; ++j) {
|
|
m[j] = max(smem_storage->m_wg[0][warp_idx_in_wg * 16 + j * 8 + lane_idx / 4],
|
|
smem_storage->m_wg[1][warp_idx_in_wg * 16 + j * 8 + lane_idx / 4]);
|
|
float o_scale = math::ptx_exp2(m_prev[j] * sm_scale - m[j] * sm_scale);
|
|
d[j] *= o_scale;
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_CKV / 2; ++mma_d) {
|
|
o_frag[mma_d][j * 2 + 0] *= o_scale;
|
|
o_frag[mma_d][j * 2 + 1] *= o_scale;
|
|
o_frag[mma_d][j * 2 + 4] *= o_scale;
|
|
o_frag[mma_d][j * 2 + 5] *= o_scale;
|
|
}
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV / 2; ++mma_kv) {
|
|
s_frag[mma_kv][j * 2 + 0] =
|
|
math::ptx_exp2(s_frag[mma_kv][j * 2 + 0] * sm_scale - m[j] * sm_scale);
|
|
s_frag[mma_kv][j * 2 + 1] =
|
|
math::ptx_exp2(s_frag[mma_kv][j * 2 + 1] * sm_scale - m[j] * sm_scale);
|
|
s_frag[mma_kv][j * 2 + 4] =
|
|
math::ptx_exp2(s_frag[mma_kv][j * 2 + 4] * sm_scale - m[j] * sm_scale);
|
|
s_frag[mma_kv][j * 2 + 5] =
|
|
math::ptx_exp2(s_frag[mma_kv][j * 2 + 5] * sm_scale - m[j] * sm_scale);
|
|
}
|
|
}
|
|
} else {
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 2; ++j) {
|
|
m_prev[j] = m[j];
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) {
|
|
float m_local = max(max(s_frag[mma_kv][j * 2 + 0], s_frag[mma_kv][j * 2 + 1]),
|
|
max(s_frag[mma_kv][j * 2 + 4], s_frag[mma_kv][j * 2 + 5]));
|
|
m[j] = max(m[j], m_local);
|
|
}
|
|
m[j] = max(m[j], math::shfl_xor_sync(m[j], 0x2));
|
|
m[j] = max(m[j], math::shfl_xor_sync(m[j], 0x1));
|
|
}
|
|
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 2; ++j) {
|
|
float o_scale = math::ptx_exp2(m_prev[j] * sm_scale - m[j] * sm_scale);
|
|
d[j] *= o_scale;
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_CKV / 2; ++mma_d) {
|
|
o_frag[mma_d][j * 2 + 0] *= o_scale;
|
|
o_frag[mma_d][j * 2 + 1] *= o_scale;
|
|
o_frag[mma_d][j * 2 + 4] *= o_scale;
|
|
o_frag[mma_d][j * 2 + 5] *= o_scale;
|
|
}
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) {
|
|
s_frag[mma_kv][j * 2 + 0] =
|
|
math::ptx_exp2(s_frag[mma_kv][j * 2 + 0] * sm_scale - m[j] * sm_scale);
|
|
s_frag[mma_kv][j * 2 + 1] =
|
|
math::ptx_exp2(s_frag[mma_kv][j * 2 + 1] * sm_scale - m[j] * sm_scale);
|
|
s_frag[mma_kv][j * 2 + 4] =
|
|
math::ptx_exp2(s_frag[mma_kv][j * 2 + 4] * sm_scale - m[j] * sm_scale);
|
|
s_frag[mma_kv][j * 2 + 5] =
|
|
math::ptx_exp2(s_frag[mma_kv][j * 2 + 5] * sm_scale - m[j] * sm_scale);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename KTraits>
|
|
__device__ __forceinline__ void compute_mla_qk(typename KTraits::SharedStorage* smem_storage,
|
|
const uint32_t stage_idx,
|
|
typename KTraits::DTypeQKAccum (*s_frag)[8]) {
|
|
constexpr uint32_t UPCAST_STRIDE_Q_NOPE = KTraits::UPCAST_STRIDE_Q_NOPE;
|
|
constexpr uint32_t UPCAST_STRIDE_Q_PE = KTraits::UPCAST_STRIDE_Q_PE;
|
|
constexpr uint32_t UPCAST_STRIDE_CKV = KTraits::UPCAST_STRIDE_CKV;
|
|
constexpr uint32_t UPCAST_STRIDE_KPE = KTraits::UPCAST_STRIDE_KPE;
|
|
constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV;
|
|
smem_t<KTraits::SWIZZLE_MODE_Q_NOPE> q_smem_nope(smem_storage->q_smem_nope);
|
|
smem_t<KTraits::SWIZZLE_MODE_Q_PE> q_smem_pe(smem_storage->q_smem_pe);
|
|
smem_t<KTraits::SWIZZLE_MODE_CKV> ckv_smem(smem_storage->ckv_smem[stage_idx]);
|
|
smem_t<KTraits::SWIZZLE_MODE_KPE> kpe_smem(smem_storage->kpe_p_smem[stage_idx]);
|
|
const uint32_t lane_idx = threadIdx.x, warpgroup_idx = threadIdx.z, warp_idx_in_wg = threadIdx.y;
|
|
compute_qk_</*init=*/true, KTraits, KTraits::NUM_MMA_D_KPE, KTraits::UPCAST_STRIDE_Q_PE,
|
|
KTraits::UPCAST_STRIDE_KPE>(q_smem_pe, kpe_smem, s_frag);
|
|
compute_qk_</*init=*/false, KTraits, KTraits::NUM_MMA_D_CKV, KTraits::UPCAST_STRIDE_Q_NOPE,
|
|
KTraits::UPCAST_STRIDE_CKV>(q_smem_nope, ckv_smem, s_frag);
|
|
}
|
|
|
|
template <typename KTraits>
|
|
__device__ __forceinline__ void compute_mla_pv(typename KTraits::SharedStorage* smem_storage,
|
|
const uint32_t stage_idx,
|
|
typename KTraits::DTypeQKAccum (*s_frag)[8],
|
|
typename KTraits::DTypeQKAccum* d,
|
|
float (*o_frag)[8]) {
|
|
const uint32_t lane_idx = threadIdx.x, warpgroup_idx = threadIdx.z, warp_idx_in_wg = threadIdx.y;
|
|
constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV;
|
|
constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV;
|
|
constexpr uint32_t UPCAST_STRIDE_CKV = KTraits::UPCAST_STRIDE_CKV;
|
|
smem_t<KTraits::SWIZZLE_MODE_CKV> ckv_smem(smem_storage->ckv_smem[stage_idx]);
|
|
uint32_t ckv_smem_offset_r = ckv_smem.template get_permuted_offset<UPCAST_STRIDE_CKV>(
|
|
lane_idx % 16, warpgroup_idx * NUM_MMA_D_CKV + lane_idx / 16);
|
|
if constexpr (KTraits::QK_SHARD) {
|
|
// shard s_frag computation on KV dimension across warpgroups, need allgather
|
|
alignas(16) typename KTraits::DTypeKV p_f16[NUM_MMA_KV / 2][8];
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV / 2; ++mma_kv) {
|
|
vec_cast<typename KTraits::DTypeKV, float>::cast<8>(p_f16[mma_kv], s_frag[mma_kv]);
|
|
mma::m16k16_rowsum_f16f16f32(d, p_f16[mma_kv]);
|
|
}
|
|
|
|
__syncthreads();
|
|
smem_t<KTraits::SWIZZLE_MODE_P> p_smem(smem_storage->kpe_p_smem[stage_idx]);
|
|
constexpr uint32_t UPCAST_STRIDE_P = KTraits::UPCAST_STRIDE_P;
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV / 2; ++mma_kv) {
|
|
#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
|
|
uint32_t p_smem_offset_w = p_smem.template get_permuted_offset<UPCAST_STRIDE_P>(
|
|
warp_idx_in_wg * 16 + lane_idx % 16,
|
|
warpgroup_idx * NUM_MMA_KV + mma_kv * 2 + lane_idx / 16);
|
|
p_smem.stmatrix_m8n8x4(p_smem_offset_w, (uint32_t*)p_f16[mma_kv]);
|
|
#else
|
|
uint32_t p_smem_offset_w = p_smem.template get_permuted_offset<UPCAST_STRIDE_P>(
|
|
warp_idx_in_wg * 16 + lane_idx / 4, warpgroup_idx * NUM_MMA_KV + mma_kv * 2);
|
|
((uint32_t*)(p_smem.base + p_smem_offset_w))[lane_idx % 4] = *(uint32_t*)&p_f16[mma_kv][0];
|
|
((uint32_t*)(p_smem.base + p_smem_offset_w + 8 * UPCAST_STRIDE_P))[lane_idx % 4] =
|
|
*(uint32_t*)&p_f16[mma_kv][2];
|
|
((uint32_t*)(p_smem.base + (p_smem_offset_w ^ 0x1)))[lane_idx % 4] =
|
|
*(uint32_t*)&p_f16[mma_kv][4];
|
|
((uint32_t*)(p_smem.base + (p_smem_offset_w ^ 0x1) + 8 * UPCAST_STRIDE_P))[lane_idx % 4] =
|
|
*(uint32_t*)&p_f16[mma_kv][6];
|
|
#endif
|
|
}
|
|
uint32_t p_smem_offset_r = p_smem.template get_permuted_offset<UPCAST_STRIDE_P>(
|
|
warp_idx_in_wg * 16 + lane_idx % 16, lane_idx / 16);
|
|
|
|
// wait for p_smem to be filled
|
|
__syncthreads();
|
|
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) {
|
|
uint32_t p_frag[4];
|
|
p_smem.ldmatrix_m8n8x4(p_smem_offset_r, p_frag);
|
|
p_smem_offset_r = p_smem.template advance_offset_by_column<2>(p_smem_offset_r, mma_kv);
|
|
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) {
|
|
uint32_t v_frag[4];
|
|
ckv_smem.ldmatrix_m8n8x4_trans(ckv_smem_offset_r, v_frag);
|
|
mma::mma_sync_m16n16k16_row_col_f16f16f32<typename KTraits::DTypeKV>(o_frag[mma_d], p_frag,
|
|
v_frag);
|
|
ckv_smem_offset_r = ckv_smem.template advance_offset_by_column<2>(ckv_smem_offset_r, mma_d);
|
|
}
|
|
ckv_smem_offset_r =
|
|
ckv_smem.template advance_offset_by_row<16, UPCAST_STRIDE_CKV>(ckv_smem_offset_r) -
|
|
NUM_MMA_D_CKV;
|
|
}
|
|
} else {
|
|
// no need to store p_smem because all warpgroups are working on the same p
|
|
alignas(16) typename KTraits::DTypeKV p_f16[NUM_MMA_KV][8];
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) {
|
|
vec_cast<typename KTraits::DTypeKV, float>::cast<8>(p_f16[mma_kv], s_frag[mma_kv]);
|
|
mma::m16k16_rowsum_f16f16f32(d, p_f16[mma_kv]);
|
|
}
|
|
#pragma unroll
|
|
for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) {
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) {
|
|
uint32_t v_frag[4];
|
|
ckv_smem.ldmatrix_m8n8x4_trans(ckv_smem_offset_r, v_frag);
|
|
mma::mma_sync_m16n16k16_row_col_f16f16f32<typename KTraits::DTypeKV>(
|
|
o_frag[mma_d], (uint32_t*)p_f16[mma_kv], v_frag);
|
|
ckv_smem_offset_r = ckv_smem.template advance_offset_by_column<2>(ckv_smem_offset_r, mma_d);
|
|
}
|
|
ckv_smem_offset_r =
|
|
ckv_smem.template advance_offset_by_row<16, UPCAST_STRIDE_CKV>(ckv_smem_offset_r) -
|
|
NUM_MMA_D_CKV;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename KTraits>
|
|
__device__ __forceinline__ void normalize_d_(typename KTraits::SharedStorage* smem_storage,
|
|
const uint32_t stage_idx, float (*o_frag)[8],
|
|
typename KTraits::DTypeQKAccum* m, float* d) {
|
|
const uint32_t warpgroup_idx = threadIdx.z, lane_idx = threadIdx.x, warp_idx_in_wg = threadIdx.y;
|
|
if constexpr (KTraits::QK_SHARD) {
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 2; ++j) {
|
|
if (lane_idx % 4 == 0) {
|
|
smem_storage->d_wg[warpgroup_idx][warp_idx_in_wg * 16 + j * 8 + lane_idx / 4] = d[j];
|
|
}
|
|
}
|
|
__syncthreads();
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 2; ++j) {
|
|
d[j] = smem_storage->d_wg[0][warp_idx_in_wg * 16 + j * 8 + lane_idx / 4] +
|
|
smem_storage->d_wg[1][warp_idx_in_wg * 16 + j * 8 + lane_idx / 4];
|
|
}
|
|
}
|
|
|
|
float d_rcp[2];
|
|
// compute reciprocal of d
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 2; ++j) {
|
|
d_rcp[j] = (m[j] != typename KTraits::DTypeQKAccum(-math::inf)) ? math::ptx_rcp(d[j]) : 0.f;
|
|
}
|
|
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_CKV / 2; ++mma_d) {
|
|
#pragma unroll
|
|
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
|
|
o_frag[mma_d][reg_id] = o_frag[mma_d][reg_id] * d_rcp[(reg_id % 4) / 2];
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename KTraits>
|
|
__device__ __forceinline__ void finalize_m_(typename KTraits::AttentionVariant variant,
|
|
typename KTraits::DTypeQKAccum* m) {
|
|
if constexpr (variant.use_softmax) {
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 2; ++j) {
|
|
if (m[j] != typename KTraits::DTypeQKAccum(-math::inf)) {
|
|
m[j] *= variant.sm_scale_log2;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename KTraits>
|
|
__device__ void DevicePersistentMergeStates(
|
|
typename KTraits::IdType* merge_packed_offset_start,
|
|
typename KTraits::IdType* merge_packed_offset_end,
|
|
typename KTraits::IdType* merge_partial_packed_offset_start,
|
|
typename KTraits::IdType* merge_partial_packed_offset_end,
|
|
typename KTraits::IdType* merge_partial_stride, typename KTraits::DTypeO* partial_o,
|
|
float* partial_lse, typename KTraits::DTypeO* final_o, float* final_lse,
|
|
const uint32_t o_stride_n, const uint32_t o_stride_h, const uint_fastdiv& num_heads) {
|
|
constexpr uint32_t VEC_SIZE = 8; // partial o has data type float
|
|
constexpr uint32_t NUM_THRS_PER_ROW = KTraits::HEAD_DIM_CKV / VEC_SIZE;
|
|
constexpr uint32_t ROWS_PER_ITERATION = (KTraits::NUM_THREADS) / NUM_THRS_PER_ROW;
|
|
const uint32_t cta_idx = (gridDim.x * blockIdx.y + blockIdx.x);
|
|
const uint32_t thread_id = (threadIdx.z * blockDim.y + threadIdx.y) * blockDim.x + threadIdx.x;
|
|
const uint32_t offset_start = merge_packed_offset_start[cta_idx];
|
|
const uint32_t len = merge_packed_offset_end[cta_idx] - offset_start;
|
|
const uint32_t partial_offset_start = merge_partial_packed_offset_start[cta_idx];
|
|
const uint32_t partial_offset_end = merge_partial_packed_offset_end[cta_idx];
|
|
const uint32_t stride = merge_partial_stride[cta_idx];
|
|
#pragma unroll 1
|
|
for (uint32_t local_packed_offset = thread_id / NUM_THRS_PER_ROW; local_packed_offset < len;
|
|
local_packed_offset += ROWS_PER_ITERATION) {
|
|
uint32_t final_packed_offset = offset_start + local_packed_offset;
|
|
uint32_t q, r;
|
|
num_heads.divmod(final_packed_offset, q, r);
|
|
state_t<VEC_SIZE> st;
|
|
#pragma unroll 8
|
|
for (uint32_t partial_packed_offset = partial_offset_start + local_packed_offset;
|
|
partial_packed_offset < partial_offset_end; partial_packed_offset += stride) {
|
|
vec_t<float, VEC_SIZE> o_partial;
|
|
float lse_partial;
|
|
o_partial.cast_load(partial_o + partial_packed_offset * KTraits::HEAD_DIM_CKV +
|
|
(thread_id % NUM_THRS_PER_ROW) * VEC_SIZE);
|
|
lse_partial = partial_lse[partial_packed_offset];
|
|
st.merge(o_partial, lse_partial, 1);
|
|
}
|
|
st.normalize();
|
|
st.o.cast_store(final_o +
|
|
(q * o_stride_n + r * o_stride_h + (thread_id % NUM_THRS_PER_ROW) * VEC_SIZE));
|
|
if (final_lse) {
|
|
final_lse[q * num_heads + r] = st.get_lse();
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename KTraits>
|
|
__device__ __forceinline__ void write_o(typename KTraits::SharedStorage* smem_storage,
|
|
typename KTraits::DTypeO* final_o, float* final_lse,
|
|
typename KTraits::DTypeO* partial_o, float* partial_lse,
|
|
float (*o_frag)[8], typename KTraits::DTypeQKAccum* m,
|
|
float* d, const uint32_t o_stride_n,
|
|
const uint32_t o_stride_h, const uint32_t q_len,
|
|
const uint32_t packed_offset,
|
|
const uint_fastdiv& num_heads) {
|
|
using DTypeO = typename KTraits::DTypeO;
|
|
constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV;
|
|
constexpr uint32_t HEAD_DIM_CKV = KTraits::HEAD_DIM_CKV;
|
|
constexpr uint32_t UPCAST_STRIDE_FINAL_O = KTraits::UPCAST_STRIDE_FINAL_O;
|
|
const uint32_t lane_idx = threadIdx.x, warpgroup_idx = threadIdx.z, warp_idx_in_wg = threadIdx.y;
|
|
smem_t<KTraits::SWIZZLE_MODE_O> o_smem(smem_storage->o_smem);
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 2; ++mma_d) {
|
|
uint32_t o_frag_f16[8 / 2];
|
|
vec_cast<DTypeO, float>::cast<8>((DTypeO*)o_frag_f16, o_frag[mma_d]);
|
|
#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED
|
|
uint32_t o_smem_offset_w = o_smem.template get_permuted_offset<UPCAST_STRIDE_FINAL_O>(
|
|
warp_idx_in_wg * 16 + lane_idx % 16,
|
|
warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2 + lane_idx / 16);
|
|
o_smem.template stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16);
|
|
#else
|
|
uint32_t o_smem_offset_w = o_smem.template get_permuted_offset<UPCAST_STRIDE_FINAL_O>(
|
|
warp_idx_in_wg * 16 + lane_idx / 4, warpgroup_idx * NUM_MMA_D_CKV + mma_d * 2);
|
|
((uint32_t*)(o_smem.base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0];
|
|
((uint32_t*)(o_smem.base + o_smem_offset_w + 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] =
|
|
o_frag_f16[1];
|
|
((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = o_frag_f16[2];
|
|
((uint32_t*)(o_smem.base + (o_smem_offset_w ^ 0x1) + 8 * UPCAST_STRIDE_FINAL_O))[lane_idx % 4] =
|
|
o_frag_f16[3];
|
|
#endif
|
|
}
|
|
|
|
if (partial_o != nullptr) {
|
|
// write to partial_o
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 2; ++j) {
|
|
uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4) / num_heads;
|
|
if (lane_idx % 4 == 0 && q_idx < q_len) {
|
|
partial_lse[(blockIdx.x * 4 + warp_idx_in_wg) * 16 + 8 * j + lane_idx / 4] =
|
|
math::ptx_log2(d[j]) + float(m[j]);
|
|
}
|
|
}
|
|
|
|
// step 1. smem to gmem
|
|
uint32_t o_smem_offset_w = o_smem.template get_permuted_offset<UPCAST_STRIDE_FINAL_O>(
|
|
warp_idx_in_wg * 16 + lane_idx / 8, warpgroup_idx * NUM_MMA_D_CKV + lane_idx % 8);
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 4; ++j) {
|
|
uint32_t q_idx = (packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8) / num_heads;
|
|
DTypeO* o_partial_ptr =
|
|
partial_o +
|
|
((blockIdx.x * 4 + warp_idx_in_wg) * 16 + 4 * j + lane_idx / 8) * HEAD_DIM_CKV +
|
|
warpgroup_idx * (HEAD_DIM_CKV / 2) + (lane_idx % 8) * upcast_size<DTypeO>();
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 8; ++mma_d) {
|
|
if (q_idx < q_len) {
|
|
o_smem.template store_128b(o_smem_offset_w, o_partial_ptr);
|
|
}
|
|
o_partial_ptr += 8 * upcast_size<DTypeO>();
|
|
o_smem_offset_w = o_smem.template advance_offset_by_column<8>(o_smem_offset_w, mma_d);
|
|
}
|
|
o_smem_offset_w =
|
|
o_smem.template advance_offset_by_row<4, UPCAST_STRIDE_FINAL_O>(o_smem_offset_w) -
|
|
NUM_MMA_D_CKV;
|
|
}
|
|
} else {
|
|
// write to final_o
|
|
|
|
if (final_lse) {
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 2; ++j) {
|
|
uint32_t q, r;
|
|
num_heads.divmod(packed_offset + warp_idx_in_wg * 16 + 8 * j + lane_idx / 4, q, r);
|
|
if (lane_idx % 4 == 0 && q < q_len) {
|
|
final_lse[q * num_heads + r] = math::ptx_log2(d[j]) + float(m[j]);
|
|
}
|
|
}
|
|
}
|
|
|
|
// step 1. smem to gmem
|
|
uint32_t o_smem_offset_w = o_smem.template get_permuted_offset<UPCAST_STRIDE_FINAL_O>(
|
|
warp_idx_in_wg * 16 + lane_idx / 8, warpgroup_idx * NUM_MMA_D_CKV + lane_idx % 8);
|
|
#pragma unroll
|
|
for (uint32_t j = 0; j < 4; ++j) {
|
|
uint32_t q, r;
|
|
num_heads.divmod(packed_offset + warp_idx_in_wg * 16 + 4 * j + lane_idx / 8, q, r);
|
|
DTypeO* o_final_ptr = final_o + q * o_stride_n + r * o_stride_h +
|
|
warpgroup_idx * (HEAD_DIM_CKV / 2) +
|
|
(lane_idx % 8) * upcast_size<DTypeO>();
|
|
#pragma unroll
|
|
for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_CKV / 8; ++mma_d) {
|
|
if (q < q_len) {
|
|
o_smem.template store_128b(o_smem_offset_w, o_final_ptr);
|
|
}
|
|
o_final_ptr += 8 * upcast_size<DTypeO>();
|
|
o_smem_offset_w = o_smem.template advance_offset_by_column<8>(o_smem_offset_w, mma_d);
|
|
}
|
|
o_smem_offset_w =
|
|
o_smem.template advance_offset_by_row<4, UPCAST_STRIDE_FINAL_O>(o_smem_offset_w) -
|
|
NUM_MMA_D_CKV;
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename KTraits, typename Params>
|
|
__global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchMLAPagedAttentionKernel(
|
|
const __grid_constant__ Params params) {
|
|
using DTypeQ = typename Params::DTypeQ;
|
|
using DTypeKV = typename Params::DTypeKV;
|
|
using DTypeO = typename Params::DTypeO;
|
|
using IdType = typename Params::IdType;
|
|
|
|
extern __shared__ __align__(alignof(typename KTraits::SharedStorage)) uint8_t smem[];
|
|
auto& smem_storage = reinterpret_cast<typename KTraits::SharedStorage&>(smem);
|
|
|
|
typename KTraits::AttentionVariant variant(params, blockIdx.y, smem);
|
|
|
|
[[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q_NOPE = KTraits::SWIZZLE_MODE_Q_NOPE;
|
|
[[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q_PE = KTraits::SWIZZLE_MODE_Q_PE;
|
|
[[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_CKV = KTraits::SWIZZLE_MODE_CKV;
|
|
[[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KPE = KTraits::SWIZZLE_MODE_KPE;
|
|
[[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV;
|
|
[[maybe_unused]] constexpr uint32_t NUM_MMA_D_CKV = KTraits::NUM_MMA_D_CKV;
|
|
[[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q;
|
|
[[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV;
|
|
[[maybe_unused]] constexpr int32_t NUM_STAGES = KTraits::NUM_STAGES;
|
|
[[maybe_unused]] constexpr bool CAUSAL = KTraits::CAUSAL;
|
|
|
|
DTypeQ* q_nope = params.q_nope;
|
|
DTypeQ* q_pe = params.q_pe;
|
|
DTypeKV* ckv = params.ckv;
|
|
DTypeKV* kpe = params.kpe;
|
|
IdType* kv_indices = params.kv_indices;
|
|
DTypeO* partial_o = params.partial_o;
|
|
float* partial_lse = params.partial_lse;
|
|
DTypeO* final_o = params.final_o;
|
|
float* final_lse = params.final_lse;
|
|
IdType* work_indptr = params.work_indptr;
|
|
|
|
float s_frag[KTraits::QK_SHARD ? NUM_MMA_KV / 2 : NUM_MMA_KV][8];
|
|
alignas(16) float o_frag[NUM_MMA_D_CKV / 2][8];
|
|
float m[2];
|
|
float d[2];
|
|
|
|
const uint_fastdiv& num_heads = params.num_heads;
|
|
const uint_fastdiv& block_size = params.block_size;
|
|
const uint32_t q_nope_stride_n = params.q_nope_stride_n;
|
|
const uint32_t q_nope_stride_h = params.q_nope_stride_h;
|
|
const uint32_t q_pe_stride_n = params.q_pe_stride_n;
|
|
const uint32_t q_pe_stride_h = params.q_pe_stride_h;
|
|
const uint32_t ckv_stride_page = params.ckv_stride_page;
|
|
const uint32_t ckv_stride_n = params.ckv_stride_n;
|
|
const uint32_t kpe_stride_page = params.kpe_stride_page;
|
|
const uint32_t kpe_stride_n = params.kpe_stride_n;
|
|
const uint32_t o_stride_n = params.o_stride_n;
|
|
const uint32_t o_stride_h = params.o_stride_h;
|
|
const uint32_t cluster_tile_q = gridDim.x * KTraits::CTA_TILE_Q;
|
|
|
|
#pragma unroll 1
|
|
for (IdType work_idx = work_indptr[blockIdx.y]; work_idx < work_indptr[blockIdx.y + 1];
|
|
++work_idx) {
|
|
const uint32_t q_indptr = params.q_indptr[work_idx];
|
|
const uint32_t kv_indptr = params.kv_indptr[work_idx];
|
|
const int32_t partial_indptr = params.partial_indptr[work_idx];
|
|
const uint32_t q_len = params.q_len[work_idx];
|
|
const uint32_t kv_len = params.kv_len[work_idx];
|
|
const uint32_t packed_qo_start = params.q_start[work_idx];
|
|
const uint32_t kv_start = params.kv_start[work_idx];
|
|
const uint32_t kv_end = params.kv_end[work_idx];
|
|
|
|
const uint32_t qo_packed_idx_base = packed_qo_start + blockIdx.x * KTraits::CTA_TILE_Q;
|
|
const uint32_t qo_upperbound =
|
|
min(q_len, ceil_div(qo_packed_idx_base + KTraits::CTA_TILE_Q, num_heads));
|
|
|
|
init_states_<KTraits>(o_frag, m, d);
|
|
|
|
__syncthreads();
|
|
load_q<KTraits>(&smem_storage, q_nope + q_indptr * q_nope_stride_n,
|
|
q_pe + q_indptr * q_pe_stride_n, q_nope_stride_n, q_nope_stride_h,
|
|
q_pe_stride_n, q_pe_stride_h, qo_upperbound, qo_packed_idx_base,
|
|
params.num_heads);
|
|
|
|
int kv_tile_idx =
|
|
ceil_div(
|
|
(CAUSAL ? min(kv_end, kv_len - q_len + (packed_qo_start + cluster_tile_q) / num_heads)
|
|
: kv_end),
|
|
CTA_TILE_KV) -
|
|
1 - (kv_start / CTA_TILE_KV);
|
|
|
|
int mask_tile_idx =
|
|
(CAUSAL ? min(kv_end, kv_len - q_len + packed_qo_start / num_heads) : kv_end) /
|
|
CTA_TILE_KV -
|
|
(kv_start / CTA_TILE_KV);
|
|
|
|
uint32_t block_iter_base = kv_indptr * block_size + kv_start;
|
|
// last kv tile
|
|
__syncthreads();
|
|
uint32_t packed_kv_bound = kv_indptr * block_size + kv_len;
|
|
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
|
|
kpe_stride_n, kpe_stride_page, packed_kv_bound,
|
|
block_iter_base + kv_tile_idx * CTA_TILE_KV, block_size,
|
|
kv_tile_idx % NUM_STAGES);
|
|
cp_async::commit_group();
|
|
#pragma unroll
|
|
for (int stage_idx = 1; stage_idx < NUM_STAGES; ++stage_idx) {
|
|
if (kv_tile_idx - stage_idx >= 0) {
|
|
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
|
|
kpe_stride_n, kpe_stride_page, packed_kv_bound,
|
|
block_iter_base + (kv_tile_idx - stage_idx) * CTA_TILE_KV, block_size,
|
|
(kv_tile_idx - stage_idx) % NUM_STAGES);
|
|
cp_async::commit_group();
|
|
}
|
|
}
|
|
|
|
// loop with mask
|
|
#pragma unroll 1
|
|
for (; kv_tile_idx >= mask_tile_idx && kv_tile_idx > 0; --kv_tile_idx) {
|
|
cp_async::wait_group<NUM_STAGES - 1>();
|
|
__syncthreads();
|
|
|
|
// compute mla qk
|
|
compute_mla_qk<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, s_frag);
|
|
|
|
// logits mask
|
|
logits_mask_<KTraits>(qo_packed_idx_base, kv_start + kv_tile_idx * CTA_TILE_KV, q_len, kv_len,
|
|
kv_end, num_heads, s_frag);
|
|
|
|
// compute m,d states in online softmax
|
|
update_mdo_states_<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, variant, s_frag, o_frag,
|
|
m, d);
|
|
|
|
// compute sfm * v
|
|
compute_mla_pv<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, s_frag, d, o_frag);
|
|
|
|
if (kv_tile_idx - NUM_STAGES >= 0) {
|
|
__syncthreads();
|
|
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
|
|
kpe_stride_n, kpe_stride_page, packed_kv_bound,
|
|
block_iter_base + (kv_tile_idx - NUM_STAGES) * CTA_TILE_KV, block_size,
|
|
(kv_tile_idx - NUM_STAGES) % NUM_STAGES);
|
|
cp_async::commit_group();
|
|
}
|
|
}
|
|
|
|
// loop without mask
|
|
#pragma unroll 1
|
|
for (; kv_tile_idx + 1 > NUM_STAGES; --kv_tile_idx) {
|
|
cp_async::wait_group<NUM_STAGES - 1>();
|
|
__syncthreads();
|
|
|
|
// compute mla qk
|
|
compute_mla_qk<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, s_frag);
|
|
|
|
// compute m,d states in online softmax
|
|
update_mdo_states_<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, variant, s_frag, o_frag,
|
|
m, d);
|
|
// compute sfm * v
|
|
compute_mla_pv<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, s_frag, d, o_frag);
|
|
|
|
__syncthreads();
|
|
load_kv<KTraits>(&smem_storage, ckv, kpe, kv_indices, ckv_stride_n, ckv_stride_page,
|
|
kpe_stride_n, kpe_stride_page, packed_kv_bound,
|
|
block_iter_base + (kv_tile_idx - NUM_STAGES) * CTA_TILE_KV, block_size,
|
|
(kv_tile_idx - NUM_STAGES) % NUM_STAGES);
|
|
cp_async::commit_group();
|
|
}
|
|
cp_async::wait_group<0>();
|
|
__syncthreads();
|
|
|
|
// last tiles
|
|
#pragma unroll
|
|
for (; kv_tile_idx >= 0; --kv_tile_idx) {
|
|
// compute mla qk
|
|
compute_mla_qk<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, s_frag);
|
|
|
|
logits_mask_<KTraits>(qo_packed_idx_base, kv_start + kv_tile_idx * CTA_TILE_KV, q_len, kv_len,
|
|
kv_end, num_heads, s_frag);
|
|
|
|
// compute m,d states in online softmax
|
|
update_mdo_states_<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, variant, s_frag, o_frag,
|
|
m, d);
|
|
|
|
// compute sfm * v
|
|
compute_mla_pv<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, s_frag, d, o_frag);
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
// normalize and write back
|
|
normalize_d_<KTraits>(&smem_storage, kv_tile_idx % NUM_STAGES, o_frag, m, d);
|
|
|
|
finalize_m_<KTraits>(variant, m);
|
|
|
|
write_o<KTraits>(
|
|
&smem_storage, final_o + q_indptr * o_stride_n,
|
|
final_lse ? final_lse + q_indptr * num_heads : nullptr,
|
|
(partial_indptr == -1) ? nullptr : partial_o + partial_indptr * KTraits::HEAD_DIM_CKV,
|
|
(partial_indptr == -1) ? nullptr : partial_lse + partial_indptr, o_frag, m, d, o_stride_n,
|
|
o_stride_h, qo_upperbound, qo_packed_idx_base, num_heads);
|
|
}
|
|
|
|
auto grid = cg::this_grid();
|
|
grid.sync();
|
|
|
|
// the second stage, merge partial outputs
|
|
DevicePersistentMergeStates<KTraits>(
|
|
params.merge_packed_offset_start, params.merge_packed_offset_end,
|
|
params.merge_partial_packed_offset_start, params.merge_partial_packed_offset_end,
|
|
params.merge_partial_stride, partial_o, partial_lse, final_o, final_lse, o_stride_n,
|
|
o_stride_h, num_heads);
|
|
}
|
|
|
|
#define DISPATCH_SMEM_CONFIG(smem_limit_per_sm, NUM_STAGES, CTA_TILE_KV, QK_SHARD, ...) \
|
|
if (smem_limit_per_sm >= 221696) { \
|
|
constexpr uint32_t NUM_STAGES = 2; \
|
|
constexpr uint32_t CTA_TILE_KV = 64; \
|
|
constexpr bool QK_SHARD = true; \
|
|
__VA_ARGS__; \
|
|
} else if (smem_limit_per_sm >= 147968) { \
|
|
constexpr uint32_t NUM_STAGES = 2; \
|
|
constexpr uint32_t CTA_TILE_KV = 32; \
|
|
constexpr bool QK_SHARD = true; \
|
|
__VA_ARGS__; \
|
|
} else if (smem_limit_per_sm >= 92672) { \
|
|
constexpr uint32_t NUM_STAGES = 1; \
|
|
constexpr uint32_t CTA_TILE_KV = 16; \
|
|
constexpr bool QK_SHARD = false; \
|
|
__VA_ARGS__; \
|
|
} else { \
|
|
std::ostringstream err; \
|
|
err << "Unsupported shared memory size: " << smem_limit_per_sm; \
|
|
FLASHINFER_ERROR(err.str()); \
|
|
return cudaErrorNotSupported; \
|
|
}
|
|
|
|
template <MaskMode MASK_MODE, uint32_t HEAD_DIM_CKV, uint32_t HEAD_DIM_KPE, typename Params>
|
|
cudaError_t BatchMLAPagedAttention(Params params, uint32_t num_blks_x, uint32_t num_blks_y,
|
|
cudaStream_t stream) {
|
|
using DTypeQ = typename Params::DTypeQ;
|
|
using DTypeKV = typename Params::DTypeKV;
|
|
using DTypeO = typename Params::DTypeO;
|
|
using IdType = typename Params::IdType;
|
|
if (MASK_MODE == MaskMode::kCustom) {
|
|
return cudaErrorNotSupported;
|
|
}
|
|
constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal;
|
|
|
|
dim3 nblks(num_blks_x, num_blks_y);
|
|
dim3 nthrs(32, 4, 2);
|
|
|
|
// get GPU shared memory size
|
|
int device;
|
|
int smem_limit_per_sm;
|
|
cudaGetDevice(&device);
|
|
cudaDeviceGetAttribute(&smem_limit_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
|
|
|
|
DISPATCH_SMEM_CONFIG(smem_limit_per_sm, NUM_STAGES, CTA_TILE_KV, QK_SHARD, {
|
|
using KTraits = KernelTraits<CAUSAL, NUM_STAGES, QK_SHARD, HEAD_DIM_CKV, HEAD_DIM_KPE,
|
|
/*CTA_TILE_Q_=*/64, CTA_TILE_KV, DTypeQ, DTypeKV, DTypeO, IdType>;
|
|
size_t smem_size = sizeof(typename KTraits::SharedStorage);
|
|
auto kernel = BatchMLAPagedAttentionKernel<KTraits, Params>;
|
|
void* args[] = {(void*)¶ms};
|
|
|
|
FLASHINFER_CUDA_CALL(
|
|
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
|
FLASHINFER_CUDA_CALL(
|
|
cudaLaunchCooperativeKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
|
|
});
|
|
|
|
return cudaSuccess;
|
|
}
|
|
|
|
} // namespace mla
|
|
|
|
} // namespace flashinfer
|
|
|
|
#endif // FLASHINFER_MLA_FA2_CUH_
|