148 lines
5.6 KiB
Plaintext
148 lines
5.6 KiB
Plaintext
#include <ATen/Dispatch.h>
|
|
#include <ATen/core/TensorBody.h>
|
|
#include <c10/cuda/CUDAStream.h>
|
|
#include <c10/util/Exception.h>
|
|
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
#include <type_traits>
|
|
|
|
namespace {
|
|
|
|
using std::size_t;
|
|
using std::uint64_t;
|
|
|
|
// Each warp will process 256 bytes per loop iteration
|
|
template <typename T>
|
|
__global__ void store_kv_cache_256x1(
|
|
uint64_t* __restrict__ k_cache,
|
|
uint64_t* __restrict__ v_cache,
|
|
const T* __restrict__ out_loc,
|
|
const size_t length,
|
|
const uint64_t* __restrict__ k,
|
|
const uint64_t* __restrict__ v,
|
|
const size_t kv_cache_stride,
|
|
const size_t kv_input_stride,
|
|
const size_t num_items) {
|
|
const auto idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
const auto warp_id = idx / 32;
|
|
const auto lane_id = idx % 32;
|
|
if (warp_id >= length) return;
|
|
const auto offset = out_loc[warp_id];
|
|
const auto k_dst = k_cache + offset * kv_cache_stride;
|
|
const auto v_dst = v_cache + offset * kv_cache_stride;
|
|
const auto k_src = k + warp_id * kv_input_stride;
|
|
const auto v_src = v + warp_id * kv_input_stride;
|
|
for (size_t i = 0; i < num_items; ++i) {
|
|
k_dst[lane_id + i * 32] = k_src[lane_id + i * 32];
|
|
v_dst[lane_id + i * 32] = v_src[lane_id + i * 32];
|
|
}
|
|
}
|
|
|
|
// Each warp will process 128 bytes per loop iteration
|
|
template <typename T>
|
|
__global__ void store_kv_cache_128x2(
|
|
uint64_t* __restrict__ k_cache,
|
|
uint64_t* __restrict__ v_cache,
|
|
const T* __restrict__ out_loc,
|
|
const size_t length,
|
|
const uint64_t* __restrict__ k,
|
|
const uint64_t* __restrict__ v,
|
|
const size_t kv_cache_stride,
|
|
const size_t kv_input_stride,
|
|
const size_t num_items) {
|
|
const auto idx = blockIdx.x * blockDim.x + threadIdx.x;
|
|
const auto warp_id = idx / 32;
|
|
const auto lane_id = idx % 32;
|
|
if (warp_id >= length) return;
|
|
const auto offset = out_loc[warp_id];
|
|
const auto copy_k = lane_id < 16;
|
|
const auto copy_id = lane_id % 16;
|
|
const auto cache = copy_k ? k_cache : v_cache;
|
|
const auto input = copy_k ? k : v;
|
|
const auto dst = cache + offset * kv_cache_stride;
|
|
const auto src = input + warp_id * kv_input_stride;
|
|
for (size_t i = 0; i < num_items; ++i) {
|
|
dst[copy_id + i * 16] = src[copy_id + i * 16];
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
|
|
auto store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v) -> void {
|
|
const auto max_tokens = k_cache.size(0);
|
|
const auto num_tokens = out_loc.size(0);
|
|
k_cache = k_cache.view({max_tokens, -1});
|
|
v_cache = v_cache.view({max_tokens, -1});
|
|
k = k.view({num_tokens, -1});
|
|
v = v.view({num_tokens, -1});
|
|
|
|
TORCH_CHECK(
|
|
k_cache.is_cuda() && v_cache.is_cuda() && out_loc.is_cuda() && k.is_cuda() && v.is_cuda(),
|
|
"All tensors must be CUDA tensors");
|
|
TORCH_CHECK(k_cache.sizes() == v_cache.sizes(), "k_cache and v_cache must have the same size");
|
|
TORCH_CHECK(k_cache.strides() == v_cache.strides(), "k_cache and v_cache must have the same strides");
|
|
TORCH_CHECK(k.sizes() == v.sizes(), "k and v must have the same size");
|
|
TORCH_CHECK(k.strides() == v.strides(), "k and v must have the same strides");
|
|
TORCH_CHECK(k.stride(-1) == 1 && k_cache.stride(-1) == 1, "k and k_cache must be contiguous in head.");
|
|
TORCH_CHECK(k.size(-1) == k_cache.size(-1), "k and k_cache must have the same head size");
|
|
TORCH_CHECK(out_loc.dim() == 1 && out_loc.is_contiguous(), "out_loc must be a 1D contiguous tensor");
|
|
static_assert(sizeof(uint64_t) == 8, "uint64_t must be 8 bytes, our code assumes that");
|
|
|
|
const auto length = out_loc.size(0);
|
|
const auto elem_size = k.element_size();
|
|
const auto size_bytes = elem_size * k.size(-1);
|
|
const auto kv_cache_stride_bytes = elem_size * k_cache.stride(-2);
|
|
const auto kv_input_stride_bytes = elem_size * k.stride(-2);
|
|
const auto kv_cache_stride = kv_cache_stride_bytes / 8;
|
|
const auto kv_input_stride = kv_input_stride_bytes / 8;
|
|
|
|
const auto k_cache_ptr = static_cast<uint64_t*>(k_cache.data_ptr());
|
|
const auto v_cache_ptr = static_cast<uint64_t*>(v_cache.data_ptr());
|
|
const auto k_ptr = static_cast<const uint64_t*>(k.data_ptr());
|
|
const auto v_ptr = static_cast<const uint64_t*>(v.data_ptr());
|
|
const auto num_threads = 256;
|
|
const auto num_warps = num_threads / 32;
|
|
const auto num_blocks = (length + num_warps - 1) / num_warps;
|
|
const auto stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
AT_DISPATCH_INTEGRAL_TYPES(out_loc.scalar_type(), "store_kv_cache", [&] {
|
|
if constexpr (!std::is_same_v<scalar_t, int32_t> && !std::is_same_v<scalar_t, int64_t>) {
|
|
// do not instantiate the kernel if out_loc is not int32 or int64
|
|
TORCH_CHECK(false, "out_loc must be of type int32 or int64, got: ", out_loc.scalar_type());
|
|
} else {
|
|
if (size_bytes % 256 == 0) {
|
|
const auto items_per_warp = size_bytes / 256;
|
|
store_kv_cache_256x1<<<num_blocks, num_threads, 0, stream>>>(
|
|
k_cache_ptr,
|
|
v_cache_ptr,
|
|
out_loc.data_ptr<scalar_t>(),
|
|
length,
|
|
k_ptr,
|
|
v_ptr,
|
|
kv_cache_stride,
|
|
kv_input_stride,
|
|
items_per_warp);
|
|
} else if (size_bytes % 128 == 0) {
|
|
const auto items_per_warp = size_bytes / 128;
|
|
store_kv_cache_128x2<<<num_blocks, num_threads, 0, stream>>>(
|
|
k_cache_ptr,
|
|
v_cache_ptr,
|
|
out_loc.data_ptr<scalar_t>(),
|
|
length,
|
|
k_ptr,
|
|
v_ptr,
|
|
kv_cache_stride,
|
|
kv_input_stride,
|
|
items_per_warp);
|
|
} else {
|
|
TORCH_CHECK(
|
|
false,
|
|
"The last dimension size bytes of k and v must be"
|
|
" divisible by 128 at least, got: ",
|
|
size_bytes);
|
|
}
|
|
}
|
|
});
|
|
}
|