sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/page.cuh

647 lines
28 KiB
Plaintext

/*
* Copyright (c) 2023 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_PAGE_CUH_
#define FLASHINFER_PAGE_CUH_
#include <driver_types.h>
#include <vector>
#include "exception.h"
#include "fastdiv.cuh"
#include "layout.cuh"
#include "utils.cuh"
#include "vec_dtypes.cuh"
namespace flashinfer {
/*!
* \brief Paged key-value cache
* \tparam layout The layout of last 3 dimensions in KV-Cache.
* \tparam DType The data type of the key-value cache
* \tparam IdType The index data type of the kv-cache
*/
template <typename DType, typename IdType>
struct paged_kv_t {
uint_fastdiv page_size;
uint32_t num_heads;
uint32_t head_dim;
uint32_t batch_size;
uint32_t stride_page;
uint32_t stride_n;
uint32_t stride_h;
// Internal layout:
// [max_num_pages, num_heads, page_size, head_dim] if layout == HND
// [max_num_pages, page_size, num_heads, head_dim] if layout == NHD
DType* k_data;
DType* v_data;
IdType* indices;
// [batch_size + 1] The page indptr array, with the first element 0, the last element nnz_pages
IdType* indptr;
// [batch_size] The offset of the last page for each request in the batch
IdType* last_page_len;
// [batch_size] The start position of each request in the batch.
IdType* rope_pos_offset;
/*!
* \brief Construct an empty paged key-value cache
*/
__host__ __device__ __forceinline__ paged_kv_t()
: num_heads(0),
page_size(),
head_dim(0),
batch_size(0),
stride_page(0),
stride_n(0),
stride_h(0),
k_data(nullptr),
v_data(nullptr),
indices(nullptr),
indptr(nullptr),
last_page_len(nullptr),
rope_pos_offset(nullptr) {}
/*!
* \brief Construct a paged key-value cache
* \param num_heads The number of heads
* \param page_size The size of each page
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param k_data The start pointer of key cache, k_cache should be contiguous
* \param v_data The start pointer of value cache, v_cache should be contiguous
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
*/
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
uint32_t batch_size, QKVLayout layout, DType* k_data,
DType* v_data, IdType* indices, IdType* indptr,
IdType* last_page_len, IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
indices(indices),
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {
stride_page = num_heads * page_size * head_dim;
this->k_data = k_data;
this->v_data = v_data;
stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim;
stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim;
}
/*!
* \brief Construct a paged key-value cache with custom kv-cache strides
* \param num_heads The number of heads
* \param page_size The size of each page
* \param head_dim The dimension of each head
* \param batch_size The batch size
* \param layout The layout of last 3 dimensions in KV-Cache.
* \param k_data The start pointer of key cache, k_cache doesn't have to be contiguous
* \param v_data The start pointer of value cache, v_cache doesn't have to be contiguous
* \param kv_strides custom strides of each dimensions of k_data and v_data
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
*/
__host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim,
uint32_t batch_size, QKVLayout layout, DType* k_data,
DType* v_data, const int64_t* kv_strides, IdType* indices,
IdType* indptr, IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
: num_heads(num_heads),
page_size(page_size),
head_dim(head_dim),
batch_size(batch_size),
indices(indices),
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {
stride_page = kv_strides[0];
this->k_data = k_data;
this->v_data = v_data;
stride_n = layout == QKVLayout::kHND ? kv_strides[2] : kv_strides[1];
stride_h = layout == QKVLayout::kHND ? kv_strides[1] : kv_strides[2];
}
__host__ __device__ __forceinline__ uint32_t get_length(uint32_t batch_idx) const {
if (indptr[batch_idx + 1] == indptr[batch_idx]) {
return 0;
}
return (indptr[batch_idx + 1] - indptr[batch_idx] - 1) * page_size + last_page_len[batch_idx];
}
/*!
* \brief Compute the offset of element in the allocated buffer.
* \param page_idx The page index
* \param head_idx The head index
* \param entry_idx The page entry index
* \param feat_idx The feature index
*/
__host__ __device__ __forceinline__ size_t get_elem_offset(size_t page_idx, size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
return page_idx * stride_page + head_idx * stride_h + entry_idx * stride_n + feat_idx;
}
/*!
* \brief Compute the offset of element inside the page.
* \param head_idx The head index
* \param entry_idx The page entry index
* \param feat_idx The feature index
*/
__host__ __device__ __forceinline__ size_t get_elem_offset_in_page(size_t head_idx,
size_t entry_idx,
size_t feat_idx) const {
return head_idx * stride_h + entry_idx * stride_n + feat_idx;
}
__device__ __forceinline__ DType* get_k_ptr(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx) const {
return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
}
__device__ __forceinline__ size_t protective_get_kv_offset(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx,
IdType last_indptr) const {
if (page_iter < last_indptr) {
return get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
} else {
return 0;
}
}
__device__ __forceinline__ DType* protective_get_k_ptr(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx,
IdType last_indptr) const {
return k_data + protective_get_kv_offset(page_iter, head_idx, entry_idx, feat_idx, last_indptr);
}
__device__ __forceinline__ DType* get_v_ptr(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx) const {
return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx);
}
__device__ __forceinline__ DType* protective_get_v_ptr(IdType page_iter, uint32_t head_idx,
uint32_t entry_idx, uint32_t feat_idx,
IdType last_indptr) const {
return v_data + protective_get_kv_offset(page_iter, head_idx, entry_idx, feat_idx, last_indptr);
}
};
/*!
* \brief CUDA kernel to append new keys/values to the paged key-value cache in the decode phase
* \tparam head_dim The dimension of each head
* \tparam vec_size The vector size used in the kernel
* \tparam DType The data type of the key-value cache
* \tparam IdType The index data type of the kv-cache
* \param paged_kv The paged key-value cache
* \param key The key to be appended
* \param value The value to be appended
*/
template <uint32_t head_dim, uint32_t vec_size, typename DType, typename IdType>
__global__ void AppendPagedKVCacheDecodeKernel(paged_kv_t<DType, IdType> paged_kv,
DType* __restrict__ key, DType* __restrict__ value) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t num_heads = paged_kv.num_heads;
uint32_t batch_idx = blockIdx.x;
uint32_t head_idx = ty;
uint32_t seq_len =
(paged_kv.indptr[batch_idx + 1] - paged_kv.indptr[batch_idx] - 1) * paged_kv.page_size +
paged_kv.last_page_len[batch_idx];
uint32_t page_iter = paged_kv.indptr[batch_idx] + (seq_len - 1) / paged_kv.page_size;
uint32_t entry_idx = (seq_len - 1) % paged_kv.page_size;
DType* k_ptr = paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
DType* v_ptr = paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
vec_t<DType, vec_size>::memcpy(
k_ptr, key + (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size);
vec_t<DType, vec_size>::memcpy(
v_ptr, value + (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size);
}
/*!
* \brief CUDA kernel to append new keys/values to the paged key-value cache in the prefill phase
* \tparam head_dim The dimension of each head
* \tparam vec_size The vector size used in the kernel
* \tparam DType The data type of the key-value cache
* \tparam IdType The index data type of the kv-cache
* \param paged_kv The paged key-value cache
* \param key The key to be appended
* \param value The value to be appended
* \param batch_indices The batch indices of elements to be appended
* \param positions The positions of elements to be appended
*/
template <uint32_t head_dim, uint32_t vec_size, typename DType, typename IdType>
__global__ void AppendPagedKVCacheKernel(paged_kv_t<DType, IdType> paged_kv,
DType* __restrict__ append_key,
DType* __restrict__ append_value,
IdType* __restrict__ batch_indices,
IdType* __restrict__ positions, uint32_t nnz,
size_t append_k_stride_n, size_t append_k_stride_h,
size_t append_v_stride_n, size_t append_v_stride_h) {
uint32_t tx = threadIdx.x, ty = threadIdx.y;
uint32_t num_heads = paged_kv.num_heads;
uint32_t head_idx = ty;
uint32_t cta_id = blockIdx.x;
uint32_t num_ctas = gridDim.x;
#pragma unroll 4
for (uint32_t i = cta_id; i < nnz; i += num_ctas) {
uint32_t page_iter, entry_idx;
paged_kv.page_size.divmod(paged_kv.indptr[batch_indices[i]] * paged_kv.page_size + positions[i],
page_iter, entry_idx);
DType* k_ptr = paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
DType* v_ptr = paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size);
vec_t<DType, vec_size>::memcpy(
k_ptr, append_key + i * append_k_stride_n + head_idx * append_k_stride_h + tx * vec_size);
vec_t<DType, vec_size>::memcpy(
v_ptr, append_value + i * append_v_stride_n + head_idx * append_v_stride_h + tx * vec_size);
}
}
template <typename IdType>
__global__ void BlockSparseIndicesToVectorSparseOffsetsKernel(
IdType* __restrict__ block_sparse_indices, IdType* __restrict__ block_sparse_indptr,
IdType* __restrict__ vector_sparse_offsets, IdType* __restrict__ vector_sparse_indptr,
IdType* __restrict__ kv_lens, const uint32_t stride_block, const uint32_t stride_n,
const uint32_t batch_size, const uint_fastdiv block_size) {
#pragma unroll 1
for (int b = blockIdx.x; b < batch_size; ++b) {
#pragma unroll 2
for (int pos = threadIdx.x; pos < kv_lens[b]; pos += blockDim.x) {
uint32_t q, r;
block_size.divmod(pos, q, r);
vector_sparse_offsets[vector_sparse_indptr[b] + pos] =
block_sparse_indices[block_sparse_indptr[b] + q] * stride_block + r * stride_n;
}
}
}
template <typename IdType>
cudaError_t BlockSparseIndicesToVectorSparseOffset(
IdType* block_sparse_indices, IdType* block_sparse_indptr, IdType* vector_sparse_offsets,
IdType* vector_sparse_indptr, IdType* kv_lens, const int64_t stride_block,
const int64_t stride_n, const int64_t batch_size, const uint32_t block_size,
cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
uint32_t num_threads = 512;
uint_fastdiv block_size_fastdiv(block_size);
auto kernel = BlockSparseIndicesToVectorSparseOffsetsKernel<IdType>;
void* args[] = {(void*)&block_sparse_indices,
(void*)&block_sparse_indptr,
(void*)&vector_sparse_offsets,
(void*)&vector_sparse_indptr,
(void*)&kv_lens,
(void*)&stride_block,
(void*)&stride_n,
(void*)&batch_size,
(void*)&block_size_fastdiv};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, num_sms, num_threads, args, 0, stream));
return cudaSuccess;
}
/*!
* \brief Append new keys/values to the paged key-value cache in the decode phase
* \tparam DType The data type of the key-value cache
* \tparam IdType The index data type of the kv-cache
* \param paged_kv The paged key-value cache
* \param key The key to be appended
* \param value The value to be appended
* \param stream The CUDA stream to execute kernels.
* \return status Indicates whether CUDA calls are successful
*/
template <typename DType, typename IdType>
cudaError_t AppendPagedKVCacheDecode(paged_kv_t<DType, IdType> paged_kv, DType* key, DType* value,
cudaStream_t stream = nullptr) {
uint32_t head_dim = paged_kv.head_dim;
uint32_t batch_size = paged_kv.batch_size;
uint32_t num_heads = paged_kv.num_heads;
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
uint32_t bdx = HEAD_DIM / vec_size;
uint32_t bdy = num_heads;
// NOTE(Zihao): could be slow for small batch size, will optimize later
dim3 nblks(batch_size);
dim3 nthrs(bdx, bdy);
auto kernel = AppendPagedKVCacheDecodeKernel<HEAD_DIM, vec_size, DType, IdType>;
void* args[] = {(void*)&paged_kv, (void*)&key, (void*)&value};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
return cudaSuccess;
}
/*!
* \brief Append new keys/values to the paged key-value cache
* \tparam layout The layout of last 3 dimension in KV-Cache
* \tparam DType The data type of the key-value cache
* \tparam IdType The index data type of the kv-cache
* \param paged_kv The paged key-value cache
* \param key The key to be appended
* \param value The value to be appended
* \param append_indptr The indptr array of the appended ragged tensor
* \param stream The CUDA stream to execute kernels.
* \return status Indicates whether CUDA calls are successful
*/
template <typename DType, typename IdType>
cudaError_t AppendPagedKVCache(paged_kv_t<DType, IdType> paged_kv, DType* append_key,
DType* append_value, IdType* batch_indices, IdType* positions,
uint32_t nnz, size_t append_k_stride_n, size_t append_k_stride_h,
size_t append_v_stride_n, size_t append_v_stride_h,
cudaStream_t stream = nullptr) {
uint32_t head_dim = paged_kv.head_dim;
uint32_t num_heads = paged_kv.num_heads;
int dev_id = 0;
int num_sms = 0;
int num_blocks_per_sm = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32);
uint32_t bdx = HEAD_DIM / vec_size;
uint32_t bdy = num_heads;
uint32_t num_threads = bdx * bdy;
uint32_t smem_size = 0;
auto kernel = AppendPagedKVCacheKernel<HEAD_DIM, vec_size, DType, IdType>;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(int(nnz), num_sms));
dim3 nblks(num_blocks_per_sm * num_sms);
dim3 nthrs(bdx, bdy);
void* args[] = {(void*)&paged_kv, (void*)&append_key, (void*)&append_value,
(void*)&batch_indices, (void*)&positions, (void*)&nnz,
(void*)&append_k_stride_n, (void*)&append_k_stride_h, (void*)&append_v_stride_n,
(void*)&append_v_stride_h};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
});
return cudaSuccess;
}
template <typename DType, typename IdType>
struct paged_kv_mla_t {
uint_fastdiv page_size;
uint32_t head_dim_ckv;
uint32_t head_dim_kpe;
uint32_t batch_size;
uint32_t stride_page_ckv;
uint32_t stride_page_kpe;
uint32_t stride_n_ckv;
uint32_t stride_n_kpe;
// Internal layout:
// [max_num_pages, page_size, head_dim]
DType* ckv_data;
DType* kpe_data;
IdType* indices;
// [batch_size + 1] The page indptr array, with the first element 0, the last element nnz_pages
IdType* indptr;
// [batch_size] The offset of the last page for each request in the batch
IdType* last_page_len;
// [batch_size] The start position of each request in the batch.
IdType* rope_pos_offset;
/*!
* \brief Construct an empty paged key-value cache
*/
__host__ __device__ __forceinline__ paged_kv_mla_t()
: head_dim_ckv(0),
head_dim_kpe(0),
batch_size(0),
stride_page_ckv(0),
stride_page_kpe(0),
stride_n_ckv(0),
stride_n_kpe(0),
ckv_data(nullptr),
kpe_data(nullptr),
indices(nullptr),
indptr(nullptr),
last_page_len(nullptr),
rope_pos_offset(nullptr) {}
/*!
* \brief Construct a paged mla kv cache
* \param page_size The size of each page
* \param head_dim_compressed_kv The dimension of compressed-kv
* \param head_dim_kpe The dimension of k-pe
* \param batch_size The batch size
* \param compressed_kv_data The start pointer of compressed-kv cache, cache should be contiguous
* \param kpe_data The start pointer of k-pe cache, cache should be contiguous
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
*/
__host__ __forceinline__ paged_kv_mla_t(uint32_t page_size, uint32_t head_dim_compressed_kv,
uint32_t head_dim_kpe, uint32_t batch_size,
DType* compressed_kv_data, DType* kpe_data,
IdType* indices, IdType* indptr, IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
: page_size(page_size),
head_dim_ckv(head_dim_compressed_kv),
head_dim_kpe(head_dim_kpe),
batch_size(batch_size),
ckv_data(compressed_kv_data),
kpe_data(kpe_data),
indices(indices),
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {
stride_page_ckv = page_size * head_dim_ckv;
stride_n_ckv = head_dim_ckv;
stride_page_kpe = page_size * head_dim_kpe;
stride_n_kpe = head_dim_kpe;
}
/*!
* \brief Construct a paged key-value cache with custom kv-cache strides
* \param page_size The size of each page
* \param head_dim_compressed_kv The dimension of compressed-kv
* \param head_dim_kpe The dimension of k-pe
* \param batch_size The batch size
* \param compressed_kv_data The start pointer of compressed-kv cache, cache should be contiguous
* \param compressed_kv_strides custom strides of each dimensions of compressed-kv cache
* \param kpe_data The start pointer of k-pe cache, cache should be contiguous
* \param kpe_strides custom strides of each dimensions of k-pe cache
* \param indices The page indices array
* \param indptr The page indptr array
* \param last_page_len The offset of the last page for each request in the batch
* \param rope_pos_offset The start position of each request in the batch.
*/
__host__ __forceinline__ paged_kv_mla_t(uint32_t page_size, uint32_t head_dim_compressed_kv,
uint32_t head_dim_kpe, uint32_t batch_size,
DType* compressed_kv_data,
const int64_t* compressed_kv_strides, DType* kpe_data,
const int64_t* kpe_strides, IdType* indices,
IdType* indptr, IdType* last_page_len,
IdType* rope_pos_offset = nullptr)
: page_size(page_size),
head_dim_ckv(head_dim_compressed_kv),
head_dim_kpe(head_dim_kpe),
batch_size(batch_size),
ckv_data(compressed_kv_data),
kpe_data(kpe_data),
indices(indices),
indptr(indptr),
last_page_len(last_page_len),
rope_pos_offset(rope_pos_offset) {
stride_page_ckv = compressed_kv_strides[0];
stride_n_ckv = compressed_kv_strides[1];
stride_page_kpe = kpe_strides[0];
stride_n_kpe = kpe_strides[1];
}
__host__ __device__ __forceinline__ uint32_t get_length(uint32_t batch_idx) const {
if (indptr[batch_idx + 1] == indptr[batch_idx]) {
return 0;
}
return (indptr[batch_idx + 1] - indptr[batch_idx] - 1) * page_size + last_page_len[batch_idx];
}
__host__ __device__ __forceinline__ size_t get_elem_offset_ckv(size_t page_idx, size_t entry_idx,
size_t feat_idx) const {
return page_idx * stride_page_ckv + entry_idx * stride_n_ckv + feat_idx;
}
__device__ __forceinline__ size_t protective_get_offset_ckv(IdType page_iter, uint32_t entry_idx,
uint32_t feat_idx,
IdType last_indptr) const {
if (page_iter < last_indptr) {
return get_elem_offset_ckv(__ldg(indices + page_iter), entry_idx, feat_idx);
} else {
return 0;
}
}
__host__ __device__ __forceinline__ size_t get_elem_offset_kpe(size_t page_idx, size_t entry_idx,
size_t feat_idx) const {
return page_idx * stride_page_kpe + entry_idx * stride_n_kpe + feat_idx;
}
__device__ __forceinline__ size_t protective_get_offset_kpe(IdType page_iter, uint32_t entry_idx,
uint32_t feat_idx,
IdType last_indptr) const {
if (page_iter < last_indptr) {
return get_elem_offset_kpe(__ldg(indices + page_iter), entry_idx, feat_idx);
} else {
return 0;
}
}
__device__ __forceinline__ DType* get_ckv_ptr(size_t page_idx, size_t entry_idx,
size_t feat_idx) const {
return ckv_data + get_elem_offset_ckv(__ldg(indices + page_idx), entry_idx, feat_idx);
}
__device__ __forceinline__ DType* get_kpe_ptr(size_t page_idx, size_t entry_idx,
size_t feat_idx) const {
return kpe_data + get_elem_offset_kpe(__ldg(indices + page_idx), entry_idx, feat_idx);
}
};
template <uint32_t head_dim_ckv, uint32_t head_dim_kpe, uint32_t vec_size, typename DType,
typename IdType>
__global__ void AppendPagedKVMlaCacheKernel(paged_kv_mla_t<DType, IdType> paged_kv_mla,
DType* __restrict__ append_ckv,
DType* __restrict__ append_kpe,
IdType* __restrict__ batch_indices,
IdType* __restrict__ positions, uint32_t nnz,
size_t append_ckv_stride_n,
size_t append_kpe_stride_n) {
uint32_t tx = threadIdx.x;
uint32_t cta_id = blockIdx.x;
uint32_t num_ctas = gridDim.x;
#pragma unroll 4
for (uint32_t i = cta_id; i < nnz; i += num_ctas) {
uint32_t page_iter, entry_idx;
paged_kv_mla.page_size.divmod(
paged_kv_mla.indptr[batch_indices[i]] * paged_kv_mla.page_size + positions[i], page_iter,
entry_idx);
DType* ckv_ptr = paged_kv_mla.get_ckv_ptr(page_iter, entry_idx, tx * vec_size);
vec_t<DType, vec_size>::memcpy(ckv_ptr, append_ckv + i * append_ckv_stride_n + tx * vec_size);
if (tx * vec_size < head_dim_kpe) {
DType* kpe_ptr = paged_kv_mla.get_kpe_ptr(page_iter, entry_idx, tx * vec_size);
vec_t<DType, vec_size>::memcpy(kpe_ptr, append_kpe + i * append_kpe_stride_n + tx * vec_size);
}
}
}
template <typename DType, typename IdType>
cudaError_t AppendPagedKVMlaCache(paged_kv_mla_t<DType, IdType> paged_kv, DType* append_ckv,
DType* append_kpe, IdType* batch_indices, IdType* positions,
uint32_t nnz, size_t append_ckv_stride_n,
size_t append_kpe_stride_n, cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
int num_blocks_per_sm = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
uint32_t head_dim_ckv = paged_kv.head_dim_ckv;
uint32_t head_dim_kpe = paged_kv.head_dim_kpe;
constexpr uint32_t HEAD_CKV_DIM = 512;
constexpr uint32_t HEAD_KPE_DIM = 64;
FLASHINFER_CHECK(head_dim_ckv == HEAD_CKV_DIM, "head_dim_ckv must be equal to 512");
FLASHINFER_CHECK(head_dim_kpe == HEAD_KPE_DIM, "head_dim_kpe must be equal to 64");
constexpr uint32_t vec_size = 2;
uint32_t bdx = HEAD_CKV_DIM / vec_size;
uint32_t num_threads = bdx;
uint32_t smem_size = 0;
auto kernel = AppendPagedKVMlaCacheKernel<HEAD_CKV_DIM, HEAD_KPE_DIM, vec_size, DType, IdType>;
FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel,
num_threads, smem_size));
num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(int(nnz), num_sms));
dim3 nblks(num_blocks_per_sm * num_sms);
dim3 nthrs(bdx);
void* args[] = {(void*)&paged_kv,
(void*)&append_ckv,
(void*)&append_kpe,
(void*)&batch_indices,
(void*)&positions,
(void*)&nnz,
(void*)&append_ckv_stride_n,
(void*)&append_kpe_stride_n};
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream));
return cudaSuccess;
}
} // namespace flashinfer
#endif // FLAHSINFER_PAGE_CUH_