sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/attention/hopper/prefill_sm90.cuh

619 lines
29 KiB
Plaintext

/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef FLASHINFER_ATTENTION_HOPPER_PREFILL_SM90_CUH_
#define FLASHINFER_ATTENTION_HOPPER_PREFILL_SM90_CUH_
#include <cuda.h>
#include <cuda_device_runtime_api.h>
#include <cutlass/arch/reg_reconfig.h>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include <type_traits>
#include <vector>
#include "../../cutlass_utils.cuh"
#include "../../exception.h"
#include "../mask.cuh"
#include "cute/tensor.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "epilogue.cuh"
#include "kernel_traits.cuh"
#include "mainloop.cuh"
#include "mainloop_mma.cuh"
#include "sparse_mainloop.cuh"
#include "tile_scheduler.cuh"
#include "utils.cuh"
namespace flashinfer {
using namespace cute;
DEFINE_HAS_MEMBER(maybe_prefix_len_ptr)
DEFINE_HAS_MEMBER(maybe_token_pos_in_items_ptr)
DEFINE_HAS_MEMBER(token_pos_in_items_len)
DEFINE_HAS_MEMBER(maybe_max_item_len_ptr)
template <typename CollectiveMainloop, typename CollectiveEpilogue, typename Ktraits,
bool LEFT_SLIDING_WINDOW, bool CAUSAL, typename TileScheduler,
bool MULTIITEMSCORING = false>
__global__ void __launch_bounds__(Ktraits::NUM_WARPS* cutlass::NumThreadsPerWarp, 1)
PrefillWithKVCacheKernel(CUTE_GRID_CONSTANT
typename CollectiveMainloop::Params const mainloop_params,
CUTE_GRID_CONSTANT
typename CollectiveEpilogue::Params const epilogue_params,
CUTE_GRID_CONSTANT
typename TileScheduler::Params const scheduler_params) {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using DTypeO = typename Ktraits::DTypeO;
using DTypeQKAccum = typename Ktraits::DTypeQKAccum;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
using TileShape_PDV = typename Ktraits::TileShape_PDV;
using AttentionVariant = typename Ktraits::AttentionVariant;
static constexpr int NUM_MMA_THREADS = Ktraits::NUM_MMA_THREADS;
static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup;
static constexpr int CTA_Q = Ktraits::CTA_Q;
static constexpr int CTA_KV = Ktraits::CTA_KV;
static constexpr bool use_tma_load_kv = CollectiveMainloop::USE_TMA_LOAD_KV;
using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
extern __shared__ char shared_memory[];
auto& shared_storage = *reinterpret_cast<typename Ktraits::SharedStorage*>(shared_memory);
int const lane_predicate = cute::elect_one_sync();
int const warp_idx = cutlass::canonical_warp_idx_sync();
// Issue Tma Descriptor Prefetch from a single thread
if (warp_idx == 0 && lane_predicate) {
CollectiveMainloop::prefetch_tma_descriptors(mainloop_params);
CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params);
}
// Obtain warp index
int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup;
PipelineParams pipeline_params;
int warp_group_idx = cutlass::canonical_warp_group_idx();
pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer
: MainloopPipeline::ThreadCategory::Consumer;
if constexpr (use_tma_load_kv) {
pipeline_params.is_leader = warp_group_thread_idx == 0;
pipeline_params.num_consumers = NUM_MMA_THREADS;
} else {
pipeline_params.producer_arv_count = NUM_COPY_THREADS;
pipeline_params.consumer_arv_count = NUM_MMA_THREADS;
}
if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(/*num_threads=*/1);
shared_storage.barrier_O.init(/*num_threads=*/1);
}
// We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
MainloopPipeline pipeline_k = [&] {
if constexpr (use_tma_load_kv) {
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK;
return MainloopPipeline(shared_storage.pipeline_k, pipeline_params,
/*cluster_shape=*/Shape<_1, _1, _1>{});
} else {
return MainloopPipeline(shared_storage.pipeline_k, pipeline_params);
}
}();
MainloopPipeline pipeline_v = [&] {
if constexpr (use_tma_load_kv) {
pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesV;
return MainloopPipeline(shared_storage.pipeline_v, pipeline_params,
/*cluster_shape=*/Shape<_1, _1, _1>{});
} else {
return MainloopPipeline(shared_storage.pipeline_v, pipeline_params);
}
}();
CollectiveMainloop collective_mainloop;
CollectiveEpilogue collective_epilogue;
// We need this to guarantee that the Pipeline init is visible to all producers and consumer
// blocks in the Cluster
__syncthreads();
uint32_t* maybe_prefix_len_ptr = nullptr;
if constexpr (has_maybe_prefix_len_ptr_v<decltype(mainloop_params.additional_params)>) {
maybe_prefix_len_ptr = mainloop_params.additional_params.maybe_prefix_len_ptr;
}
uint16_t* maybe_token_pos_in_items_ptr = nullptr;
if constexpr (has_maybe_token_pos_in_items_ptr_v<decltype(mainloop_params.additional_params)>) {
maybe_token_pos_in_items_ptr = mainloop_params.additional_params.maybe_token_pos_in_items_ptr;
}
uint32_t token_pos_in_items_len = 0;
if constexpr (has_token_pos_in_items_len_v<decltype(mainloop_params.additional_params)>) {
token_pos_in_items_len = mainloop_params.additional_params.token_pos_in_items_len;
}
uint16_t* maybe_max_item_len_ptr = nullptr;
if constexpr (has_maybe_max_item_len_ptr_v<decltype(mainloop_params.additional_params)>) {
maybe_max_item_len_ptr = mainloop_params.additional_params.maybe_max_item_len_ptr;
}
if (warp_group_idx == 0) { // Producer
if constexpr (use_tma_load_kv) {
cutlass::arch::warpgroup_reg_dealloc<Ktraits::NUM_WARPS == 12 ? 24 : 32>();
} else {
cutlass::arch::warpgroup_reg_dealloc<72>();
}
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0);
if (!use_tma_load_kv || warp_idx_in_warpgroup == 0) { // Load Q, K, V
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
int work_idx = 0;
TileScheduler scheduler;
for (auto work_tile_info = scheduler.get_initial_work(scheduler_params);
work_tile_info.is_valid(scheduler_params);
work_tile_info = scheduler.template get_next_work</*is_producer=*/true>(
scheduler_params, work_tile_info)) {
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len,
batch_idx] = block_coord;
if (q_tile_idx * CTA_Q >= qo_len) {
continue;
}
int num_kv_tiles =
collective_mainloop.get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len);
if (num_kv_tiles <= 0) {
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
scheduler.broadcast_next_work(work_tile_info);
continue;
}
int num_kv_tiles_outside_items_window = 0;
int num_kv_tiles_prefix = 0;
if constexpr (MULTIITEMSCORING) {
auto prefix_len = __ldg(maybe_prefix_len_ptr + batch_idx);
auto max_item_len = __ldg(maybe_max_item_len_ptr + batch_idx);
auto valid_items_window_len =
std::max(0, q_tile_idx * CTA_Q + kv_len - qo_len - max_item_len);
num_kv_tiles_outside_items_window = valid_items_window_len / CTA_KV;
num_kv_tiles_prefix = cute::ceil_div(prefix_len, CTA_KV);
}
if constexpr (MULTIITEMSCORING) {
collective_mainloop.load<LEFT_SLIDING_WINDOW>(
mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx,
num_kv_tiles_outside_items_window, num_kv_tiles_prefix);
} else {
collective_mainloop.load<LEFT_SLIDING_WINDOW>(
mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v,
shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx);
}
++work_idx;
}
collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v);
}
} else { // Consumer
if constexpr (use_tma_load_kv) {
cutlass::arch::warpgroup_reg_alloc<Ktraits::NUM_WARPS == 12 ? 240 : 160>();
} else {
cutlass::arch::warpgroup_reg_alloc<Ktraits::NUM_WARPS == 12 ? 216 : 144>();
}
TileScheduler scheduler;
// Initialize matmul objects.
typename Ktraits::TiledMmaPV tiled_mma_pv;
PipelineState smem_pipe_read_k, smem_pipe_read_v;
// We don't need separate variables smem_pipe_release_k and smem_pipe_release_v
// (like in Cutlass's gemm) because the read and release pipeline states are always the same.
CollectiveMainloop::WarpScheduler::mma_init();
scheduler.init_consumer();
int work_idx = 0;
CUTLASS_PRAGMA_NO_UNROLL
for (auto work_tile_info = scheduler.get_initial_work(scheduler_params);
work_tile_info.is_valid(scheduler_params);
work_tile_info = scheduler.template get_next_work</*is_producer=*/false>(scheduler_params,
work_tile_info)) {
// Attention output (GEMM-II) accumulator.
Tensor tOrO = partition_fragment_C(tiled_mma_pv, select<0, 1>(TileShape_PDV{}));
auto block_coord = work_tile_info.get_block_coord(scheduler_params);
auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] =
block_coord;
AttentionVariant variant(mainloop_params, block_coord);
auto attention_updater =
variant.template GetAttentionUpdater<2 * (2 * CTA_Q / NUM_MMA_THREADS)>();
if (q_tile_idx * CTA_Q >= qo_len) {
continue;
}
int num_kv_tiles =
collective_mainloop.get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len);
if (num_kv_tiles <= 0) { // We exit early and write 0 to gO and -inf to gLSE.
collective_epilogue.store_zero(epilogue_params, shared_storage,
threadIdx.x - NUM_COPY_THREADS, block_coord);
continue;
}
int swa_begin_kv_tile_idx = 0;
int swa_end_kv_tile_idx = -1;
if constexpr (LEFT_SLIDING_WINDOW) {
swa_begin_kv_tile_idx = get_swa_begin_kv_tile_idx<CTA_Q, CTA_KV>(
mainloop_params.window_left, q_tile_idx, qo_len, kv_len);
swa_end_kv_tile_idx = get_swa_end_kv_tile_idx<CTA_Q, CTA_KV>(mainloop_params.window_left,
q_tile_idx, qo_len, kv_len);
}
uint32_t prefix_len = 0;
uint16_t* token_pos_in_items = nullptr;
if constexpr (MULTIITEMSCORING) {
prefix_len = __ldg(maybe_prefix_len_ptr + batch_idx);
token_pos_in_items = maybe_token_pos_in_items_ptr + batch_idx * token_pos_in_items_len;
}
int num_kv_tiles_outside_items_window = 0;
int num_kv_tiles_prefix = 0;
if constexpr (MULTIITEMSCORING) {
auto prefix_len = __ldg(maybe_prefix_len_ptr + batch_idx);
auto max_item_len = __ldg(maybe_max_item_len_ptr + batch_idx);
auto valid_items_window_len =
std::max(0, q_tile_idx * CTA_Q + kv_len - qo_len - max_item_len);
num_kv_tiles_outside_items_window = valid_items_window_len / CTA_KV;
num_kv_tiles_prefix = cute::ceil_div(prefix_len, CTA_KV);
}
mma_f16<Ktraits, /*LEFT_SLIDING_WINDOW=*/LEFT_SLIDING_WINDOW, CAUSAL, MULTIITEMSCORING,
CollectiveMainloop::WarpScheduler>(
mainloop_params, variant, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v,
tOrO, attention_updater, num_kv_tiles, swa_begin_kv_tile_idx, swa_end_kv_tile_idx,
threadIdx.x - NUM_COPY_THREADS, work_idx, q_tile_idx, shared_storage, qo_len, kv_len,
qo_head_idx, kv_head_idx, prefix_len, token_pos_in_items,
num_kv_tiles_outside_items_window, num_kv_tiles_prefix);
collective_epilogue.store(epilogue_params, tOrO, attention_updater.get_lse(), shared_storage,
tiled_mma_pv, threadIdx.x - NUM_COPY_THREADS, block_coord);
++work_idx;
}
collective_epilogue.store_tail();
}
}
template <typename KernelTraits, bool LEFT_SLIDING_WINDOW, bool CAUSAL, typename Params>
cudaError_t SinglePrefillWithKVCacheKernelTraitsDispatched(Params& params, cudaStream_t stream) {
using DTypeQ = typename KernelTraits::DTypeQ;
using DTypeKV = typename KernelTraits::DTypeKV;
using DTypeO = typename KernelTraits::DTypeO;
using CollectiveMainloop =
CollectiveMainloop<typename Params::AdditionalParams, KernelTraits, CAUSAL>;
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
using Scheduler = SingleTileScheduler;
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments(
{params.q_ptr,
get_gmem_layout(params.qo_len, params.num_qo_heads, KernelTraits::HEAD_DIM_QK,
params.q_stride_n,
params.q_stride_h), // layout_Q
params.k_ptr,
get_gmem_layout(params.kv_len, params.num_kv_heads, KernelTraits::HEAD_DIM_QK,
params.k_stride_n,
params.k_stride_h), // layout_K
params.v_ptr,
get_gmem_layout(params.kv_len, params.num_kv_heads, KernelTraits::HEAD_DIM_VO,
params.v_stride_n,
params.v_stride_h), // layout_V
params.window_left, params.additional_params});
typename CollectiveEpilogue::Params epilogue_params =
CollectiveEpilogue::to_underlying_arguments({
static_cast<DTypeO*>(params.o_ptr),
get_gmem_layout(params.qo_len, params.num_qo_heads, KernelTraits::HEAD_DIM_VO,
params.o_stride_n,
params.o_stride_h), // layout_O
static_cast<float*>(params.lse_ptr),
get_lse_gmem_layout(params.qo_len, params.num_qo_heads), // layout_LSE
});
int num_tiles_q = cutlass::ceil_div(params.qo_len, KernelTraits::CTA_Q);
// TODO(Zihao): also support kv-head major
typename Scheduler::Arguments scheduler_args = {
num_tiles_q, params.num_qo_heads, params.qo_len, params.kv_len,
cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads)};
typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
auto kernel =
(void*)PrefillWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits,
LEFT_SLIDING_WINDOW, CAUSAL, Scheduler>;
int smem_size = sizeof(typename KernelTraits::SharedStorage);
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
int device;
cudaGetDevice(&device);
int multiprocessor_count;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
static constexpr int num_ctas = KernelTraits::NUM_WARPS * 32;
dim3 block_dims(num_ctas);
void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params};
FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream));
return cudaSuccess;
}
template <typename KernelTraits, bool LEFT_SLIDING_WINDOW, bool CAUSAL,
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename Params, bool MULTIITEMSCORING = false>
cudaError_t BatchPrefillWithPagedKVCacheKernelTraitsDispatched(Params& params,
cudaStream_t stream) {
using DTypeQ = typename KernelTraits::DTypeQ;
using DTypeKV = typename KernelTraits::DTypeKV;
using DTypeO = typename KernelTraits::DTypeO;
using IdType = typename KernelTraits::IdType;
using CollectiveMainloop = SparseCollectiveMainloop<typename Params::AdditionalParams,
KernelTraits, CAUSAL, MULTIITEMSCORING>;
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
using Scheduler =
std::conditional_t<SAME_SCHEDULE_FOR_ALL_HEADS, BatchPrefillTileScheduler<IdType>,
BatchPrefillPersistentTileScheduler<IdType>>;
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments(
{params.q_ptr,
get_gmem_layout(params.nnz_qo, params.num_qo_heads, KernelTraits::HEAD_DIM_QK,
params.q_stride_n,
params.q_stride_h), // layout_Q
params.k_ptr,
// NOTE(Zihao): nnz was useless here, we can just pass 0
get_gmem_layout(/*nnz=*/0, params.num_kv_heads, KernelTraits::HEAD_DIM_QK, params.k_stride_n,
params.k_stride_h), // layout_K
params.v_ptr,
get_gmem_layout(/*nnz=*/0, params.num_kv_heads, KernelTraits::HEAD_DIM_VO, params.v_stride_n,
params.v_stride_h), // layout_V
params.kv_indices, params.window_left, params.additional_params});
typename CollectiveEpilogue::Params epilogue_params =
CollectiveEpilogue::to_underlying_arguments({
params.o_ptr,
get_gmem_layout(params.nnz_qo, params.num_qo_heads, KernelTraits::HEAD_DIM_VO,
params.o_stride_n,
params.o_stride_h), // layout_O
params.lse_ptr, get_lse_gmem_layout(params.nnz_qo, params.num_qo_heads), // layout_LSE
});
typename Scheduler::Arguments scheduler_args = {
params.work_indptr,
params.head_indices,
params.qo_tile_indices,
params.qo_indptr,
params.kv_indptr,
params.qo_lens,
params.kv_lens,
params.batch_indices,
cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads),
params.num_qo_heads};
typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
// Get the ptr to kernel function.
auto kernel =
(void*)PrefillWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits,
LEFT_SLIDING_WINDOW, CAUSAL, Scheduler, MULTIITEMSCORING>;
int smem_size = sizeof(typename KernelTraits::SharedStorage);
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
int device;
cudaGetDevice(&device);
int multiprocessor_count;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
dim3 block_dims(ctaSize);
void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params};
FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream));
return cudaSuccess;
}
template <typename KernelTraits, bool LEFT_SLIDING_WINDOW, bool CAUSAL,
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename Params>
cudaError_t BatchPrefillWithRaggedKVCacheKernelTraitsDispatched(Params& params,
cudaStream_t stream) {
using DTypeQ = typename KernelTraits::DTypeQ;
using DTypeKV = typename KernelTraits::DTypeKV;
using DTypeO = typename KernelTraits::DTypeO;
using IdType = typename KernelTraits::IdType;
using CollectiveMainloop =
CollectiveMainloop<typename Params::AdditionalParams, KernelTraits, CAUSAL>;
using CollectiveEpilogue = CollectiveEpilogue<KernelTraits>;
using Scheduler =
std::conditional_t<SAME_SCHEDULE_FOR_ALL_HEADS, BatchPrefillTileScheduler<IdType>,
BatchPrefillPersistentTileScheduler<IdType>>;
typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments(
{params.q_ptr,
get_gmem_layout(params.nnz_qo, params.num_qo_heads, KernelTraits::HEAD_DIM_QK,
params.q_stride_n,
params.q_stride_h), // layout_Q
params.k_ptr,
// NOTE(Zihao): nnz was useless here, we can just pass 0
get_gmem_layout(params.nnz_kv, params.num_kv_heads, KernelTraits::HEAD_DIM_QK,
params.k_stride_n,
params.k_stride_h), // layout_K
params.v_ptr,
get_gmem_layout(params.nnz_kv, params.num_kv_heads, KernelTraits::HEAD_DIM_VO,
params.v_stride_n,
params.v_stride_h), // layout_V
params.window_left, params.additional_params});
typename CollectiveEpilogue::Params epilogue_params =
CollectiveEpilogue::to_underlying_arguments({
params.o_ptr,
get_gmem_layout(params.nnz_qo, params.num_qo_heads, KernelTraits::HEAD_DIM_VO,
params.o_stride_n,
params.o_stride_h), // layout_O
params.lse_ptr, get_lse_gmem_layout(params.nnz_qo, params.num_qo_heads), // layout_LSE
});
// NOTE(Zihao): add support for kv head-major later
typename Scheduler::Arguments scheduler_args = {
params.work_indptr,
params.head_indices,
params.qo_tile_indices,
params.qo_indptr,
params.kv_indptr,
params.qo_lens,
params.kv_lens,
params.batch_indices,
cutlass::FastDivmod(params.num_qo_heads / params.num_kv_heads),
params.num_qo_heads};
typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args);
// Get the ptr to kernel function.
auto kernel =
(void*)PrefillWithKVCacheKernel<CollectiveMainloop, CollectiveEpilogue, KernelTraits,
LEFT_SLIDING_WINDOW, CAUSAL, Scheduler>;
int smem_size = sizeof(typename KernelTraits::SharedStorage);
FLASHINFER_CUDA_CALL(
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
int device;
cudaGetDevice(&device);
int multiprocessor_count;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&multiprocessor_count, cudaDevAttrMultiProcessorCount, device));
dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count);
static constexpr int ctaSize = KernelTraits::NUM_WARPS * 32;
dim3 block_dims(ctaSize);
void* args[] = {&mainloop_params, &epilogue_params, &scheduler_params};
FLASHINFER_CUDA_CALL(cudaLaunchKernel(kernel, grid_dims, block_dims, args, smem_size, stream));
return cudaSuccess;
}
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, bool CAUSAL>
constexpr auto getCTATileSize() {
if constexpr (HEAD_DIM_QK == HEAD_DIM_VO) {
if constexpr (HEAD_DIM_QK == 64) {
return std::make_tuple(192, 128);
} else if constexpr (HEAD_DIM_QK == 128) {
if constexpr (CAUSAL) {
return std::make_tuple(128, 128);
} else {
return std::make_tuple(128, 192);
}
} else {
return std::make_tuple(128, 64);
}
} else {
// NOTE(Zihao) hack for deepseek prefill
static_assert(HEAD_DIM_QK == 192 && HEAD_DIM_VO == 128);
return std::make_tuple(128, 128);
}
}
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, MaskMode MASK_MODE, bool LEFT_SLIDING_WINDOW,
typename AttentionVariant, typename Params>
cudaError_t SinglePrefillWithKVCacheDispatched(Params& params, cudaStream_t stream) {
static_assert(HEAD_DIM_VO == 64 || HEAD_DIM_VO == 128 || HEAD_DIM_VO == 256);
if (MASK_MODE == MaskMode::kCustom) {
return cudaErrorNotSupported; // Not supported yet.
}
constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal;
constexpr auto CTA_TILE_SIZE = getCTATileSize<HEAD_DIM_QK, HEAD_DIM_VO, CAUSAL>();
SinglePrefillWithKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true, HEAD_DIM_QK, HEAD_DIM_VO,
/*CTA_Q_=*/get<0>(CTA_TILE_SIZE),
/*CTA_KV_=*/get<1>(CTA_TILE_SIZE),
/*NUM_STAGES_=*/2, typename Params::DTypeQ, typename Params::DTypeKV,
typename Params::DTypeO, typename Params::IdType, AttentionVariant>,
LEFT_SLIDING_WINDOW, CAUSAL>(params, stream);
cudaError_t status = cudaGetLastError();
return status;
}
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, MaskMode MASK_MODE, bool LEFT_SLIDING_WINDOW,
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename Params>
cudaError_t BatchPrefillWithRaggedKVCacheDispatched(Params& params, bool enable_pdl,
cudaStream_t stream) {
static_assert(HEAD_DIM_VO == 64 || HEAD_DIM_VO == 128 || HEAD_DIM_VO == 256);
if (MASK_MODE == MaskMode::kCustom) {
return cudaErrorNotSupported; // Not supported yet.
}
constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal;
constexpr auto CTA_TILE_SIZE = getCTATileSize<HEAD_DIM_QK, HEAD_DIM_VO, CAUSAL>();
BatchPrefillWithRaggedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/true, HEAD_DIM_QK, HEAD_DIM_VO,
/*CTA_Q_=*/get<0>(CTA_TILE_SIZE),
/*CTA_KV_=*/get<1>(CTA_TILE_SIZE),
/*NUM_STAGES_=*/2, typename Params::DTypeQ, typename Params::DTypeKV,
typename Params::DTypeO, typename Params::IdType, AttentionVariant>,
LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS>(params, stream);
cudaError_t status = cudaGetLastError();
return status;
}
template <uint32_t HEAD_DIM_QK, uint32_t HEAD_DIM_VO, MaskMode MASK_MODE, bool LEFT_SLIDING_WINDOW,
bool SAME_SCHEDULE_FOR_ALL_HEADS, typename AttentionVariant, typename Params>
cudaError_t BatchPrefillWithPagedKVCacheDispatched(Params& params, bool enable_pdl,
cudaStream_t stream) {
static_assert(HEAD_DIM_VO == 64 || HEAD_DIM_VO == 128 || HEAD_DIM_VO == 256);
if (MASK_MODE == MaskMode::kCustom) {
return cudaErrorNotSupported; // Not supported yet.
}
constexpr bool CAUSAL = MASK_MODE == MaskMode::kCausal;
constexpr bool MULTIITEMSCORING = MASK_MODE == MaskMode::kMultiItemScoring;
if constexpr (HEAD_DIM_QK == HEAD_DIM_VO) {
if constexpr (HEAD_DIM_VO == 64) {
// NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 64, need to optimize later
BatchPrefillWithPagedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false, HEAD_DIM_QK, HEAD_DIM_VO,
/*CTA_Q_=*/192,
/*CTA_KV_=*/96,
/*NUM_STAGES_=*/2, typename Params::DTypeQ,
typename Params::DTypeKV, typename Params::DTypeO,
typename Params::IdType, AttentionVariant>,
LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS, Params, MULTIITEMSCORING>(
params, stream);
} else if constexpr (HEAD_DIM_VO == 128) {
BatchPrefillWithPagedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false, HEAD_DIM_QK, HEAD_DIM_VO,
/*CTA_Q_=*/128,
/*CTA_KV_=*/96,
/*NUM_STAGES_=*/2, typename Params::DTypeQ,
typename Params::DTypeKV, typename Params::DTypeO,
typename Params::IdType, AttentionVariant>,
LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS, Params, MULTIITEMSCORING>(
params, stream);
} else {
// HEAD_DIM == 256;
// NOTE(Zihao): CTA_KV not tuned for HEAD_DIM == 256, need to optimize later
BatchPrefillWithPagedKVCacheKernelTraitsDispatched<
AttentionKernelTraits</*USE_TMA_LOAD_KV=*/false, HEAD_DIM_QK, HEAD_DIM_VO,
/*CTA_Q_=*/128,
/*CTA_KV_=*/32,
/*NUM_STAGES_=*/2, typename Params::DTypeQ,
typename Params::DTypeKV, typename Params::DTypeO,
typename Params::IdType, AttentionVariant>,
LEFT_SLIDING_WINDOW, CAUSAL, SAME_SCHEDULE_FOR_ALL_HEADS, Params, MULTIITEMSCORING>(
params, stream);
}
} else {
return cudaErrorNotSupported;
}
cudaError_t status = cudaGetLastError();
return status;
};
} // namespace flashinfer
#endif // FLASHINFER_ATTENTION_HOPPER_PREFILL_SM90_CUH_