sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/attention/scheduler.cuh

1691 lines
77 KiB
Plaintext

/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_ATTENTION_SCHEDULER_CUH_
#define FLASHINFER_ATTENTION_SCHEDULER_CUH_
#include <cuda_runtime_api.h>
#include <driver_types.h>
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <sstream>
#include <vector>
#include "../allocator.h"
#include "../exception.h"
#include "../pos_enc.cuh"
#include "../utils.cuh"
#include "heap.h"
namespace flashinfer {
template <PosEncodingMode POS_ENCODING_MODE, uint32_t num_stages_smem, uint32_t tile_size_per_bdx,
uint32_t vec_size, uint32_t bdx, uint32_t bdy, uint32_t bdz, typename AttentionVariant,
typename Params>
__global__ void BatchDecodeWithPagedKVCacheKernel(const __grid_constant__ Params params);
template <uint32_t num_stages_smem, uint32_t vec_size_ckv, uint32_t vec_size_kpe, uint32_t bdx,
uint32_t bdy, uint32_t bdz, uint32_t tile_size_qo_heads, typename AttentionVariant,
typename Params>
__global__ void BatchDecodeWithPagedKVCacheKernelMLA(Params params);
template <uint32_t HEAD_DIM_CKV, uint32_t HEAD_DIM_KPE, uint32_t QO_TILE_LEN, typename DTypeKV>
std::tuple<uint32_t, uint32_t, uint32_t> LaunchSpecForDecodeKernelMlaCuteSM80(
const uint32_t num_qo_heads);
template <uint32_t HEAD_DIM_CKV, uint32_t HEAD_DIM_KPE, uint32_t QO_TILE_LEN, typename Params>
__global__ void BatchDecodeWithPagedKVCacheKernelMlaCuteSM80(Params params);
template <typename DType>
inline void CopyToPageLockedBuffer(void* page_locked_int_buffer, int64_t offset,
const std::vector<DType>& vec) {
DType* ptr = GetPtrFromBaseOffset<DType>(page_locked_int_buffer, offset);
std::copy(vec.begin(), vec.end(), ptr);
}
/*!
* \brief Compute the maximum number of pages per batch and the new batch size
* after we partition Paged KV-Cache into multiple chunks on KV sequence length
* dimension.
* \tparam IdType A template type indicates the index data type
* \param max_grid_size The maximum grid size of the kernel
* \param gdy gridDim.y
* \param num_pages The number of pages per request in the batch
* \param max_num_pages_per_batch_lb The pre-set lower bound of maximum number of
* pages per batch, default to 1
* \return (max_num_pages_per_batch, new_batch_size) The number of pages per batch and
* the new batch size after the partition.
*/
template <typename IdType>
inline auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(
const uint32_t max_grid_size, const uint32_t gdy, const std::vector<IdType>& num_pages,
const uint32_t min_num_pages_per_batch = 1) {
uint32_t low = min_num_pages_per_batch, high = 0;
for (const IdType& elem : num_pages) {
high = max(high, elem);
}
uint32_t new_batch_size;
while (low < high) {
uint32_t mid = (low + high) / 2;
new_batch_size = 0;
for (const IdType& elem : num_pages) {
new_batch_size += ceil_div(elem, mid);
}
if (new_batch_size * gdy > max_grid_size) {
low = mid + 1;
} else {
high = mid;
}
}
new_batch_size = 0;
for (const IdType& elem : num_pages) {
new_batch_size += ceil_div(std::max(elem, 1), low);
}
return std::make_tuple(low, new_batch_size);
}
inline auto PrefillBinarySearchKVChunkSize(const bool enable_cuda_graph,
const uint32_t max_batch_size_if_split,
const std::vector<int64_t>& packed_qo_len_arr,
const std::vector<int64_t>& kv_len_arr,
const uint32_t qo_chunk_size,
const uint32_t min_kv_chunk_size = 1) {
const int64_t batch_size = packed_qo_len_arr.size();
int64_t max_kv_len = 1;
for (const int64_t& kv_len : kv_len_arr) {
max_kv_len = std::max(max_kv_len, kv_len);
}
int64_t low = min_kv_chunk_size;
int64_t high = max_kv_len;
constexpr int64_t min_kv_len = 1;
while (low < high) {
const int64_t mid = (low + high) / 2;
int64_t new_batch_size = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) *
ceil_div(std::max(kv_len_arr[i], min_kv_len), mid);
}
if (new_batch_size > max_batch_size_if_split) {
low = mid + 1;
} else {
high = mid;
}
}
return std::make_tuple(enable_cuda_graph || low < max_kv_len, low);
}
/*!
* \brief Estimate the temporary buffer size and the maximum grid size for the
* partition-kv BatchDecodeWithPagedKVCache kernel
* \tparam DTypeKV A template type indicates the key-value data type
* \tparam DTypeO A template type indicates the output data type
* \tparam IdType A template type indicates the index data type
* \param split_kv Whether to split the KV cache into multiple chunks
* \param max_grid_size The maximum grid size that can be used in a partiton-kv kernel
* \param max_num_pages_per_batch The maximum number of pages per batch
* \param new_batch_size The new batch size after the partition
* \param paged_kv The paged kv cache data structure
* \param num_qo_heads A integer indicates the number of heads of query and output
* \param pos_encoding_mode The positional encoding mode
* \param stream The cuda stream to launch the kernel
* \return status Indicates whether CUDA calls are successful
*/
template <uint32_t GROUP_SIZE, uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE,
typename AttentionVariant, typename Params>
inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched(
bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch,
uint32_t& new_batch_size, uint32_t& gdy, uint32_t batch_size,
typename Params::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size,
bool enable_cuda_graph, cudaStream_t stream) {
using DTypeKV = typename Params::DTypeKV;
using IdType = typename Params::IdType;
constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL);
auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, {
constexpr uint32_t bdx = HEAD_DIM / vec_size;
static_assert(bdx <= 32);
constexpr uint32_t bdy = GROUP_SIZE;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U;
const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE;
gdy = num_kv_heads;
const uint32_t smem_size =
2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) +
std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float));
auto kernel =
BatchDecodeWithPagedKVCacheKernel<POS_ENCODING_MODE, NUM_STAGES_SMEM, tile_size_per_bdx,
vec_size, bdx, bdy, bdz, AttentionVariant, Params>;
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
max_grid_size = num_blocks_per_sm * num_sm;
if (batch_size * gdy >= max_grid_size) {
split_kv = false;
max_num_pages_per_batch = 1;
for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
max_num_pages_per_batch = std::max<uint32_t>(
max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]);
}
new_batch_size = batch_size;
} else {
// compute max_num_pages_per_batch and new_batch_size
std::vector<IdType> num_pages(batch_size);
for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx];
}
std::tie(max_num_pages_per_batch, new_batch_size) =
PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages,
std::max(128 / page_size, 1U));
if (new_batch_size == batch_size && !enable_cuda_graph) {
// do not use partition-kv kernel for short sequence, when not using CUDAGraph
split_kv = false;
} else {
// when using CUDAGraph, we always use partition-kv kernel
split_kv = true;
}
}
return cudaSuccess;
})
}
template <uint32_t HEAD_DIM_CKV, uint32_t HEAD_DIM_KPE, typename AttentionVariant, typename Params>
inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA(
bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch,
uint32_t& new_batch_size, uint32_t& gdy, uint32_t batch_size,
typename Params::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size,
bool enable_cuda_graph, cudaStream_t stream) {
using DTypeKV = typename Params::DTypeKV;
using IdType = typename Params::IdType;
auto compute_capacity = GetCudaComputeCapability();
DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, {
constexpr uint32_t vec_size_ckv = std::max(16UL / sizeof(DTypeKV), HEAD_DIM_CKV / 32UL);
constexpr uint32_t bdx = HEAD_DIM_CKV / vec_size_ckv;
constexpr uint32_t vec_size_kpe = HEAD_DIM_KPE / bdx;
constexpr uint32_t bdy = 8;
constexpr uint32_t tile_size_qo_heads = 2;
constexpr uint32_t qo_heads_per_block = bdy * tile_size_qo_heads;
constexpr uint32_t num_threads = std::max(128U, bdx * bdy);
constexpr uint32_t bdz = num_threads / (bdx * bdy);
gdy = ceil_div(num_qo_heads, qo_heads_per_block);
const uint32_t smem_size =
NUM_STAGES_SMEM * bdy * bdz * (HEAD_DIM_CKV + HEAD_DIM_KPE) * sizeof(DTypeKV) +
std::max(num_threads * sizeof(size_t) * 2, 2 * bdy * bdz * sizeof(float));
auto kernel =
BatchDecodeWithPagedKVCacheKernelMLA<NUM_STAGES_SMEM, vec_size_ckv, vec_size_kpe, bdx, bdy,
bdz, tile_size_qo_heads, AttentionVariant, Params>;
int num_blocks_per_sm = 0;
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
max_grid_size = num_blocks_per_sm * num_sm;
if (batch_size * gdy >= max_grid_size) {
split_kv = false;
max_num_pages_per_batch = 1;
for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
max_num_pages_per_batch = std::max<uint32_t>(
max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]);
}
new_batch_size = batch_size;
} else {
// compute max_num_pages_per_batch and new_batch_size
std::vector<IdType> num_pages(batch_size);
for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx];
}
std::tie(max_num_pages_per_batch, new_batch_size) =
PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages,
std::max(128 / page_size, 1U));
if (new_batch_size == batch_size && !enable_cuda_graph) {
// do not use partition-kv kernel for short sequence, when not using CUDAGraph
split_kv = false;
} else {
// when using CUDAGraph, we always use partition-kv kernel
split_kv = true;
}
}
return cudaSuccess;
});
}
template <uint32_t HEAD_DIM_CKV, uint32_t HEAD_DIM_KPE, uint32_t QO_TILE_LEN,
typename AttentionVariant, typename Params>
inline cudaError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80(
bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch,
uint32_t& new_batch_size, uint32_t& gdy_, uint32_t batch_size,
typename Params::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size,
bool enable_cuda_graph, cudaStream_t stream) {
using DTypeKV = typename Params::DTypeKV;
using IdType = typename Params::IdType;
auto [smem_size, gdy, k_warps] =
LaunchSpecForDecodeKernelMlaCuteSM80<HEAD_DIM_CKV, HEAD_DIM_KPE, QO_TILE_LEN, DTypeKV>(
num_qo_heads);
gdy_ = gdy;
const uint32_t num_threads = k_warps * 32;
auto kernel =
BatchDecodeWithPagedKVCacheKernelMlaCuteSM80<HEAD_DIM_CKV, HEAD_DIM_KPE, QO_TILE_LEN, Params>;
int num_blocks_per_sm;
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
// FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
// num_threads, smem_size));
// fixme: num_blocks_per_sm is 0 derived from cudaOccupancyMaxActiveBlocksPerMultiprocessor at
// times, and we fill smem with q-heads as many as possible, so num_blocks_per_sm should be 1
num_blocks_per_sm = 1;
max_grid_size = num_blocks_per_sm * num_sm;
if (batch_size * gdy >= max_grid_size) {
split_kv = false;
max_num_pages_per_batch = 1;
for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
max_num_pages_per_batch = std::max<uint32_t>(
max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]);
}
new_batch_size = batch_size;
} else {
// compute max_num_pages_per_batch and new_batch_size
std::vector<IdType> num_pages(batch_size);
for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx];
}
std::tie(max_num_pages_per_batch, new_batch_size) =
PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages,
std::max(128 / page_size, 1U));
if (new_batch_size == batch_size && !enable_cuda_graph) {
// do not use partition-kv kernel for short sequence, when not using CUDAGraph
split_kv = false;
} else {
// when using CUDAGraph, we always use partition-kv kernel
split_kv = true;
}
}
return cudaSuccess;
}
/*!
* \brief Partition Paged KV-Cache into multiple chunks on KV sequence length
* \tparam IdType A template type indicates the index data type
* \param old_batch_size The batch size of the old Paged KV-Cache
* \param old_page_indptr_h The host-side page indptr of the old Paged KV-Cache
* \param max_num_pages_per_batch The maximum number of pages per batch
* \param new_paged_kv_d The device-side new Paged KV-Cache
* \param stream The cuda stream to launch the kernel
* \return status Indicates whether CUDA calls are successful
*/
template <typename IdType>
inline auto DecodeSplitKVIndptr(IdType* indptr_h, uint32_t batch_size, uint32_t kv_chunk_size) {
std::vector<IdType> request_indices, kv_tile_indices, o_indptr;
o_indptr.push_back(0);
for (uint32_t batch_idx = 0; batch_idx < batch_size; batch_idx++) {
uint32_t num_tiles_kv = ceil_div(
std::max<uint32_t>(indptr_h[batch_idx + 1] - indptr_h[batch_idx], 1U), kv_chunk_size);
for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) {
request_indices.push_back(batch_idx);
kv_tile_indices.push_back(kv_tile_idx);
}
o_indptr.push_back(o_indptr.back() + num_tiles_kv);
}
return std::make_tuple(request_indices, kv_tile_indices, o_indptr);
}
struct DecodePlanInfo {
int64_t padded_batch_size;
int64_t v_offset;
int64_t s_offset;
int64_t request_indices_offset;
int64_t kv_tile_indices_offset;
int64_t o_indptr_offset;
int64_t block_valid_mask_offset;
int64_t kv_chunk_size_ptr_offset;
bool enable_cuda_graph;
bool split_kv;
DecodePlanInfo()
: padded_batch_size(0),
v_offset(0),
s_offset(0),
request_indices_offset(0),
kv_tile_indices_offset(0),
o_indptr_offset(0),
block_valid_mask_offset(0),
kv_chunk_size_ptr_offset(0),
enable_cuda_graph(false),
split_kv(false) {}
// convert DecodePlanInfo to std::vector<int64_t>
std::vector<int64_t> ToVector() const {
return {padded_batch_size,
v_offset,
s_offset,
request_indices_offset,
kv_tile_indices_offset,
o_indptr_offset,
block_valid_mask_offset,
kv_chunk_size_ptr_offset,
enable_cuda_graph,
split_kv};
}
// From std::vector<int64_t> to DecodePlanInfo
void FromVector(const std::vector<int64_t>& vec) {
if (vec.size() != 10) {
std::ostringstream err_msg;
err_msg << "DecodePlanInfo::FromVector: vec.size() should be 10, but got " << vec.size();
FLASHINFER_ERROR(err_msg.str());
}
padded_batch_size = vec[0];
v_offset = vec[1];
s_offset = vec[2];
request_indices_offset = vec[3];
kv_tile_indices_offset = vec[4];
o_indptr_offset = vec[5];
block_valid_mask_offset = vec[6];
kv_chunk_size_ptr_offset = vec[7];
enable_cuda_graph = vec[8];
split_kv = vec[9];
}
};
template <uint32_t HEAD_DIM, PosEncodingMode POS_ENCODING_MODE, typename AttentionVariant,
typename Params, typename WorkEstimationFunc>
inline cudaError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes,
void* int_buffer, void* page_locked_int_buffer,
size_t int_workspace_size_in_bytes, DecodePlanInfo& plan_info,
typename Params::IdType* indptr_h, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t page_size, bool enable_cuda_graph,
cudaStream_t stream, WorkEstimationFunc work_estimation_func) {
using DTypeO = typename Params::DTypeO;
using IdType = typename Params::IdType;
bool split_kv;
uint32_t max_grid_size, kv_chunk_size_in_pages, new_batch_size, gdy;
FLASHINFER_CUDA_CALL(work_estimation_func(split_kv, max_grid_size, kv_chunk_size_in_pages,
new_batch_size, gdy, batch_size, indptr_h, num_qo_heads,
page_size, enable_cuda_graph, stream));
size_t padded_batch_size;
plan_info.enable_cuda_graph = enable_cuda_graph;
plan_info.split_kv = split_kv;
padded_batch_size =
(enable_cuda_graph) ? (split_kv ? max_grid_size / gdy : batch_size) : new_batch_size;
plan_info.padded_batch_size = padded_batch_size;
auto [request_indices_vec, kv_tile_indices_vec, o_indptr_vec] =
DecodeSplitKVIndptr(indptr_h, batch_size, kv_chunk_size_in_pages);
AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes);
plan_info.request_indices_offset = int_allocator.aligned_alloc_offset(
padded_batch_size * sizeof(IdType), 16, "batch_decode_request_indices");
plan_info.kv_tile_indices_offset = int_allocator.aligned_alloc_offset(
padded_batch_size * sizeof(IdType), 16, "batch_decode_kv_tile_indices");
plan_info.o_indptr_offset = int_allocator.aligned_alloc_offset(
(padded_batch_size + 1) * sizeof(IdType), 16, "batch_decode_o_indptr");
plan_info.kv_chunk_size_ptr_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_decode_kv_chunk_size_ptr");
IdType* request_indices_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.request_indices_offset);
IdType* kv_tile_indices_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.kv_tile_indices_offset);
IdType* o_indptr_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.o_indptr_offset);
IdType* kv_chunk_size_ptr_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset);
std::copy(request_indices_vec.begin(), request_indices_vec.end(), request_indices_h);
std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), kv_tile_indices_h);
std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), o_indptr_h);
kv_chunk_size_ptr_h[0] = kv_chunk_size_in_pages * page_size;
if (split_kv) {
AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes);
plan_info.v_offset = float_allocator.aligned_alloc_offset(
num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(float), 16, "batch_decode_tmp_v");
plan_info.s_offset = float_allocator.aligned_alloc_offset(
num_qo_heads * padded_batch_size * sizeof(float), 16, "batch_decode_tmp_s");
plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset(
padded_batch_size * sizeof(bool), 16, "batch_decode_block_valid_mask");
bool* block_valid_mask_h =
GetPtrFromBaseOffset<bool>(page_locked_int_buffer, plan_info.block_valid_mask_offset);
for (uint32_t i = 0; i < padded_batch_size; ++i) {
block_valid_mask_h[i] = i < new_batch_size;
}
}
size_t num_bytes_to_copy = int_allocator.num_allocated_bytes();
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy,
cudaMemcpyHostToDevice, stream));
return cudaSuccess;
}
template <typename IdType>
inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h,
uint32_t total_num_rows, uint32_t batch_size,
uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim,
uint32_t page_size, uint32_t max_batch_size_if_split,
bool enable_cuda_graph) {
std::vector<IdType> request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr;
merge_indptr.push_back(0);
o_indptr.push_back(0);
const uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
// step 1: determine packed_qo_len_arr and verify qo_indptr contents.
std::vector<int64_t> packed_qo_len_arr(batch_size), kv_len_arr(batch_size);
for (uint32_t i = 0; i < batch_size; ++i) {
packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size);
if (packed_qo_len_arr[i] < 0) {
std::ostringstream err_msg;
err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]"
<< qo_indptr_h[i] << " should be non-negative";
FLASHINFER_ERROR(err_msg.str());
}
kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]);
if (kv_len_arr[i] < 0) {
std::ostringstream err_msg;
err_msg << "kv_indptr[" << i + 1 << "]" << kv_indptr_h[i + 1] << " - kv_indptr[" << i << "]"
<< kv_indptr_h[i] << " should be non-negative";
FLASHINFER_ERROR(err_msg.str());
}
}
// step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q
const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U);
uint32_t cta_tile_q;
uint32_t total_num_tiles_q;
if (enable_cuda_graph) {
// When CUDA graphs are enabled, the lengths of sequences determined by
// qo_indptr_h can vary. We assume that the dummy data based on which
// the CUDA graph is created fixes the maximum number of tokens.
const uint64_t max_seq_len = total_num_rows - batch_size + 1;
uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size;
cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim);
// Find an upper bound for the number of tiles, derived from the total
// number of rows and the batch size. The sum of qo lengths rounded
// up to cta_tile_q will not exceed this number derived from the total
// number of rows.
total_num_tiles_q = ceil_div(total_num_rows * gqa_group_size, cta_tile_q) + batch_size - 1;
} else {
int64_t sum_packed_qo_len = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
sum_packed_qo_len += packed_qo_len_arr[i];
}
const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size;
cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim);
total_num_tiles_q = 0;
for (uint32_t i = 0; i < batch_size; ++i) {
total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q);
}
}
auto [split_kv, kv_chunk_size] =
PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_batch_size_if_split, packed_qo_len_arr,
kv_len_arr, cta_tile_q, min_kv_chunk_size);
// step 3: split qo_indptr and kv_indptr
uint32_t new_batch_size = 0;
for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) {
const int64_t packed_qo_len = packed_qo_len_arr[request_idx];
const int64_t kv_len = std::max(int(kv_len_arr[request_idx]), 1);
const int64_t num_tiles_q = ceil_div(packed_qo_len, cta_tile_q);
const int64_t num_tiles_kv = ceil_div(kv_len, kv_chunk_size);
for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) {
for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) {
new_batch_size += 1;
request_indices.push_back(request_idx);
qo_tile_indices.push_back(q_tile_idx);
kv_tile_indices.push_back(kv_tile_idx);
}
}
int64_t qo_len = packed_qo_len / gqa_group_size;
for (uint32_t row = 0; row < qo_len; ++row) {
merge_indptr.push_back(merge_indptr.back() + num_tiles_kv);
}
o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv);
}
const size_t padded_batch_size =
enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size;
FLASHINFER_CHECK(new_batch_size <= padded_batch_size,
"new batch size should not exceed padded batch size");
// step 4: multiply kv_chunk_size by page_size
kv_chunk_size *= page_size;
return std::make_tuple(split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size,
std::move(request_indices), std::move(qo_tile_indices),
std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr));
}
struct PrefillPlanInfo {
int64_t padded_batch_size;
int64_t total_num_rows;
int64_t total_num_rows_offset;
int64_t cta_tile_q;
int64_t request_indices_offset;
int64_t qo_tile_indices_offset;
int64_t kv_tile_indices_offset;
int64_t merge_indptr_offset;
int64_t o_indptr_offset;
int64_t kv_chunk_size_ptr_offset;
int64_t v_offset;
int64_t s_offset;
int64_t block_valid_mask_offset;
bool enable_cuda_graph;
bool split_kv;
PrefillPlanInfo()
: padded_batch_size(0),
total_num_rows(0),
total_num_rows_offset(0),
cta_tile_q(0),
request_indices_offset(0),
qo_tile_indices_offset(0),
kv_tile_indices_offset(0),
merge_indptr_offset(0),
o_indptr_offset(0),
kv_chunk_size_ptr_offset(0),
v_offset(0),
s_offset(0),
block_valid_mask_offset(0),
enable_cuda_graph(false),
split_kv(false) {}
// convert PrefillPlanInfo to std::vector<int64_t>
std::vector<int64_t> ToVector() const {
return {padded_batch_size,
total_num_rows,
total_num_rows_offset,
cta_tile_q,
request_indices_offset,
qo_tile_indices_offset,
kv_tile_indices_offset,
merge_indptr_offset,
o_indptr_offset,
kv_chunk_size_ptr_offset,
v_offset,
s_offset,
block_valid_mask_offset,
enable_cuda_graph,
split_kv};
}
// From std::vector<int64_t> to PrefillPlanInfo
void FromVector(const std::vector<int64_t>& vec) {
if (vec.size() != 15) {
std::ostringstream err_msg;
err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 15, but got " << vec.size();
FLASHINFER_ERROR(err_msg.str());
}
padded_batch_size = vec[0];
total_num_rows = vec[1];
total_num_rows_offset = vec[2];
cta_tile_q = vec[3];
request_indices_offset = vec[4];
qo_tile_indices_offset = vec[5];
kv_tile_indices_offset = vec[6];
merge_indptr_offset = vec[7];
o_indptr_offset = vec[8];
kv_chunk_size_ptr_offset = vec[9];
v_offset = vec[10];
s_offset = vec[11];
block_valid_mask_offset = vec[12];
enable_cuda_graph = vec[13];
split_kv = vec[14];
}
};
template <typename IdType>
inline cudaError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes,
void* int_buffer, void* page_locked_int_buffer,
size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info,
IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows,
uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size,
bool enable_cuda_graph, uint32_t sizeof_dtype_o,
cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
<< num_kv_heads;
FLASHINFER_ERROR(err_msg.str());
}
// step 0: get the number of SMs
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
int num_blocks_per_sm = 2;
int max_grid_size = num_blocks_per_sm * num_sm;
uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads;
// step 2: determine kv_chunk_size
auto [split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec,
qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] =
PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads,
num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split,
enable_cuda_graph);
plan_info.cta_tile_q = cta_tile_q;
plan_info.total_num_rows = total_num_rows;
plan_info.enable_cuda_graph = enable_cuda_graph;
plan_info.padded_batch_size = padded_batch_size;
plan_info.split_kv = split_kv;
AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes);
plan_info.request_indices_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * padded_batch_size, 16, "batch_prefill_request_indices");
plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * padded_batch_size, 16, "batch_prefill_qo_tile_indices");
plan_info.kv_tile_indices_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * padded_batch_size, 16, "batch_prefill_kv_tile_indices");
plan_info.o_indptr_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * (batch_size + 1),
16, "batch_prefill_o_indptr");
plan_info.kv_chunk_size_ptr_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr");
if (plan_info.enable_cuda_graph) {
plan_info.total_num_rows_offset =
int_allocator.aligned_alloc_offset(sizeof(uint32_t), 16, "batch_prefill_total_num_rows");
uint32_t* total_num_rows_h =
GetPtrFromBaseOffset<uint32_t>(page_locked_int_buffer, plan_info.total_num_rows_offset);
*total_num_rows_h = qo_indptr_h[batch_size];
}
IdType* request_indices_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.request_indices_offset);
IdType* qo_tile_indices_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.qo_tile_indices_offset);
IdType* kv_tile_indices_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.kv_tile_indices_offset);
IdType* o_indptr_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.o_indptr_offset);
IdType* kv_chunk_size_ptr_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset);
std::copy(request_indices_vec.begin(), request_indices_vec.end(), request_indices_h);
std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), qo_tile_indices_h);
std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), kv_tile_indices_h);
std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), o_indptr_h);
kv_chunk_size_ptr_h[0] = kv_chunk_size;
if (split_kv) {
AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes);
plan_info.v_offset = float_allocator.aligned_alloc_offset(
num_qo_heads * padded_batch_size * cta_tile_q * head_dim_vo * sizeof(float), 16,
"batch_prefill_tmp_v");
plan_info.s_offset = float_allocator.aligned_alloc_offset(
num_qo_heads * padded_batch_size * cta_tile_q * sizeof(float), 16, "batch_prefill_tmp_s");
plan_info.merge_indptr_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "batch_prefill_merge_indptr");
plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset(
sizeof(bool) * padded_batch_size, 16, "batch_prefill_block_valid_mask");
IdType* merge_indptr_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.merge_indptr_offset);
bool* block_valid_mask_h =
GetPtrFromBaseOffset<bool>(page_locked_int_buffer, plan_info.block_valid_mask_offset);
std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), merge_indptr_h);
for (uint32_t i = 0; i < padded_batch_size; ++i) {
block_valid_mask_h[i] = i < new_batch_size;
}
}
size_t num_bytes_to_copy = int_allocator.num_allocated_bytes();
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy,
cudaMemcpyHostToDevice, stream));
return cudaSuccess;
}
inline float cost_function(int qo_len, int kv_len) { return 2 * float(qo_len) + kv_len; }
template <typename T>
std::vector<T> flatten(const std::vector<std::vector<T>>& vec, int size_after_flatten) {
std::vector<T> result;
result.reserve(size_after_flatten);
for (const auto& inner_vec : vec) {
result.insert(result.end(), inner_vec.begin(), inner_vec.end());
}
return result;
}
inline int packed_causal_kv_end(int qo_len, int kv_len, int qo_tile_idx, int cluster_tile_q,
int num_qo_tiles, int group_size) {
if (qo_tile_idx + 1 == num_qo_tiles) {
return kv_len;
}
int kv_len_init = kv_len - qo_len; // right aligned
return max(min(kv_len_init + ceil_div((qo_tile_idx + 1) * cluster_tile_q, group_size), kv_len),
0);
}
struct PrefillPlanSM90Info {
int64_t qo_tile_indices_offset;
int64_t qo_indptr_offset;
int64_t kv_indptr_offset;
int64_t qo_len_offset;
int64_t kv_len_offset;
int64_t head_indices_offset;
int64_t work_indptr_offset;
int64_t batch_indices_offset;
bool same_schedule_for_all_heads;
PrefillPlanSM90Info()
: qo_tile_indices_offset(0),
qo_indptr_offset(0),
kv_indptr_offset(0),
qo_len_offset(0),
kv_len_offset(0),
head_indices_offset(0),
work_indptr_offset(0),
batch_indices_offset(0),
same_schedule_for_all_heads(false) {}
// convert PrefillPlanSM90Info to std::vector<int64_t>
std::vector<int64_t> ToVector() const {
return {qo_tile_indices_offset, qo_indptr_offset, kv_indptr_offset,
qo_len_offset, kv_len_offset, head_indices_offset,
work_indptr_offset, batch_indices_offset, same_schedule_for_all_heads};
}
// From std::vector<int64_t> to PrefillPlanSM90Info
void FromVector(const std::vector<int64_t>& vec) {
if (vec.size() != 9) {
std::ostringstream err_msg;
err_msg << "PrefillPlanSM90Info::FromVector: vec.size() should be 9, but got " << vec.size();
FLASHINFER_ERROR(err_msg.str());
}
qo_tile_indices_offset = vec[0];
qo_indptr_offset = vec[1];
kv_indptr_offset = vec[2];
qo_len_offset = vec[3];
kv_len_offset = vec[4];
head_indices_offset = vec[5];
work_indptr_offset = vec[6];
batch_indices_offset = vec[7];
same_schedule_for_all_heads = vec[8];
}
};
template <typename IdType>
inline cudaError_t PrefillSM90Plan(
void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer,
void* page_locked_int_buffer, size_t int_workspace_size_in_bytes,
PrefillPlanSM90Info& plan_info, IdType* qo_indptr_h, IdType* kv_indptr_h, IdType* kv_len_arr_h,
uint32_t total_num_rows, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads,
uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size, bool causal,
bool enable_cuda_graph, uint32_t sizeof_dtype_o, cudaStream_t stream) {
if (num_qo_heads % num_kv_heads != 0) {
std::ostringstream err_msg;
err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads "
<< num_kv_heads;
FLASHINFER_ERROR(err_msg.str());
}
std::vector<std::tuple<int, int, int>> idx_qo_kv_len_vec;
for (uint32_t i = 0; i < batch_size; ++i) {
int qo_len = qo_indptr_h[i + 1] - qo_indptr_h[i];
int kv_len = kv_len_arr_h[i];
if (kv_len < 0) {
std::ostringstream err_msg;
err_msg << "kv_len[" << i << "]" << kv_len << " should be non-negative";
FLASHINFER_ERROR(err_msg.str());
}
if (qo_len < 0) {
std::ostringstream err_msg;
err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]"
<< qo_indptr_h[i] << " should be non-negative";
FLASHINFER_ERROR(err_msg.str());
}
idx_qo_kv_len_vec.push_back({i, qo_len, kv_len});
}
std::sort(idx_qo_kv_len_vec.begin(), idx_qo_kv_len_vec.end(),
[](const auto& a, const auto& b) { return std::get<2>(a) > std::get<2>(b); });
int cta_tile_q = 128;
if (head_dim_vo == 64) {
cta_tile_q = 192;
}
int device = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&device));
int num_sm90_ctas = 0;
FLASHINFER_CUDA_CALL(
cudaDeviceGetAttribute(&num_sm90_ctas, cudaDevAttrMultiProcessorCount, device));
MinHeap cta_cost_heap(num_sm90_ctas);
std::vector<std::vector<IdType>> cta_qo_tile_indices(num_sm90_ctas, std::vector<IdType>()),
cta_qo_indptr(num_sm90_ctas, std::vector<IdType>()),
cta_kv_indptr(num_sm90_ctas, std::vector<IdType>()),
cta_qo_len(num_sm90_ctas, std::vector<IdType>()),
cta_kv_len(num_sm90_ctas, std::vector<IdType>()),
cta_head_indices(num_sm90_ctas, std::vector<IdType>()),
cta_batch_indices(num_sm90_ctas, std::vector<IdType>());
int max_num_works_per_head = ceil_div(total_num_rows, cta_tile_q) + batch_size - 1;
plan_info.same_schedule_for_all_heads = max_num_works_per_head > 4096;
for (int qo_head_idx = 0;
qo_head_idx < (plan_info.same_schedule_for_all_heads ? 1 : num_qo_heads); ++qo_head_idx) {
for (auto& [i, qo_len, kv_len] : idx_qo_kv_len_vec) {
int num_qo_tiles = ceil_div(qo_len, cta_tile_q);
for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) {
auto [cta_idx, accum_cost] = cta_cost_heap.pop();
// NOTE(Zihao): our current FA3 implementation do not fuse query and group heads
// so the group_size in cost_function is always 1
int effective_kv_len =
causal ? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx, cta_tile_q, num_qo_tiles, 1)
: kv_len;
cta_cost_heap.insert({cta_idx, accum_cost + cost_function(cta_tile_q, effective_kv_len)});
cta_qo_tile_indices[cta_idx].push_back(qo_tile_idx);
cta_qo_indptr[cta_idx].push_back(qo_indptr_h[i]);
cta_qo_len[cta_idx].push_back(qo_len);
cta_kv_indptr[cta_idx].push_back(kv_indptr_h[i]);
cta_kv_len[cta_idx].push_back(kv_len);
cta_head_indices[cta_idx].push_back(qo_head_idx);
cta_batch_indices[cta_idx].push_back(i);
}
}
}
std::vector<IdType> work_indptr_vec(num_sm90_ctas + 1, 0);
for (uint32_t i = 0; i < num_sm90_ctas; ++i) {
work_indptr_vec[i + 1] = work_indptr_vec[i] + cta_qo_tile_indices[i].size();
}
int total_num_works = work_indptr_vec.back();
auto qo_tile_indices_vec = flatten(cta_qo_tile_indices, total_num_works);
auto qo_indptr_vec = flatten(cta_qo_indptr, total_num_works);
auto kv_indptr_vec = flatten(cta_kv_indptr, total_num_works);
auto qo_len_vec = flatten(cta_qo_len, total_num_works);
auto kv_len_vec = flatten(cta_kv_len, total_num_works);
auto head_indices_vec = flatten(cta_head_indices, total_num_works);
auto batch_indices_vec = flatten(cta_batch_indices, total_num_works);
AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes);
int max_total_num_works;
if (enable_cuda_graph) {
max_total_num_works = plan_info.same_schedule_for_all_heads
? max_num_works_per_head
: max_num_works_per_head * num_qo_heads;
} else {
max_total_num_works = total_num_works;
}
plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_qo_tile_indices");
plan_info.qo_indptr_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_qo_offset");
plan_info.kv_indptr_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_kv_offset");
plan_info.qo_len_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works,
16, "batch_prefill_sm90_qo_len");
plan_info.kv_len_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works,
16, "batch_prefill_sm90_kv_len");
plan_info.head_indices_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_head_indices");
plan_info.work_indptr_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * (num_sm90_ctas + 1), 16, "batch_prefill_sm90_work_indptr");
plan_info.batch_indices_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_batch_indices");
IdType* qo_tile_indices_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.qo_tile_indices_offset);
IdType* qo_offset_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.qo_indptr_offset);
IdType* kv_offset_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.kv_indptr_offset);
IdType* qo_len_h = GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.qo_len_offset);
IdType* kv_len_h = GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.kv_len_offset);
IdType* head_indices_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.head_indices_offset);
IdType* work_indptr_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.work_indptr_offset);
IdType* batch_indices_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.batch_indices_offset);
std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), qo_tile_indices_h);
std::copy(qo_indptr_vec.begin(), qo_indptr_vec.end(), qo_offset_h);
std::copy(kv_indptr_vec.begin(), kv_indptr_vec.end(), kv_offset_h);
std::copy(qo_len_vec.begin(), qo_len_vec.end(), qo_len_h);
std::copy(kv_len_vec.begin(), kv_len_vec.end(), kv_len_h);
std::copy(head_indices_vec.begin(), head_indices_vec.end(), head_indices_h);
std::copy(work_indptr_vec.begin(), work_indptr_vec.end(), work_indptr_h);
std::copy(batch_indices_vec.begin(), batch_indices_vec.end(), batch_indices_h);
size_t num_bytes_to_copy = int_allocator.num_allocated_bytes();
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy,
cudaMemcpyHostToDevice, stream));
return cudaSuccess;
}
template <uint32_t NUM_TASKS>
struct HolisticPlanInfo {
int64_t num_blks_x;
int64_t num_blks_y;
struct {
int64_t q_indptr_offset;
int64_t kv_indptr_offset;
int64_t partial_indptr_offset;
int64_t q_len_offset;
int64_t kv_len_offset;
int64_t q_start_offset;
int64_t kv_start_offset;
int64_t kv_end_offset;
int64_t kv_head_idx_offset;
int64_t work_indptr_offset;
} tasks[NUM_TASKS];
int64_t len_kv_chunk_offset;
int64_t partial_o_offset;
int64_t partial_lse_offset;
int64_t merge_indptr_offset;
int64_t merge_o_indices_offset;
int64_t num_qo_len_offset;
static constexpr uint32_t NUM_TASK_ARGS = 10;
static constexpr uint32_t NUM_SHARED_ARGS = 8;
std::vector<int64_t> ToVector() const {
std::vector<int64_t> vec;
vec.push_back(num_blks_x);
vec.push_back(num_blks_y);
for (uint32_t i = 0; i < NUM_TASKS; ++i) {
vec.push_back(tasks[i].q_indptr_offset);
vec.push_back(tasks[i].kv_indptr_offset);
vec.push_back(tasks[i].partial_indptr_offset);
vec.push_back(tasks[i].q_len_offset);
vec.push_back(tasks[i].kv_len_offset);
vec.push_back(tasks[i].q_start_offset);
vec.push_back(tasks[i].kv_start_offset);
vec.push_back(tasks[i].kv_end_offset);
vec.push_back(tasks[i].kv_head_idx_offset);
vec.push_back(tasks[i].work_indptr_offset);
}
vec.push_back(len_kv_chunk_offset);
vec.push_back(partial_o_offset);
vec.push_back(partial_lse_offset);
vec.push_back(merge_indptr_offset);
vec.push_back(merge_o_indices_offset);
vec.push_back(num_qo_len_offset);
return vec;
}
void FromVector(const std::vector<int64_t>& vec) {
if (vec.size() != NUM_SHARED_ARGS + NUM_TASKS * NUM_TASK_ARGS) {
std::ostringstream err_msg;
err_msg << "HolisticPlanInfo::FromVector: vec.size() should be "
<< NUM_SHARED_ARGS + NUM_TASKS * NUM_TASK_ARGS << ", but got " << vec.size();
FLASHINFER_ERROR(err_msg.str());
}
num_blks_x = vec[0];
num_blks_y = vec[1];
for (uint32_t i = 0; i < NUM_TASKS; ++i) {
tasks[i].q_indptr_offset = vec[2 + i * NUM_TASK_ARGS + 0];
tasks[i].kv_indptr_offset = vec[2 + i * NUM_TASK_ARGS + 1];
tasks[i].partial_indptr_offset = vec[2 + i * NUM_TASK_ARGS + 2];
tasks[i].q_len_offset = vec[2 + i * NUM_TASK_ARGS + 3];
tasks[i].kv_len_offset = vec[2 + i * NUM_TASK_ARGS + 4];
tasks[i].q_start_offset = vec[2 + i * NUM_TASK_ARGS + 5];
tasks[i].kv_start_offset = vec[2 + i * NUM_TASK_ARGS + 6];
tasks[i].kv_end_offset = vec[2 + i * NUM_TASK_ARGS + 7];
tasks[i].kv_head_idx_offset = vec[2 + i * NUM_TASK_ARGS + 8];
tasks[i].work_indptr_offset = vec[2 + i * NUM_TASK_ARGS + 9];
}
len_kv_chunk_offset = vec[2 + NUM_TASKS * NUM_TASK_ARGS];
partial_o_offset = vec[3 + NUM_TASKS * NUM_TASK_ARGS];
partial_lse_offset = vec[4 + NUM_TASKS * NUM_TASK_ARGS];
merge_indptr_offset = vec[5 + NUM_TASKS * NUM_TASK_ARGS];
merge_o_indices_offset = vec[6 + NUM_TASKS * NUM_TASK_ARGS];
num_qo_len_offset = vec[7 + NUM_TASKS * NUM_TASK_ARGS];
}
};
template <typename IdType>
inline cudaError_t TwoStageHolisticPlan(void* float_buffer, size_t float_workspace_size_in_bytes,
void* int_buffer, void* page_locked_int_buffer,
size_t int_workspace_size_in_bytes,
HolisticPlanInfo<2>& plan_info, IdType* qo_indptr_h,
IdType* kv_indptr_h, IdType* kv_len_arr_h,
uint32_t batch_size, uint32_t num_qo_heads,
uint32_t num_kv_heads, uint32_t head_dim, bool causal,
cudaStream_t stream) {
constexpr uint32_t NUM_TASKS = 2;
const uint32_t CTA_TILE_Q_SIZES[NUM_TASKS] = {128, 16};
int num_sm = 0;
int dev_id = 0;
uint32_t gqa_group_size = num_qo_heads / num_kv_heads;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
if (head_dim >= 256) {
// NOTE (Yilong): optimize this code path
// constraint gridDim due to cooperative group
num_sm *= 1;
} else {
// NOTE(Zihao): two cta per sm
num_sm *= 2;
}
// step 0. determine the number of blocks in x and y dimensions
std::vector<std::tuple<int, int, int>> idx_qo_kv_len_vec[NUM_TASKS];
for (uint32_t i = 0; i < batch_size; ++i) {
if (qo_indptr_h[i + 1] - qo_indptr_h[i] < 0) {
std::ostringstream err_msg;
err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]"
<< qo_indptr_h[i] << " should be non-negative";
FLASHINFER_ERROR(err_msg.str());
}
int qo_len = qo_indptr_h[i + 1] - qo_indptr_h[i];
int packed_qo_len = qo_len * gqa_group_size;
int kv_len = kv_len_arr_h[i];
if (packed_qo_len > CTA_TILE_Q_SIZES[1]) {
idx_qo_kv_len_vec[0].push_back({i, qo_len, kv_len});
} else {
idx_qo_kv_len_vec[1].push_back({i, qo_len, kv_len});
}
}
int cluster_size = 1;
int num_clusters = num_sm / cluster_size;
plan_info.num_blks_x = cluster_size;
plan_info.num_blks_y = num_clusters;
auto f = [](int x) {
if (x <= 128) {
// This aligns with CTA_TILE_KV in persistent mainloop
// NOTE (Yilong): Optimize here for smaller batch/seqlen scenarios
return 128;
}
return ceil_div(x, 256) * 256;
};
MinHeap cluster_cost_heap(num_clusters);
AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes);
// NOTE(Zihao): adjust it later
const int max_total_num_works = 65536;
const int max_num_kv_splits =
4 * num_clusters * cluster_size * (CTA_TILE_Q_SIZES[0] + CTA_TILE_Q_SIZES[1]);
// calculate kv_len_limit first, considering all workloads
int64_t total_kv_lens = 0;
for (uint32_t task = 0; task < NUM_TASKS; ++task) {
int cluster_tile_q = CTA_TILE_Q_SIZES[task] * cluster_size;
for (auto& [_, qo_len, kv_len] : idx_qo_kv_len_vec[task]) {
int packed_qo_len = qo_len * gqa_group_size;
int num_qo_tiles = ceil_div(packed_qo_len, cluster_tile_q);
for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) {
int effective_kv_len =
causal ? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx, cluster_tile_q, num_qo_tiles,
gqa_group_size)
: kv_len;
total_kv_lens += effective_kv_len;
}
}
}
// used for remapping the output offsets
// layout [packed_qo_len x num_kv_tiles, num_kv_heads, head_dim]
int partial_o_nnz = 0;
std::vector<IdType> merge_indptr, merge_o_indices, num_expand_qo_len_vec;
std::vector<IdType> cluster_len_kv_chunk(NUM_TASKS, 0);
merge_indptr.push_back(partial_o_nnz);
for (uint32_t task = 0; task < NUM_TASKS; ++task) {
int cluster_tile_q = CTA_TILE_Q_SIZES[task] * cluster_size;
int kv_len_limit = f(std::max(ceil_div(total_kv_lens * num_kv_heads, num_clusters), 1L));
if (cluster_tile_q >= 64) {
// chunked-prefill workloads are much more expensive than decode
// so we use a smaller kv_len_limit for chunked-prefill workloads
kv_len_limit /= std::min(num_kv_heads, 2U);
}
cluster_len_kv_chunk[task] = kv_len_limit;
std::vector<std::vector<IdType>> cluster_q_indptr(num_clusters, std::vector<IdType>()),
cluster_kv_indptr(num_clusters, std::vector<IdType>()),
cluster_q_len(num_clusters, std::vector<IdType>()),
cluster_kv_len(num_clusters, std::vector<IdType>()),
cluster_q_start(num_clusters, std::vector<IdType>()),
cluster_kv_start(num_clusters, std::vector<IdType>()),
cluster_kv_end(num_clusters, std::vector<IdType>()),
cluster_kv_head_idx(num_clusters, std::vector<IdType>()),
cluster_partial_indptr(num_clusters, std::vector<IdType>());
for (auto& [i, qo_len, kv_len] : idx_qo_kv_len_vec[task]) {
int packed_qo_len = qo_len * gqa_group_size;
int num_qo_tiles = ceil_div(packed_qo_len, cluster_tile_q);
// NOTE (Yilong): this ordering correspoinds to the layout of reduction kernel
for (int qo_tile_idx = 0; qo_tile_idx < num_qo_tiles; ++qo_tile_idx) {
int remaining_len = causal
? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx, cluster_tile_q,
num_qo_tiles, gqa_group_size)
: kv_len;
int kv_start = 0;
bool split_kv = remaining_len > kv_len_limit;
int num_kv_tiles = split_kv ? ceil_div(remaining_len, kv_len_limit) : 1;
int row_tile_size = std::min(cluster_tile_q, packed_qo_len - qo_tile_idx * cluster_tile_q);
bool zero_kv_len = (remaining_len == 0);
while (remaining_len > 0 || zero_kv_len) {
int actual_len = std::min(remaining_len, kv_len_limit);
for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) {
auto [cluster_idx, accum_cost] = cluster_cost_heap.pop();
cluster_cost_heap.insert(
{cluster_idx, accum_cost + cost_function(cluster_tile_q, actual_len)});
cluster_q_len[cluster_idx].push_back(qo_len);
cluster_kv_len[cluster_idx].push_back(kv_len);
cluster_q_indptr[cluster_idx].push_back(qo_indptr_h[i]);
cluster_kv_indptr[cluster_idx].push_back(kv_indptr_h[i]);
// use kv_chunk to rematerize num_kv_tiles and kv_tile_idx
cluster_partial_indptr[cluster_idx].push_back(partial_o_nnz);
cluster_q_start[cluster_idx].push_back(qo_tile_idx * cluster_tile_q);
cluster_kv_start[cluster_idx].push_back(kv_start);
cluster_kv_end[cluster_idx].push_back(kv_start + actual_len);
cluster_kv_head_idx[cluster_idx].push_back(kv_head_idx);
}
remaining_len -= actual_len;
zero_kv_len = (remaining_len == 0);
kv_start += actual_len;
if (zero_kv_len) {
break;
}
}
if (split_kv) {
// non-split kv is directly written through
for (int row = 0; row < row_tile_size; ++row) {
merge_indptr.push_back(merge_indptr.back() + num_kv_tiles);
// output layout: [qo_len, num_kv_heads, gqa_group_size, head_dim]
// merge_o_indices is the indices of `gqa_group_size` dimension
auto q = (qo_tile_idx * cluster_tile_q + row) / gqa_group_size,
r = (qo_tile_idx * cluster_tile_q + row) % gqa_group_size;
merge_o_indices.push_back((qo_indptr_h[i] + q) * num_kv_heads * gqa_group_size + r);
}
partial_o_nnz += row_tile_size * num_kv_tiles;
}
}
}
std::vector<IdType> work_indptr_vec(num_clusters + 1, 0);
for (uint32_t i = 0; i < num_clusters; ++i) {
work_indptr_vec[i + 1] = work_indptr_vec[i] + cluster_q_indptr[i].size();
}
int total_num_works = work_indptr_vec.back();
if (total_num_works > max_total_num_works) {
std::ostringstream err_msg;
err_msg << "total_num_works (#q tiles * #kv tiles) " << total_num_works
<< " exceeds max_total_num_works " << max_total_num_works;
FLASHINFER_ERROR(err_msg.str());
}
auto q_indptr_vec = flatten(cluster_q_indptr, total_num_works);
auto kv_indptr_vec = flatten(cluster_kv_indptr, total_num_works);
auto partial_indptr_vec = flatten(cluster_partial_indptr, total_num_works);
auto q_len_vec = flatten(cluster_q_len, total_num_works);
auto kv_len_vec = flatten(cluster_kv_len, total_num_works);
auto q_start_vec = flatten(cluster_q_start, total_num_works);
auto kv_start_vec = flatten(cluster_kv_start, total_num_works);
auto kv_end_vec = flatten(cluster_kv_end, total_num_works);
auto kv_head_idx_vec = flatten(cluster_kv_head_idx, total_num_works);
plan_info.tasks[task].q_indptr_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "q_indptr");
plan_info.tasks[task].kv_indptr_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "kv_indptr");
plan_info.tasks[task].partial_indptr_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * max_total_num_works, 16, "partial_indptr");
plan_info.tasks[task].q_len_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "q_len");
plan_info.tasks[task].kv_len_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "kv_len");
plan_info.tasks[task].q_start_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "q_start");
plan_info.tasks[task].kv_start_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "kv_start");
plan_info.tasks[task].kv_end_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "kv_end");
plan_info.tasks[task].kv_head_idx_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "kv_head_idx");
plan_info.tasks[task].work_indptr_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "work_indptr");
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].q_indptr_offset,
q_indptr_vec);
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].kv_indptr_offset,
kv_indptr_vec);
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].partial_indptr_offset,
partial_indptr_vec);
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].q_len_offset, q_len_vec);
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].kv_len_offset, kv_len_vec);
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].q_start_offset,
q_start_vec);
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].kv_start_offset,
kv_start_vec);
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].kv_end_offset, kv_end_vec);
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].kv_head_idx_offset,
kv_head_idx_vec);
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.tasks[task].work_indptr_offset,
work_indptr_vec);
}
plan_info.len_kv_chunk_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * NUM_TASKS, 16, "len_kv_chunk");
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.len_kv_chunk_offset,
cluster_len_kv_chunk);
if (merge_indptr.size() > max_num_kv_splits) {
std::ostringstream err_msg;
err_msg << "Number of kv splits " << merge_indptr.size() << " exceeds max buffer size "
<< max_num_kv_splits << ". Please increase the threshold.";
FLASHINFER_ERROR(err_msg.str());
}
// update num_qo_len_vec
num_expand_qo_len_vec.push_back(merge_indptr.size() - 1);
// allocate buffer for state merge function
plan_info.merge_indptr_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_num_kv_splits, 16, "merge_indptr");
plan_info.merge_o_indices_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_num_kv_splits, 16, "merge_o_indices");
plan_info.num_qo_len_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType), 16, "num_qo_len_offset");
// copy data to paged cpu buffer
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.merge_indptr_offset, merge_indptr);
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.merge_o_indices_offset, merge_o_indices);
CopyToPageLockedBuffer(page_locked_int_buffer, plan_info.num_qo_len_offset,
num_expand_qo_len_vec);
size_t num_bytes_to_copy = int_allocator.num_allocated_bytes();
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy,
cudaMemcpyHostToDevice, stream));
constexpr size_t sizeof_dtype_o = 2; // NOTE (Yilong): assume fp16
// Note(Yilong): times num_kv_heads as it is not counted in partial_o_nnz
AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes);
plan_info.partial_o_offset = float_allocator.aligned_alloc_offset(
max_num_kv_splits * sizeof_dtype_o * head_dim * num_kv_heads, 16, "holistic_partial_o");
plan_info.partial_lse_offset = float_allocator.aligned_alloc_offset(
max_num_kv_splits * sizeof(float) * num_kv_heads, 16, "holistic_partial_lse");
return cudaSuccess;
}
struct MLAPlanInfo {
int64_t num_blks_x;
int64_t num_blks_y;
int64_t q_indptr_offset;
int64_t kv_indptr_offset;
int64_t partial_indptr_offset;
int64_t merge_packed_offset_start_offset;
int64_t merge_packed_offset_end_offset;
int64_t merge_partial_packed_offset_start_offset;
int64_t merge_partial_packed_offset_end_offset;
int64_t merge_partial_stride_offset;
int64_t q_len_offset;
int64_t kv_len_offset;
int64_t q_start_offset;
int64_t kv_start_offset;
int64_t kv_end_offset;
int64_t work_indptr_offset;
int64_t partial_o_offset;
int64_t partial_lse_offset;
std::vector<int64_t> ToVector() const {
return {num_blks_x,
num_blks_y,
q_indptr_offset,
kv_indptr_offset,
partial_indptr_offset,
merge_packed_offset_start_offset,
merge_packed_offset_end_offset,
merge_partial_packed_offset_start_offset,
merge_partial_packed_offset_end_offset,
merge_partial_stride_offset,
q_len_offset,
kv_len_offset,
q_start_offset,
kv_start_offset,
kv_end_offset,
work_indptr_offset,
partial_o_offset,
partial_lse_offset};
}
void FromVector(const std::vector<int64_t>& vec) {
if (vec.size() != 18) {
std::ostringstream err_msg;
err_msg << "MLAPlanInfo::FromVector: vec.size() should be 18, but got " << vec.size();
FLASHINFER_ERROR(err_msg.str());
}
num_blks_x = vec[0];
num_blks_y = vec[1];
q_indptr_offset = vec[2];
kv_indptr_offset = vec[3];
partial_indptr_offset = vec[4];
merge_packed_offset_start_offset = vec[5];
merge_packed_offset_end_offset = vec[6];
merge_partial_packed_offset_start_offset = vec[7];
merge_partial_packed_offset_end_offset = vec[8];
merge_partial_stride_offset = vec[9];
q_len_offset = vec[10];
kv_len_offset = vec[11];
q_start_offset = vec[12];
kv_start_offset = vec[13];
kv_end_offset = vec[14];
work_indptr_offset = vec[15];
partial_o_offset = vec[16];
partial_lse_offset = vec[17];
}
};
template <typename IdType>
inline cudaError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_bytes,
void* int_buffer, void* page_locked_int_buffer,
size_t int_workspace_size_in_bytes, MLAPlanInfo& plan_info,
IdType* qo_indptr_h, IdType* kv_indptr_h, IdType* kv_len_arr_h,
uint32_t batch_size, uint32_t num_heads, uint32_t head_dim_o,
bool causal, cudaStream_t stream) {
int num_sm = 0;
int dev_id = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id));
// step 0. determine the number of blocks in x and y dimensions
int accum_packed_qo_len = 0;
std::vector<std::tuple<int, int, int>> idx_qo_kv_len_vec;
for (uint32_t i = 0; i < batch_size; ++i) {
if (qo_indptr_h[i + 1] - qo_indptr_h[i] < 0) {
std::ostringstream err_msg;
err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]"
<< qo_indptr_h[i] << " should be non-negative";
FLASHINFER_ERROR(err_msg.str());
}
int qo_len = qo_indptr_h[i + 1] - qo_indptr_h[i];
int packed_qo_len = qo_len * num_heads;
accum_packed_qo_len += packed_qo_len;
int kv_len = kv_len_arr_h[i];
idx_qo_kv_len_vec.push_back({i, qo_len, kv_len});
}
int avg_packed_qo_len = accum_packed_qo_len / batch_size;
int cluster_size;
if (avg_packed_qo_len > 64) {
cluster_size = 2; // two ctas in a cluster
} else {
cluster_size = 1; // one cta in a cluster
}
uint32_t num_clusters = num_sm / cluster_size;
plan_info.num_blks_x = cluster_size;
plan_info.num_blks_y = num_clusters;
const int cta_tile_q = 64;
int cluster_tile_q = cluster_size * cta_tile_q;
int64_t total_kv_lens = 0;
for (auto& [_, qo_len, kv_len] : idx_qo_kv_len_vec) {
int packed_qo_len = qo_len * num_heads;
int num_qo_tiles = ceil_div(packed_qo_len, cluster_tile_q);
for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) {
int effective_kv_len = causal ? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx,
cluster_tile_q, num_qo_tiles, num_heads)
: kv_len;
total_kv_lens += effective_kv_len;
}
}
auto f = [](int x) {
if (x <= 8) {
return 32;
} else if (x <= 16) {
return 64;
} else if (x <= 32) {
return 128;
} else if (x <= 64) {
return 192;
}
return ceil_div(x, 256) * 256;
};
int kv_len_limit = f(std::max(ceil_div(total_kv_lens, num_clusters), 1L));
// step 1. load-balancing scheduling algorithm
MinHeap cluster_cost_heap(num_clusters);
std::vector<std::vector<IdType>> cluster_q_indptr(num_clusters, std::vector<IdType>()),
cluster_kv_indptr(num_clusters, std::vector<IdType>()),
cluster_q_len(num_clusters, std::vector<IdType>()),
cluster_kv_len(num_clusters, std::vector<IdType>()),
cluster_q_start(num_clusters, std::vector<IdType>()),
cluster_kv_start(num_clusters, std::vector<IdType>()),
cluster_kv_end(num_clusters, std::vector<IdType>()),
cluster_partial_indptr(num_clusters, std::vector<IdType>());
std::vector<IdType> merge_packed_offset_start(num_sm, 0), merge_packed_offset_end(num_sm, 0),
merge_partial_packed_offset_start(num_sm, 0), merge_partial_packed_offset_end(num_sm, 0),
merge_partial_stride(num_sm, 0);
int merge_cta_counter = 0;
int partial_o_nnz = 0;
for (auto& [i, qo_len, kv_len] : idx_qo_kv_len_vec) {
int packed_qo_len = qo_len * num_heads;
int num_qo_tiles = ceil_div(packed_qo_len, cluster_tile_q);
for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) {
int remaining_len = causal ? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx, cluster_tile_q,
num_qo_tiles, num_heads)
: kv_len;
int kv_start = 0;
bool split_kv = remaining_len > kv_len_limit;
int row_tile_size = std::min(cluster_tile_q, packed_qo_len - qo_tile_idx * cluster_tile_q);
if (split_kv) {
/*
* Proof(Zihao): merge_cta_counter <= num_sm (num_sm == num_clusters * cluster_size)
*
* Precondition:
* 1. kv_len_limit * num_clusters >= total_kv_lens == sum(remaining_len)
* 2. num_qo_chunks <= max((remaining_len * cluster_size) // kv_len_limit, 1)
* 3. num_qo_tiles_requires_split <= num_clusters
* Implication:
* 1. sum(num_qo_chunks) <= max(sum(remaining_len) * cluster_size / kv_len_limit,
num_qo_tiles_requires_split)
* 2. sum(num_qo_chunks) <= max(cluster_size * num_clusters, num_qo_tiles_requires_split)
*/
int num_qo_chunks = std::max(remaining_len * cluster_size / kv_len_limit, 1);
// row_chunk_size * num_qo_chunks >= row_tile_size
int row_chunk_size = ceil_div(row_tile_size, num_qo_chunks);
int current_q_tile_end =
std::min(cluster_tile_q, packed_qo_len - qo_tile_idx * cluster_tile_q);
for (int offset_start = 0; offset_start < row_tile_size; offset_start += row_chunk_size) {
merge_packed_offset_start[merge_cta_counter] =
qo_indptr_h[i] * num_heads + qo_tile_idx * cluster_tile_q + offset_start;
merge_packed_offset_end[merge_cta_counter] =
qo_indptr_h[i] * num_heads + qo_tile_idx * cluster_tile_q +
std::min(offset_start + row_chunk_size, current_q_tile_end);
merge_partial_packed_offset_start[merge_cta_counter] = partial_o_nnz + offset_start;
merge_partial_packed_offset_end[merge_cta_counter] =
partial_o_nnz + ceil_div(remaining_len, kv_len_limit) * row_tile_size;
merge_partial_stride[merge_cta_counter] = row_tile_size;
merge_cta_counter++;
}
}
bool zero_kv_len = (remaining_len == 0);
while (remaining_len > 0 || zero_kv_len) {
auto [cluster_idx, accum_cost] = cluster_cost_heap.pop();
int actual_len = std::min(remaining_len, kv_len_limit);
cluster_cost_heap.insert(
{cluster_idx, accum_cost + cost_function(cluster_tile_q, actual_len)});
cluster_q_len[cluster_idx].push_back(qo_len);
cluster_kv_len[cluster_idx].push_back(kv_len);
cluster_q_indptr[cluster_idx].push_back(qo_indptr_h[i]);
cluster_kv_indptr[cluster_idx].push_back(kv_indptr_h[i]);
if (split_kv) {
cluster_partial_indptr[cluster_idx].push_back(partial_o_nnz);
partial_o_nnz += row_tile_size;
} else {
cluster_partial_indptr[cluster_idx].push_back(-1);
}
cluster_q_start[cluster_idx].push_back(qo_tile_idx * cluster_tile_q);
cluster_kv_start[cluster_idx].push_back(kv_start);
cluster_kv_end[cluster_idx].push_back(kv_start + actual_len);
remaining_len -= actual_len;
kv_start += actual_len;
if (zero_kv_len) break;
}
}
}
FLASHINFER_CHECK(merge_cta_counter <= num_sm,
"Internal Error: merge_cta_counter should be less than or equal to num_sm, "
"please report this bug to the developers");
int max_total_num_works = 16384; // NOTE(Zihao): adjust it later
std::vector<IdType> work_indptr_vec(num_clusters + 1, 0);
for (uint32_t i = 0; i < num_clusters; ++i) {
work_indptr_vec[i + 1] = work_indptr_vec[i] + cluster_q_indptr[i].size();
}
int total_num_works = work_indptr_vec.back();
auto q_indptr_vec = flatten(cluster_q_indptr, total_num_works);
auto kv_indptr_vec = flatten(cluster_kv_indptr, total_num_works);
auto partial_indptr_vec = flatten(cluster_partial_indptr, total_num_works);
auto q_len_vec = flatten(cluster_q_len, total_num_works);
auto kv_len_vec = flatten(cluster_kv_len, total_num_works);
auto q_start_vec = flatten(cluster_q_start, total_num_works);
auto kv_start_vec = flatten(cluster_kv_start, total_num_works);
auto kv_end_vec = flatten(cluster_kv_end, total_num_works);
AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes);
plan_info.q_indptr_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_q_indptr");
plan_info.kv_indptr_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_kv_indptr");
plan_info.partial_indptr_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * max_total_num_works, 16, "mla_partial_indptr");
plan_info.merge_packed_offset_start_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * num_sm, 16, "mla_merge_packed_offset_start");
plan_info.merge_packed_offset_end_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * num_sm, 16, "mla_merge_packed_offset_end");
plan_info.merge_partial_packed_offset_start_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * num_sm, 16, "mla_merge_partial_packed_offset_start");
plan_info.merge_partial_packed_offset_end_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * num_sm, 16, "mla_merge_partial_packed_offset_end");
plan_info.merge_partial_stride_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * num_sm, 16, "mla_merge_partial_stride");
plan_info.q_len_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_q_len");
plan_info.kv_len_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_kv_len");
plan_info.q_start_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_q_start");
plan_info.kv_start_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_kv_start");
plan_info.kv_end_offset =
int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_kv_end");
plan_info.work_indptr_offset = int_allocator.aligned_alloc_offset(
sizeof(IdType) * max_total_num_works, 16, "mla_work_indptr");
IdType* cluster_q_indptr_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.q_indptr_offset);
IdType* cluster_kv_indptr_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.kv_indptr_offset);
IdType* cluster_partial_indptr_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.partial_indptr_offset);
IdType* cluster_merge_packed_offset_start_h = GetPtrFromBaseOffset<IdType>(
page_locked_int_buffer, plan_info.merge_packed_offset_start_offset);
IdType* cluster_merge_packed_offset_end_h = GetPtrFromBaseOffset<IdType>(
page_locked_int_buffer, plan_info.merge_packed_offset_end_offset);
IdType* cluster_merge_partial_packed_offset_start_h = GetPtrFromBaseOffset<IdType>(
page_locked_int_buffer, plan_info.merge_partial_packed_offset_start_offset);
IdType* cluster_merge_partial_packed_offset_end_h = GetPtrFromBaseOffset<IdType>(
page_locked_int_buffer, plan_info.merge_partial_packed_offset_end_offset);
IdType* cluster_merge_partial_stride_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.merge_partial_stride_offset);
IdType* cluster_q_len_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.q_len_offset);
IdType* cluster_kv_len_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.kv_len_offset);
IdType* cluster_q_start_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.q_start_offset);
IdType* cluster_kv_start_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.kv_start_offset);
IdType* cluster_kv_end_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.kv_end_offset);
IdType* cluster_work_indptr_h =
GetPtrFromBaseOffset<IdType>(page_locked_int_buffer, plan_info.work_indptr_offset);
std::copy(q_indptr_vec.begin(), q_indptr_vec.end(), cluster_q_indptr_h);
std::copy(kv_indptr_vec.begin(), kv_indptr_vec.end(), cluster_kv_indptr_h);
std::copy(partial_indptr_vec.begin(), partial_indptr_vec.end(), cluster_partial_indptr_h);
std::copy(merge_packed_offset_start.begin(), merge_packed_offset_start.end(),
cluster_merge_packed_offset_start_h);
std::copy(merge_packed_offset_end.begin(), merge_packed_offset_end.end(),
cluster_merge_packed_offset_end_h);
std::copy(merge_partial_packed_offset_start.begin(), merge_partial_packed_offset_start.end(),
cluster_merge_partial_packed_offset_start_h);
std::copy(merge_partial_packed_offset_end.begin(), merge_partial_packed_offset_end.end(),
cluster_merge_partial_packed_offset_end_h);
std::copy(merge_partial_stride.begin(), merge_partial_stride.end(),
cluster_merge_partial_stride_h);
std::copy(q_len_vec.begin(), q_len_vec.end(), cluster_q_len_h);
std::copy(kv_len_vec.begin(), kv_len_vec.end(), cluster_kv_len_h);
std::copy(q_start_vec.begin(), q_start_vec.end(), cluster_q_start_h);
std::copy(kv_start_vec.begin(), kv_start_vec.end(), cluster_kv_start_h);
std::copy(kv_end_vec.begin(), kv_end_vec.end(), cluster_kv_end_h);
std::copy(work_indptr_vec.begin(), work_indptr_vec.end(), cluster_work_indptr_h);
size_t num_bytes_to_copy = int_allocator.num_allocated_bytes();
FLASHINFER_CUDA_CALL(cudaMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy,
cudaMemcpyHostToDevice, stream));
constexpr size_t sizeof_dtype_o = 2;
AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes);
plan_info.partial_o_offset = float_allocator.aligned_alloc_offset(
2 * num_clusters * cluster_tile_q * sizeof_dtype_o * head_dim_o, 16, "mla_partial_o");
plan_info.partial_lse_offset = float_allocator.aligned_alloc_offset(
2 * num_clusters * cluster_tile_q * sizeof(float), 16, "mla_partial_lse");
return cudaSuccess;
}
} // namespace flashinfer
#endif // FLASHINFER_ATTENTION_SCHEDULER_CUH_