sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/attention/pod.cuh

476 lines
21 KiB
Plaintext

#ifndef FLASHINFER_POD_CUH_
#define FLASHINFER_POD_CUH_
#include <cooperative_groups.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include "../cp_async.cuh"
#include "../fastdiv.cuh"
#include "../frag_layout_swizzle.cuh"
#include "../layout.cuh"
#include "../math.cuh"
#include "../mma.cuh"
#include "../page.cuh"
#include "../permuted_smem.cuh"
#include "../pos_enc.cuh"
#include "../utils.cuh"
#include "cascade.cuh"
#include "decode.cuh"
#include "mask.cuh"
#include "prefill.cuh"
#include "variants.cuh"
namespace flashinfer {
namespace cg = cooperative_groups;
using cp_async::SharedMemFillMode;
using mma::MMAMode;
enum Operation {
PREFILL = 0,
DECODE = 1,
};
template <typename KTraits_P, typename KTraits_D, bool PartitionKV_P, typename PrefillParams,
typename DecodeParams>
__global__ __launch_bounds__(std::max(
KTraits_P::NUM_THREADS,
KTraits_D::NUM_THREADS)) void PODWithKVCacheTensorKernel(const uint32_t xsize,
const __grid_constant__ PrefillParams
prefill_params,
const __grid_constant__ DecodeParams
decode_params,
int* tbAssign) {
extern __shared__ uint8_t smem[];
// PREFILL VARS
const uint32_t num_kv_heads_p = prefill_params.num_kv_heads;
const uint32_t num_chunks = prefill_params.partition_kv;
const uint32_t qo_len = prefill_params.qo_len;
// DECODE VARS
const uint32_t padded_bsize = decode_params.padded_batch_size;
const uint32_t num_kv_heads_d = decode_params.paged_kv.num_heads;
// THREADBLOCKS
const uint32_t prefill_blocks = num_kv_heads_p * xsize * (PartitionKV_P ? num_chunks : 1);
const uint32_t decode_blocks = padded_bsize * num_kv_heads_d;
int op;
int linear_bid;
// SM-aware CTA scheduler
if (threadIdx.x == 0) {
// TODO_AK: If num_threads dont match, use virtual sub-CTAs.
// Requires changing block-level sync in main prefill/decode kernels.
constexpr int blk_factor_p = 1;
constexpr int blk_factor_d = 1;
// SM-aware threadblock scheduler code
// Find out which SM this threadblock is scheduled on
int num_SMs;
// WARNING: nsmid has only been tested on A100/H100, and matches SM count
// No guarantee this will work on other GPUs
asm volatile("mov.u32 %0, %nsmid;" : "=r"(num_SMs));
asm volatile("mov.u32 %0, %smid;" : "=r"(linear_bid));
const int prefill_slots = (prefill_blocks + blk_factor_p - 1) / blk_factor_p;
const int decode_slots = (decode_blocks + blk_factor_d - 1) / blk_factor_d;
if (prefill_slots <= decode_slots) {
// Total tags = (decode + prefill) / min(decode, prefill)
// = 1 + decode / prefill; when prefill < decode
const int total_tags = decode_slots / prefill_slots + 1;
// For this SM, what's the next operation we want to run?
op = (atomicAdd(&tbAssign[linear_bid], 1) % total_tags);
if (op > 0) {
op = 1;
}
} else {
// Total tags = (decode + prefill) / min(decode, prefill)
// = 1 + prefill / decode; when decode < prefill
const int pref_tags = prefill_slots / decode_slots;
// For this SM, what's the next operation we want to run?
op = (atomicAdd(&tbAssign[linear_bid], 1) % (pref_tags + 1));
if (op < pref_tags) {
op = 0;
} else {
op = 1;
}
}
// Get the next blockId for that operation
linear_bid = atomicAdd(&tbAssign[num_SMs + op], 1);
// If the blockId obtained exceeds the max blockIds for that op, switch to the other op
if (op == 0 && linear_bid >= prefill_slots) {
linear_bid = atomicAdd(&tbAssign[num_SMs + 1], 1);
op = !op;
} else if (op == 1 && linear_bid >= decode_slots) {
op = !op;
linear_bid = atomicAdd(&tbAssign[num_SMs + 0], 1);
}
// Write the blockId and operation to shared memory
((int*)smem)[0] = linear_bid;
((int*)smem)[1] = op;
}
// Sync to wait for dynamic scheduler to finish
__syncthreads();
// Fetch from shared memory the assigned blockId and operation.
linear_bid = ((int*)smem)[0];
op = ((int*)smem)[1];
// Sync to force all threads to wait
__syncthreads();
if (op == PREFILL) {
const uint32_t linear_tid = threadIdx.x;
// Return if threadId exceeds number of threads for this op
if (linear_tid >= 32 * KTraits_P::NUM_WARPS_Q * KTraits_P::NUM_WARPS_KV) return;
const dim3 tid = dim3(linear_tid % 32, (linear_tid / 32) % KTraits_P::NUM_WARPS_Q,
(linear_tid / 32) / KTraits_P::NUM_WARPS_Q);
// dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, num_kv_heads);
// dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), num_chunks, num_kv_heads);
// BlockID exceeds limit
if (linear_bid >= prefill_blocks) return;
const uint32_t bx = linear_bid % xsize;
auto& smem_storage = reinterpret_cast<typename KTraits_P::SharedStorage&>(smem);
// Not partition_kv
if constexpr (!PartitionKV_P) {
const uint32_t chunk_idx = 0;
const uint32_t kv_head_idx = linear_bid / xsize;
SinglePrefillWithKVCacheDevice<KTraits_P>(prefill_params, smem_storage, tid, bx, chunk_idx,
kv_head_idx, 1, num_kv_heads_p);
} else {
const uint32_t chunk_idx = (linear_bid / xsize) % num_chunks;
const uint32_t kv_head_idx = linear_bid / (xsize * num_chunks);
SinglePrefillWithKVCacheDevice<KTraits_P>(prefill_params, smem_storage, tid, bx, chunk_idx,
kv_head_idx, num_chunks, num_kv_heads_p);
}
} else /* OP == DECODE */ {
auto& smem_storage = reinterpret_cast<typename KTraits_D::SharedStorage&>(smem);
// dim3 nblks_d(padded_batch_size_d, 1, num_kv_heads);
if (linear_bid >= decode_blocks) return;
const uint32_t bx = linear_bid % padded_bsize;
const uint32_t kv_head_idx = linear_bid / padded_bsize;
// dim3 nthrs_d(32, NUM_WARPS_Q_D, NUM_WARPS_KV_D);
const uint32_t linear_tid = threadIdx.x;
// Return if threadId exceeds number of threads for this op
if (linear_tid >= 32 * KTraits_D::NUM_WARPS_Q * KTraits_D::NUM_WARPS_KV) return;
const dim3 tid = dim3(linear_tid % 32, (linear_tid / 32) % KTraits_D::NUM_WARPS_Q,
(linear_tid / 32) / KTraits_D::NUM_WARPS_Q);
BatchPrefillWithPagedKVCacheDevice<KTraits_D>(decode_params, smem_storage, tid, bx, kv_head_idx,
num_kv_heads_d);
}
}
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, PosEncodingMode POS_ENCODING_MODE,
bool USE_FP16_QK_REDUCTION, MaskMode MASK_MODE_P, uint32_t CTA_TILE_Q_D,
MaskMode MASK_MODE_D, typename PrefillAttentionVariant, typename DecodeAttentionVariant,
typename PrefillParams, typename DecodeParams>
cudaError_t PODWithKVCacheTensorDispatched(PrefillParams prefill_params,
typename PrefillParams::DTypeO* tmp_p,
DecodeParams decode_params,
typename DecodeParams::DTypeO* tmp_v, float* tmp_s,
bool enable_pdl, cudaStream_t stream) {
static_assert(std::is_same<typename PrefillParams::DTypeQ, typename DecodeParams::DTypeQ>::value);
static_assert(
std::is_same<typename PrefillParams::DTypeKV, typename DecodeParams::DTypeKV>::value);
static_assert(std::is_same<typename PrefillParams::DTypeO, typename DecodeParams::DTypeO>::value);
// Ensure heads match
assert(prefill_params.num_kv_heads == decode_params.paged_kv.num_heads);
assert(prefill_params.num_qo_heads == decode_params.num_qo_heads);
// Prefill variable setup
using DTypeQ_P = typename PrefillParams::DTypeQ;
using DTypeKV_P = typename PrefillParams::DTypeKV;
using DTypeO_P = typename PrefillParams::DTypeO;
const uint32_t num_qo_heads = prefill_params.num_qo_heads;
const uint32_t num_kv_heads = prefill_params.num_kv_heads;
const uint32_t qo_len = prefill_params.qo_len;
const uint32_t kv_len = prefill_params.kv_len;
if (kv_len < qo_len && MASK_MODE_P == MaskMode::kCausal) {
std::ostringstream err_msg;
err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal "
"to qo_len, got kv_len"
<< kv_len << " and qo_len " << qo_len;
FLASHINFER_ERROR(err_msg.str());
}
const uint32_t group_size = num_qo_heads / num_kv_heads;
const uint_fastdiv group_size_fastdiv(group_size);
constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16;
constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16;
uint32_t cta_tile_q_p = 0;
int64_t unpacked_qo_len = qo_len * group_size;
if (unpacked_qo_len > 64 && HEAD_DIM_VO < 256) {
cta_tile_q_p = 128;
} else {
auto compute_capacity = GetCudaComputeCapability();
if (compute_capacity.first >= 8) {
// Ampere or newer
if (unpacked_qo_len > 16) {
// avg_packed_qo_len <= 64
cta_tile_q_p = 64;
} else {
// avg_packed_qo_len <= 16
cta_tile_q_p = 16;
}
} else {
// NOTE(Zihao): not enough shared memory on Turing for 1x4 warp layout
cta_tile_q_p = 64;
}
}
// Decode vars setup
using DTypeQ_D = typename DecodeParams::DTypeQ;
using DTypeKV_D = typename DecodeParams::DTypeKV;
using DTypeO_D = typename DecodeParams::DTypeO;
const uint32_t padded_batch_size_d = decode_params.padded_batch_size;
constexpr uint32_t NUM_MMA_Q_D = get_num_mma_q(CTA_TILE_Q_D);
constexpr uint32_t NUM_WARPS_Q_D = get_num_warps_q(CTA_TILE_Q_D);
constexpr uint32_t NUM_WARPS_KV_D = get_num_warps_kv(CTA_TILE_Q_D);
if (padded_batch_size_d == 0) {
// No request, skip
// this won't happen in CUDAGraph mode because we fixed the padded_batch_size
return cudaSuccess;
}
// constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16;
// constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16;
using DTypeQKAccum_D =
typename std::conditional<USE_FP16_QK_REDUCTION && std::is_same_v<DTypeQ_D, half>, half,
float>::type;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
int max_smem_per_sm = 0;
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&max_smem_per_sm,
cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id));
// we expect each sm execute two threadblocks
// TODO(Zihao): fix the following computation
const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_D) * 16) ? 2 : 1;
const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm;
constexpr uint32_t max_num_mma_kv_reg_d =
(HEAD_DIM_VO >= 128 && NUM_MMA_Q_D == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama &&
!USE_FP16_QK_REDUCTION)
? 2
: (8 / NUM_MMA_Q_D);
// TODO(Zihao): fix the following computation
const uint32_t max_num_mma_kv_smem_d =
(max_smem_per_threadblock / (16 * HEAD_DIM_QK * sizeof(DTypeQ_D)) -
NUM_MMA_Q_D * NUM_WARPS_Q_D) /
(2 * NUM_WARPS_KV_D);
DISPATCH_CTA_TILE_Q(cta_tile_q_p, CTA_TILE_Q_P, {
constexpr uint32_t NUM_WARPS_Q_P = get_num_warps_q(CTA_TILE_Q_P);
constexpr uint32_t NUM_WARPS_KV_P = get_num_warps_kv(CTA_TILE_Q_P);
constexpr uint32_t NUM_MMA_Q_P = get_num_mma_q(CTA_TILE_Q_P);
using DTypeQKAccum_P =
typename std::conditional<USE_FP16_QK_REDUCTION && std::is_same_v<DTypeQ_P, half>, half,
float>::type;
// we expect each sm execute two threadblocks
// TODO(Zihao): fix the following computation
const int num_ctas_per_sm_p =
max_smem_per_sm > (16 * HEAD_DIM_QK * sizeof(DTypeQ_P) * 16) ? 2 : 1;
const int max_smem_per_threadblock_p = max_smem_per_sm / num_ctas_per_sm_p;
constexpr uint32_t max_num_mma_kv_reg_p =
(HEAD_DIM_VO >= 128 && NUM_MMA_Q_P == 2 &&
POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && !USE_FP16_QK_REDUCTION)
? 2
: (8 / NUM_MMA_Q_P);
// TODO(Zihao): fix the following computation
const uint32_t max_num_mma_kv_smem_p =
(max_smem_per_threadblock_p / (16 * HEAD_DIM_QK * sizeof(DTypeQ_P)) -
NUM_MMA_Q_P * NUM_WARPS_Q_P) /
(2 * NUM_WARPS_KV_P);
// control NUM_MMA_KV for maximum warp occupancy
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_p, max_num_mma_kv_reg_p), NUM_MMA_KV_P, {
using KTraits_P =
KernelTraits<MASK_MODE_P, CTA_TILE_Q_P, NUM_MMA_Q_P, NUM_MMA_KV_P, NUM_MMA_D_QK,
NUM_MMA_D_VO, NUM_WARPS_Q_P, NUM_WARPS_KV_P, POS_ENCODING_MODE, DTypeQ_P,
DTypeKV_P, DTypeO_P, DTypeQKAccum_P, typename PrefillParams::IdType,
PrefillAttentionVariant>;
if constexpr (KTraits_P::IsInvalid()) {
// Invalid configuration, skip
std::ostringstream err_msg;
err_msg << "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_P
<< " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
<< " NUM_MMA_KV=" << NUM_MMA_KV_P << " NUM_WARPS_Q=" << NUM_WARPS_Q_P
<< " NUM_WARPS_KV=" << NUM_WARPS_KV_P
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
" and report the issue to the developers.";
FLASHINFER_ERROR(err_msg.str());
} else {
// Decode stuff
// TODO: Is there a way to avoid this nested dispatch?
DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem_d, max_num_mma_kv_reg_d), NUM_MMA_KV_D, {
using KTraits_D =
KernelTraits<MASK_MODE_D, CTA_TILE_Q_D, NUM_MMA_Q_D, NUM_MMA_KV_D, NUM_MMA_D_QK,
NUM_MMA_D_VO, NUM_WARPS_Q_D, NUM_WARPS_KV_D, POS_ENCODING_MODE, DTypeQ_D,
DTypeKV_D, DTypeO_D, DTypeQKAccum_D, typename DecodeParams::IdType,
DecodeAttentionVariant>;
if constexpr (KTraits_D::IsInvalid()) {
// Invalid configuration, skip
std::ostringstream err_msg;
err_msg
<< "FlashInfer Internal Error: Invalid configuration : NUM_MMA_Q=" << NUM_MMA_Q_D
<< " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO
<< " NUM_MMA_KV=" << NUM_MMA_KV_D << " NUM_WARPS_Q=" << NUM_WARPS_Q_D
<< " NUM_WARPS_KV=" << NUM_WARPS_KV_D
<< " please create an issue (https://github.com/flashinfer-ai/flashinfer/issues)"
" and report the issue to the developers.";
FLASHINFER_ERROR(err_msg.str());
} else {
// End decode stuff
constexpr uint32_t num_threads_p = (NUM_WARPS_Q_P * NUM_WARPS_KV_P) * WARP_SIZE;
size_t smem_size_p = sizeof(typename KTraits_P::SharedStorage);
size_t smem_size_d = sizeof(typename KTraits_D::SharedStorage);
auto kernel =
PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, true, PrefillParams, DecodeParams>;
// Prefill: decide num_splits for split-kv
int num_blocks_per_sm = 0;
int num_sm = 0;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
// FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &num_blocks_per_sm, kernel, num_threads_p, smem_size_p));
// Above function returns 0 for some reason, so we use a workaround
num_blocks_per_sm = std::max(
1, std::min((int)(max_smem_per_sm / smem_size_p), (int)(256 / num_threads_p)));
uint32_t max_num_kv_chunks =
(num_blocks_per_sm * num_sm) /
(num_kv_heads * ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q));
uint32_t num_chunks;
if (max_num_kv_chunks > 0) {
uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256);
num_chunks = ceil_div(kv_len, chunk_size);
} else {
num_chunks = 0;
}
// Setup new prefill params if (not) split
auto o_p = prefill_params.o;
auto lse_p = prefill_params.lse;
float* tmp_lse = (float*)(tmp_p + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO);
if (num_chunks <= 1 || tmp_p == nullptr) {
// Enough parallelism, do not split-kv
prefill_params.partition_kv = 0;
kernel = PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, false, PrefillParams,
DecodeParams>;
} else {
// Use cooperative groups to increase occupancy
prefill_params.partition_kv = num_chunks;
prefill_params.o = tmp_p;
prefill_params.lse = tmp_lse;
kernel = PODWithKVCacheTensorKernel<KTraits_P, KTraits_D, true, PrefillParams,
DecodeParams>;
}
// Setup new decode params if (not) split
auto o_d = decode_params.o;
auto lse_d = decode_params.lse;
if (tmp_v == nullptr) {
// do not partition kv
decode_params.partition_kv = false;
} else {
decode_params.partition_kv = true;
decode_params.o = tmp_v;
decode_params.lse = tmp_s;
}
uint32_t xsize = ceil_div(qo_len * group_size, KTraits_P::CTA_TILE_Q);
int nblks_p(xsize * (prefill_params.partition_kv ? prefill_params.partition_kv : 1) *
num_kv_heads);
int nthrs_p(32 * NUM_WARPS_Q_P * NUM_WARPS_KV_P);
int nblks_d(padded_batch_size_d * 1 * num_kv_heads);
int nthrs_d(32 * NUM_WARPS_Q_D * NUM_WARPS_KV_D);
// ******* Select final combined sizes here ******* /
size_t smem_size = max(smem_size_p, smem_size_d);
int nblks = nblks_p + nblks_d;
int nthrs = max(nthrs_p, nthrs_d);
// printf("Smem: prefill %zu, decode %zu, total %zu\n", smem_size_p, smem_size_d,
// smem_size); printf("Blocks: prefill %d, decode %d, total %d\n", nblks_p, nblks_d,
// nblks); printf("Threads: prefill %d, decode %d, total %d\n", nthrs_p, nthrs_d,
// nthrs);
// ************************************************ /
static int* tbAssign = nullptr;
if (tbAssign == nullptr) cudaMalloc(&tbAssign, sizeof(int) * (num_sm + 2));
cudaMemset(tbAssign, 0, sizeof(int) * (num_sm + 2));
// Setup kernel arguments
void* args[] = {(void*)&xsize, (void*)&prefill_params, (void*)&decode_params,
(void*)&tbAssign};
FLASHINFER_CUDA_CALL(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Launch kernel
if (enable_pdl) {
cudaLaunchAttribute attribute[1];
cudaLaunchConfig_t config;
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
config.attrs = attribute;
config.numAttrs = 1;
config.gridDim = nblks;
config.blockDim = nthrs;
config.dynamicSmemBytes = smem_size;
config.stream = stream;
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx(&config, kernel, xsize, prefill_params,
decode_params, tbAssign));
} else {
FLASHINFER_CUDA_CALL(
cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream));
}
// Post-kernel stuff for split-kv prefill
if (!(num_chunks <= 1 || tmp_p == nullptr)) {
if constexpr (PrefillAttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(MergeStates(tmp_p, tmp_lse, o_p, lse_p, num_chunks, qo_len,
num_qo_heads, HEAD_DIM_VO, stream));
} else {
FLASHINFER_CUDA_CALL(AttentionSum(tmp_p, o_p, num_chunks, qo_len, num_qo_heads,
HEAD_DIM_VO, stream));
}
}
// Post-kernel stuff for split-kv decode
if (tmp_v != nullptr) {
if constexpr (DecodeAttentionVariant::use_softmax) {
FLASHINFER_CUDA_CALL(VariableLengthMergeStates(
tmp_v, tmp_s, decode_params.merge_indptr, o_d, lse_d,
decode_params.max_total_num_rows, decode_params.total_num_rows, num_qo_heads,
HEAD_DIM_VO, enable_pdl, stream));
} else {
FLASHINFER_CUDA_CALL(VariableLengthAttentionSum(
tmp_v, decode_params.merge_indptr, o_d, decode_params.max_total_num_rows,
decode_params.total_num_rows, num_qo_heads, HEAD_DIM_VO, enable_pdl, stream));
}
}
}
});
}
});
});
return cudaSuccess;
}
} // namespace flashinfer
#endif // FLASHINFER_PREFILL_CUH_