sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/attention/hopper/sparse_mainloop.cuh

363 lines
16 KiB
Plaintext

/*
* Copyright (c) 2024 by FlashInfer team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_
#define FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
#include "../../math.cuh"
#include "block_sparse_gather.cuh"
#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "named_barrier.cuh"
#include "utils.cuh"
namespace flashinfer {
using namespace cute;
template <typename AdditionalParams, typename Ktraits, bool CAUSAL, bool MULTIITEMSCORING = false>
struct SparseCollectiveMainloop {
using DTypeQ = typename Ktraits::DTypeQ;
using DTypeKV = typename Ktraits::DTypeKV;
using IdType = typename Ktraits::IdType;
using TileShape_QKD = typename Ktraits::TileShape_QKD;
using TileShape_PDV = typename Ktraits::TileShape_PDV;
static constexpr int CTA_Q = get<0>(TileShape_QKD{});
static constexpr int CTA_KV = get<1>(TileShape_QKD{});
static constexpr int NUM_STAGES = Ktraits::NUM_STAGES;
static constexpr int HEAD_DIM_QK = Ktraits::HEAD_DIM_QK;
static constexpr int HEAD_DIM_VO = Ktraits::HEAD_DIM_VO;
static_assert(HEAD_DIM_QK == HEAD_DIM_VO);
static constexpr int NUM_COPY_THREADS = cutlass::NumThreadsPerWarpGroup;
using GmemTiledCopyQ = cute::SM90_TMA_LOAD;
static constexpr auto AlignmentKV = 128 / cutlass::sizeof_bits<DTypeKV>::value;
using AlignmentTypeKV = cute::uint_byte_t<static_cast<int>(sizeof(DTypeKV)) * AlignmentKV>;
// NOTE(Zihao): use SM80_CP_ASYNC for sparse loading of KV-cache
using GmemCopyAtomKV = cute::Copy_Atom<SM80_CP_ASYNC_CACHEGLOBAL_ZFILL<AlignmentTypeKV>, DTypeKV>;
using GmemTiledCopyK =
decltype(cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy<
GmemCopyAtomKV, NUM_COPY_THREADS, AlignmentKV,
cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>,
decltype(cute::get<1>(TileShape_QKD{})), decltype(cute::get<2>(TileShape_QKD{}))>());
using GmemTiledCopyV =
decltype(cutlass::gemm::collective::detail::make_simt_gmem_tiled_copy<
GmemCopyAtomKV, NUM_COPY_THREADS, AlignmentKV,
cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>,
decltype(cute::get<2>(TileShape_PDV{})), decltype(cute::get<1>(TileShape_PDV{}))>());
using SmemLayoutQ = typename Ktraits::SmemLayoutQ;
using SmemLayoutK = typename Ktraits::SmemLayoutK;
using SmemLayoutV = typename Ktraits::SmemLayoutV;
using SmemLayoutVt = typename Ktraits::SmemLayoutVt;
using ShapeT = cute::Shape<int32_t, int32_t, int32_t>;
using StrideT = cute::Shape<int64_t, _1, int64_t>; // (N, D, H)
using LayoutT = cute::Layout<ShapeT, StrideT>;
using ShapeLseT = cute::Shape<int32_t, int32_t>;
using StrideLseT = cute::Shape<_1, int64_t>;
using LayoutLseT = cute::Layout<ShapeLseT, StrideLseT>;
using TMA_Q = decltype(make_tma_copy(
GmemTiledCopyQ{},
make_tensor(make_gmem_ptr(static_cast<DTypeQ const*>(nullptr)),
repeat_like(StrideT{}, int32_t(0)), StrideT{}),
SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{})); // no mcast for Q
static constexpr bool USE_TMA_LOAD_KV = false;
static constexpr int NUM_MMA_THREADS = size(typename Ktraits::TiledMmaQK{});
using MainloopPipeline = typename Ktraits::MainloopPipeline;
using PipelineParams = typename MainloopPipeline::Params;
using PipelineState = typename MainloopPipeline::PipelineState;
static constexpr uint32_t TmaTransactionBytesQ =
static_cast<uint32_t>(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v<DTypeQ> / 8);
static constexpr bool UseSchedulerBarrier =
cutlass::sizeof_bits_v<DTypeQ> == 8 ? HEAD_DIM_VO >= 128 : HEAD_DIM_VO <= 128;
using WarpScheduler = WarpScheduler<Ktraits, UseSchedulerBarrier>;
// Host side kernel arguments
struct Arguments {
DTypeQ const* Q_ptr;
LayoutT layout_Q;
DTypeKV const* K_ptr;
LayoutT layout_K;
DTypeKV const* V_ptr;
LayoutT layout_V;
IdType const* kv_indices;
int window_left;
AdditionalParams additional_params;
};
// Device side kernel params
struct Params {
LayoutT layout_Q;
LayoutT layout_K;
LayoutT layout_V;
TMA_Q tma_load_Q;
DTypeKV* K_ptr;
DTypeKV* V_ptr;
IdType* kv_indices;
int window_left;
AdditionalParams additional_params;
};
static Params to_underlying_arguments(Arguments const& args) {
Tensor mQ = make_tensor(make_gmem_ptr(args.Q_ptr), args.layout_Q);
TMA_Q tma_load_Q =
make_tma_copy(GmemTiledCopyQ{}, mQ, SmemLayoutQ{}, select<0, 2>(TileShape_QKD{}), _1{});
return {args.layout_Q,
args.layout_K,
args.layout_V,
tma_load_Q,
const_cast<DTypeKV*>(args.K_ptr),
const_cast<DTypeKV*>(args.V_ptr),
const_cast<IdType*>(args.kv_indices),
args.window_left,
args.additional_params};
}
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor());
}
CUTLASS_DEVICE
int get_num_kv_tiles(Params const& mainloop_params, int q_tile_idx, const int qo_len,
const int kv_len) {
static constexpr int CTA_Q = get<0>(TileShape_QKD{});
static constexpr int CTA_KV = get<1>(TileShape_QKD{});
int num_kv_tiles = cute::ceil_div(kv_len, CTA_KV);
if constexpr (CAUSAL) {
num_kv_tiles = std::min(num_kv_tiles,
cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV));
}
if constexpr (MULTIITEMSCORING) {
num_kv_tiles = std::min(num_kv_tiles,
cute::ceil_div((q_tile_idx + 1) * CTA_Q + kv_len - qo_len, CTA_KV));
}
return num_kv_tiles;
}
template <bool LEFT_SLIDING_WINDOW, typename BlockCoord, typename Scheduler,
typename SharedStorage>
CUTLASS_DEVICE void load(Params const& mainloop_params, MainloopPipeline pipeline_k,
MainloopPipeline pipeline_v, PipelineState& smem_pipe_write_k,
PipelineState& smem_pipe_write_v, SharedStorage& shared_storage,
Scheduler& scheduler, typename Scheduler::Params const& scheduler_params,
typename Scheduler::WorkTileInfo& work_tile_info,
BlockCoord const& block_coord, int work_idx,
const int num_kv_tiles_outside_items_window = 0,
const int num_kv_tiles_prefix = 0) {
int thread_idx = threadIdx.x;
int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (thread_idx / 32) % 4, 0);
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{});
Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.layout_Q.shape());
auto [q_tile_idx, qo_head_idx, kv_head_idx, qo_indptr, kv_indptr, qo_len, kv_len, batch_idx] =
block_coord;
// Prepare the TMA loads
Tensor gQ = get_local_tile_tensor(mQ, select<0, 2>(TileShape_QKD{}), qo_head_idx, qo_indptr,
qo_len)(_, _, q_tile_idx); // (Q, D)
Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{}));
Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{}));
auto [tQgQ, tQsQ] =
tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{}, group_modes<0, 2>(sQ_x),
group_modes<0, 2>(gQ_x)); // (TMA), (TMA)
int num_kv_tiles = get_num_kv_tiles(mainloop_params, q_tile_idx, qo_len, kv_len);
int kv_tile_idx = num_kv_tiles - 1;
int swa_begin_kv_tile_idx = 0;
if constexpr (LEFT_SLIDING_WINDOW) {
swa_begin_kv_tile_idx = get_swa_begin_kv_tile_idx<CTA_Q, CTA_KV>(mainloop_params.window_left,
q_tile_idx, qo_len, kv_len);
}
constexpr int HEAD_DIM_QK = get<2>(TileShape_QKD{});
constexpr int HEAD_DIM_VO = get<1>(TileShape_PDV{});
constexpr int CTA_KV = get<1>(TileShape_QKD{});
auto indexed_gather = BlockSparseIndexedGather<IdType>(mainloop_params.kv_indices + kv_indptr);
Tensor mK = make_block_sparse_tensor( // (kv_len, D_K)
make_gmem_ptr(mainloop_params.K_ptr + kv_head_idx * stride<2>(mainloop_params.layout_K)),
make_shape(kv_len, HEAD_DIM_QK), stride<0>(mainloop_params.layout_K), indexed_gather);
Tensor mV = make_block_sparse_tensor( // (kv_len, D_V)
make_gmem_ptr(mainloop_params.V_ptr + kv_head_idx * stride<2>(mainloop_params.layout_V)),
make_shape(kv_len, HEAD_DIM_VO), stride<0>(mainloop_params.layout_V), indexed_gather);
Tensor gK =
local_tile(mK, select<1, 2>(TileShape_QKD{}), make_coord(_, _0{})); // (KV, D_K, kv)
Tensor gV =
local_tile(mV, select<2, 1>(TileShape_PDV{}), make_coord(_, _0{})); // (KV, D_V, kv)
Tensor cK = cute::make_identity_tensor(gK.shape());
Tensor cV = cute::make_identity_tensor(gV.shape());
GmemTiledCopyK gmem_tiled_copy_k;
GmemTiledCopyV gmem_tiled_copy_v;
auto gmem_thr_copy_k = gmem_tiled_copy_k.get_slice(thread_idx);
auto gmem_thr_copy_v = gmem_tiled_copy_v.get_slice(thread_idx);
Tensor tKgK = gmem_thr_copy_k.partition_S(gK); // (CPY, CPY_KV, CPY_D, kv)
Tensor tKsK = gmem_thr_copy_k.partition_D(sK); // (CPY, CPY_KV, CPY_D, PIPE)
Tensor tVgV = gmem_thr_copy_v.partition_S(gV); // (CPY, CPY_KV, CPY_D, kv)
Tensor tVsV = gmem_thr_copy_v.partition_D(sV); // (CPY, CPY_KV, CPY_D, PIPE)
Tensor tKcK = gmem_thr_copy_k.partition_D(cK); // (CPY, CPY_KV, CPY_D)
Tensor tKcKGroup = flatten_1(tKcK); // (CPY, (CPY_KV, CPY_D))
Tensor tVcV = gmem_thr_copy_v.partition_D(cV); // (CPY, CPY_KV, CPY_D)
Tensor tVcVGroup = flatten_1(tVcV); // (CPY, (CPY_KV, CPY_D))
int valid_last_kv_tile_size = std::min<int>(kv_len - kv_tile_idx * CTA_KV, CTA_KV);
auto k_predicate_fn = [&](auto coords) {
auto s_coords = tKcKGroup(_0{}, coords);
return elem_less(get<0>(s_coords), valid_last_kv_tile_size);
};
auto v_predicate_fn = [&](auto coords) {
auto s_coords = tVcVGroup(_0{}, coords);
return elem_less(get<0>(s_coords), valid_last_kv_tile_size);
};
auto kv_tile_idx_decrement = [&](int kv_tile_idx) {
int result = kv_tile_idx - 1;
if constexpr (MULTIITEMSCORING) {
if ((kv_tile_idx == num_kv_tiles_outside_items_window) &
(kv_tile_idx >= num_kv_tiles_prefix)) {
result = num_kv_tiles_prefix - 1;
}
}
return result;
};
// load last k-tile
{
pipeline_k.producer_acquire(smem_pipe_write_k);
Tensor tKgKiGroup = flatten_1(tKgK(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D))
Tensor tKsKiGroup =
flatten_1(tKsK(_, _, _, smem_pipe_write_k.index())); // (CPY, (CPY_KV, CPY_D))
copy_if(gmem_tiled_copy_k, k_predicate_fn, tKgKiGroup, tKsKiGroup);
pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_k;
}
// load Q tile
if (warp_idx_in_warpgroup == 0) {
cutlass::arch::NamedBarrier::sync(NUM_MMA_THREADS + cutlass::NumThreadsPerWarp,
static_cast<int>(NamedBarriers::kQueryEmpty));
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ);
copy(mainloop_params.tma_load_Q.with(
reinterpret_cast<cutlass::arch::ClusterTransactionBarrier::ValueType&>(
shared_storage.barrier_Q),
/*mcast_mask=*/0),
tQgQ, tQsQ);
}
}
shared_storage.barrier_O.wait((work_idx + 1) % 2);
if (kv_tile_idx == swa_begin_kv_tile_idx) {
pipeline_v.producer_acquire(smem_pipe_write_v);
Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D))
Tensor tVsViGroup =
flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D))
copy_if(gmem_tiled_copy_v, v_predicate_fn, tVgViGroup, tVsViGroup);
pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_v;
} else {
// load second last k-tile and last v-tile
pipeline_k.producer_acquire(smem_pipe_write_k);
Tensor tKgKi = tKgK(_, _, _, kv_tile_idx_decrement(kv_tile_idx)); // (CPY, CPY_KV, CPY_D)
Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D)
copy(gmem_tiled_copy_k, tKgKi, tKsKi);
pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
Tensor tVgViGroup = flatten_1(tVgV(_, _, _, kv_tile_idx)); // (CPY, (CPY_KV, CPY_D))
Tensor tVsViGroup =
flatten_1(tVsV(_, _, _, smem_pipe_write_v.index())); // (CPY, (CPY_KV, CPY_D))
copy_if(gmem_tiled_copy_v, v_predicate_fn, tVgViGroup, tVsViGroup);
pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive);
kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx);
++smem_pipe_write_v;
// load remaining k/v tiles
#pragma unroll 2
for (; kv_tile_idx > swa_begin_kv_tile_idx;
kv_tile_idx = kv_tile_idx_decrement(kv_tile_idx)) {
pipeline_k.producer_acquire(smem_pipe_write_k);
Tensor tKgKi = tKgK(_, _, _, kv_tile_idx_decrement(kv_tile_idx)); // (CPY, CPY_KV, CPY_D)
Tensor tKsKi = tKsK(_, _, _, smem_pipe_write_k.index()); // (CPY, CPY_KV, CPY_D)
copy(gmem_tiled_copy_k, tKgKi, tKsKi);
pipeline_k.producer_commit(smem_pipe_write_k, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_k;
pipeline_v.producer_acquire(smem_pipe_write_v);
Tensor tVgVi = tVgV(_, _, _, kv_tile_idx); // (CPY, CPY_KV, CPY_D)
Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, CPY_KV, CPY_D)
copy(gmem_tiled_copy_v, tVgVi, tVsVi);
pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_v;
}
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
// load first v tile
{
pipeline_v.producer_acquire(smem_pipe_write_v);
Tensor tVgVi = tVgV(_, _, _, 0); // (CPY, (CPY_KV, CPY_D))
Tensor tVsVi = tVsV(_, _, _, smem_pipe_write_v.index()); // (CPY, (CPY_KV, CPY_D))
copy(gmem_tiled_copy_v, tVgVi, tVsVi);
pipeline_v.producer_commit(smem_pipe_write_v, cutlass::arch::cpasync_barrier_arrive);
++smem_pipe_write_v;
}
}
scheduler.broadcast_next_work(work_tile_info);
}
CUTLASS_DEVICE void load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v,
PipelineState& smem_pipe_write_k,
PipelineState& smem_pipe_write_v) {
pipeline_k.producer_tail(smem_pipe_write_k);
pipeline_v.producer_tail(smem_pipe_write_v);
}
};
} // namespace flashinfer
#endif // FLASHINFER_ATTENTION_HOPPER_SPARSE_MAINLOOP_CUH_