sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/attention/hopper/tile_scheduler.cuh

323 lines
9.8 KiB
Plaintext

/*
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri
* Dao. Licensed under the BSD 3-Clause.
*
* Modified by the FlashInfer team.
*/
#ifndef FLASHINFER_ATTENTION_HOPPER_TILE_SCHEDULER_CUH_
#define FLASHINFER_ATTENTION_HOPPER_TILE_SCHEDULER_CUH_
#include "cutlass/arch/barrier.h"
#include "cutlass/fast_math.h"
#include "named_barrier.cuh"
namespace flashinfer {
struct SingleTileScheduler {
public:
// Host side kernel arguments
struct Arguments {
int const num_qo_tiles, num_qo_heads, qo_len, kv_len;
cutlass::FastDivmod group_size_fastdiv;
};
// Device side kernel params
struct Params {
int const qo_len, kv_len;
cutlass::FastDivmod group_size_fastdiv;
};
static Params to_underlying_arguments(Arguments const& args) {
return {args.qo_len, args.kv_len, args.group_size_fastdiv};
}
static dim3 get_grid_dim(Arguments const& args, int num_sm) {
return {uint32_t(args.num_qo_tiles), uint32_t(args.num_qo_heads)};
}
struct WorkTileInfo {
int q_tile_idx = 0;
int qo_head_idx = 0;
int kv_head_idx = 0;
bool is_valid_tile = false;
CUTLASS_DEVICE
bool is_valid(Params const& params) const { return is_valid_tile; }
CUTLASS_DEVICE
auto get_block_coord(Params const& params) const {
return cute::tuple{q_tile_idx, qo_head_idx, kv_head_idx, /*qo_indptr=*/0,
/*kv_indptr=*/0, params.qo_len, params.kv_len, /*batch_idx=*/0};
}
};
CUTLASS_DEVICE
SingleTileScheduler() {}
CUTLASS_DEVICE
WorkTileInfo get_initial_work(Params const& params) const {
int qo_head_idx = blockIdx.y;
int kv_head_idx = params.group_size_fastdiv.divide(qo_head_idx);
return {/*q_tile_idx=*/int(blockIdx.x), qo_head_idx, kv_head_idx, /*is_valid_tile*/ true};
}
CUTLASS_DEVICE
void init_consumer() const {}
CUTLASS_DEVICE
void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
CUTLASS_DEVICE
void broadcast_next_work(WorkTileInfo& current_work) const {}
template <bool is_producer = false>
CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params,
WorkTileInfo const& current_work) const {
return {-1, -1, false};
}
};
template <typename IdType>
struct BatchPrefillPersistentTileScheduler {
public:
// Host side kernel arguments
struct Arguments {
IdType *work_indptr, *head_indices, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens,
*kv_lens, *batch_indices;
cutlass::FastDivmod group_size_fastdiv;
int num_qo_heads; // placeholder
};
// Device side kernel params
struct Params {
IdType *work_indptr, *head_indices, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens,
*kv_lens, *batch_indices;
cutlass::FastDivmod group_size_fastdiv;
};
static Params to_underlying_arguments(Arguments const& args) {
return {args.work_indptr, args.head_indices, args.qo_tile_indices,
args.qo_indptr, args.kv_indptr, args.qo_lens,
args.kv_lens, args.batch_indices, args.group_size_fastdiv};
}
static dim3 get_grid_dim(Arguments const& args, int num_sm) { return {(unsigned)num_sm}; }
struct WorkTileInfo {
int q_tile_idx = 0;
int qo_head_idx = 0;
int kv_head_idx = 0;
int qo_indptr = 0;
int kv_indptr = 0;
int qo_len = 0;
int kv_len = 0;
int batch_idx = 0;
int counter = 0;
int ptr_begin = 0;
int ptr_end = 0;
CUTLASS_DEVICE
bool is_valid(Params const& params) const { return counter + ptr_begin < ptr_end; }
CUTLASS_DEVICE
auto get_block_coord(Params const& params) const {
return cute::tuple{q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr,
kv_indptr, qo_len, kv_len, batch_idx};
}
};
CUTLASS_DEVICE
BatchPrefillPersistentTileScheduler() {}
CUTLASS_DEVICE
WorkTileInfo get_initial_work(Params const& params) const {
int ptr_begin = params.work_indptr[blockIdx.x];
int ptr_end = params.work_indptr[blockIdx.x + 1];
if (ptr_begin < ptr_end) {
int work_idx = ptr_begin;
int qo_head_idx = params.head_indices[work_idx];
int kv_head_idx = params.group_size_fastdiv.divide(qo_head_idx);
return {params.qo_tile_indices[work_idx],
qo_head_idx,
kv_head_idx,
params.qo_indptr[work_idx],
params.kv_indptr[work_idx],
params.qo_lens[work_idx],
params.kv_lens[work_idx],
params.batch_indices[work_idx],
/*counter=*/0,
ptr_begin,
ptr_end};
} else {
return {-1, -1, -1, -1, -1, -1, -1, 0, ptr_begin, ptr_end};
}
}
CUTLASS_DEVICE
void init_consumer() const {}
CUTLASS_DEVICE
void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
CUTLASS_DEVICE
void broadcast_next_work(WorkTileInfo& current_work) const {}
template <bool is_producer = false>
CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params,
WorkTileInfo const& current_work) const {
int work_idx = current_work.ptr_begin + current_work.counter + 1;
if (work_idx < current_work.ptr_end) {
int qo_head_idx = params.head_indices[work_idx];
int kv_head_idx = params.group_size_fastdiv.divide(qo_head_idx);
return {params.qo_tile_indices[work_idx],
qo_head_idx,
kv_head_idx,
params.qo_indptr[work_idx],
params.kv_indptr[work_idx],
params.qo_lens[work_idx],
params.kv_lens[work_idx],
params.batch_indices[work_idx],
current_work.counter + 1,
current_work.ptr_begin,
current_work.ptr_end};
} else {
return {-1,
-1,
-1,
-1,
-1,
-1,
-1,
current_work.counter + 1,
current_work.ptr_begin,
current_work.ptr_end};
}
}
};
/*!
* \brief Tile scheduler that maps q/o head to blockIdx.y
*/
template <typename IdType>
struct BatchPrefillTileScheduler {
public:
// Host side kernel arguments
struct Arguments {
IdType *work_indptr, *head_indices, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens,
*kv_lens, *batch_indices; // head_indices is a placeholder
cutlass::FastDivmod group_size_fastdiv;
int num_qo_heads;
};
// Device side kernel params
struct Params {
IdType *work_indptr, *qo_tile_indices, *qo_indptr, *kv_indptr, *qo_lens, *kv_lens,
*batch_indices;
cutlass::FastDivmod group_size_fastdiv;
int num_qo_heads;
};
static Params to_underlying_arguments(Arguments const& args) {
return {args.work_indptr, args.qo_tile_indices, args.qo_indptr, args.kv_indptr,
args.qo_lens, args.kv_lens, args.batch_indices, args.group_size_fastdiv,
args.num_qo_heads};
}
static dim3 get_grid_dim(Arguments const& args, int num_sm) {
return {(unsigned)num_sm, (unsigned)args.num_qo_heads};
}
struct WorkTileInfo {
int q_tile_idx = 0;
int qo_head_idx = 0;
int kv_head_idx = 0;
int qo_indptr = 0;
int kv_indptr = 0;
int qo_len = 0;
int kv_len = 0;
int batch_idx = 0;
int counter = 0;
int ptr_begin = 0;
int ptr_end = 0;
CUTLASS_DEVICE
bool is_valid(Params const& params) const { return counter + ptr_begin < ptr_end; }
CUTLASS_DEVICE
auto get_block_coord(Params const& params) const {
return cute::tuple{q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr,
kv_indptr, qo_len, kv_len, batch_idx};
}
};
CUTLASS_DEVICE
BatchPrefillTileScheduler() {}
CUTLASS_DEVICE
WorkTileInfo get_initial_work(Params const& params) const {
int ptr_begin = params.work_indptr[blockIdx.x];
int ptr_end = params.work_indptr[blockIdx.x + 1];
if (ptr_begin < ptr_end) {
int work_idx = ptr_begin;
int qo_head_idx = blockIdx.y;
int kv_head_idx = params.group_size_fastdiv.divide(qo_head_idx);
return {params.qo_tile_indices[work_idx],
/*qo_head_idx=*/qo_head_idx,
/*kv_head_idx=*/kv_head_idx,
params.qo_indptr[work_idx],
params.kv_indptr[work_idx],
params.qo_lens[work_idx],
params.kv_lens[work_idx],
params.batch_indices[work_idx],
/*counter=*/0,
ptr_begin,
ptr_end};
} else {
return {-1, -1, -1, -1, -1, -1, -1, 0, ptr_begin, ptr_end};
}
}
CUTLASS_DEVICE
void init_consumer() const {}
CUTLASS_DEVICE
void prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {}
CUTLASS_DEVICE
void broadcast_next_work(WorkTileInfo& current_work) const {}
template <bool is_producer = false>
CUTLASS_DEVICE WorkTileInfo get_next_work(Params const& params,
WorkTileInfo const& current_work) const {
int work_idx = current_work.ptr_begin + current_work.counter + 1;
if (work_idx < current_work.ptr_end) {
return {params.qo_tile_indices[work_idx],
current_work.qo_head_idx,
current_work.kv_head_idx,
params.qo_indptr[work_idx],
params.kv_indptr[work_idx],
params.qo_lens[work_idx],
params.kv_lens[work_idx],
params.batch_indices[work_idx],
current_work.counter + 1,
current_work.ptr_begin,
current_work.ptr_end};
} else {
return {-1,
-1,
-1,
-1,
-1,
-1,
-1,
current_work.counter + 1,
current_work.ptr_begin,
current_work.ptr_end};
}
}
};
} // namespace flashinfer
#endif // FLASHINFER_ATTENTION_HOPPER_TILE_SCHEDULER_CUH_