704 lines
31 KiB
C++
Executable File
704 lines
31 KiB
C++
Executable File
/***************************************************************************************************
|
|
* Copyright (c) 2011-2024, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Redistribution and use in source and binary forms, with or without modification, are not permit-
|
|
* ted.
|
|
*
|
|
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
|
|
* IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
|
|
* FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
|
|
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
|
* BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
|
|
* OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
|
|
* STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
*
|
|
**************************************************************************************************/
|
|
#pragma once
|
|
|
|
#ifdef _WIN32
|
|
#define _USE_MATH_DEFINES
|
|
#include <math.h>
|
|
#endif
|
|
|
|
#include <cutlass/cutlass.h>
|
|
|
|
#include <cmath>
|
|
#include <cstdint>
|
|
#include <cute/tensor.hpp>
|
|
|
|
#include "../common.h"
|
|
#include "fmhaRunnerParams.h"
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
using Dtype = Data_type;
|
|
|
|
struct KernelParams {
|
|
// TMA descriptor for Q.
|
|
CUtensorMap tmaQ_;
|
|
// TMA descriptor for K.
|
|
CUtensorMap tmaK_;
|
|
// TMA descriptor for V.
|
|
CUtensorMap tmaV_;
|
|
// The descriptor for O.
|
|
CUtensorMap tmaO_;
|
|
|
|
// For FP4 KV cache, additional scaling factors are needed.
|
|
// TMA descriptor for K scaling factor.
|
|
CUtensorMap tmaKSf_;
|
|
// TMA descriptor for V scaling factor.
|
|
CUtensorMap tmaVSf_;
|
|
|
|
// grid dimensions, these might differ from actual grid the kernel is launched with
|
|
// for persistent kernels on Hopper GPUs.
|
|
int32_t logicalGridDimX, logicalGridDimY, logicalGridDimZ;
|
|
|
|
// The output pointer (used by STG for last tile).
|
|
void* ptrO;
|
|
// The output SF pointer (used for FP4 output).
|
|
void* ptrSfO;
|
|
// The attention sinks pointer (additional value per head in the denominator of the softmax).
|
|
float const* ptrAttentionSinks;
|
|
|
|
// The cumulative sequence lengths for Q.
|
|
int32_t const* ptrCumSeqLensQ;
|
|
// The cumulative sequence lengths for K/V.
|
|
int32_t const* ptrCumSeqLensKv;
|
|
// The packed custom mask.
|
|
uint32_t const* ptrCustomMask;
|
|
// The packed custom mask's offsets of each sequence.
|
|
int64_t const* ptrCustomMaskOffsets;
|
|
// The debug output matrix O
|
|
float* ptrDebugO;
|
|
// The first sparseMask offsets in the Kv sequence dimension.
|
|
int32_t const* ptrFirstSparseMaskOffsetsKv;
|
|
// The counter for the multiCtasKv mode.
|
|
int32_t* ptrMultiCtasKvCounter;
|
|
// The device output scale for FP8 quantization. Only needed by trt-llm fp8 kernels as the sca-
|
|
// les have to be on the device currently.
|
|
float const* ptrOutputScale;
|
|
// The page indexes of the paged-kv buffer with shape of [batchSize, 2, maxNumPagesPerSeq].
|
|
int32_t const* ptrPageIdxKv;
|
|
// The partial matrix O for each CtaKv when the multiCtasKv mode is enabled.
|
|
void* ptrPartialO;
|
|
// The partial softmax stats (max/sum)for each CtaKv when the multiCtasKv mode is enabled.
|
|
float2* ptrPartialStats;
|
|
// The scaling factors for K.
|
|
float const* ptrSageAttnSfsK;
|
|
// The scaling factors for P.
|
|
float const* ptrSageAttnSfsP;
|
|
// The scaling factors for Q.
|
|
float const* ptrSageAttnSfsQ;
|
|
// The scaling factors for V.
|
|
float const* ptrSageAttnSfsV;
|
|
// The device scaling factor for softmax (multiplied by log2 to use faster exp2). Only needed by
|
|
// trt-llm fp8 kernels as the scales have to be on the device currently.
|
|
float const* ptrScaleSoftmaxLog2;
|
|
// The SF scale for Kv on device. Only needed by trt-llm kernels as the scales have to be on the
|
|
// device currently.
|
|
float const* ptrScaleSfKv;
|
|
// The SF scale for O on device. Only needed by trt-llm kernels as the scales have to be on the
|
|
// device currently.
|
|
float const* ptrScaleSfO;
|
|
// The sequence lengths for K/V. Required by pagedKv kernels to avoid unnecessary computation
|
|
// based on (ptrCumSeqLensKv[batchIdx + 1] - ptrCumSeqLensKv[batchIdx]).
|
|
int32_t const* ptrSeqLensKv;
|
|
// The softmax stats buffer.
|
|
float2* ptrSoftmaxStats;
|
|
|
|
// The attention window size for sliding window attention.
|
|
int32_t mAttentionWindowSize;
|
|
// The batch size
|
|
int32_t mBatchSize;
|
|
// The chunked attention size in log2.
|
|
int32_t mChunkedAttentionSizeLog2;
|
|
// The log of the Sage Attention block size for K.
|
|
int32_t mLogNumEltsPerSageAttnBlkK;
|
|
// The log of the Sage Attention block size for P.
|
|
int32_t mLogNumEltsPerSageAttnBlkP;
|
|
// The log of the Sage Attention block size for Q.
|
|
int32_t mLogNumEltsPerSageAttnBlkQ;
|
|
// The log of the Sage Attention block size for V.
|
|
int32_t mLogNumEltsPerSageAttnBlkV;
|
|
// The sequence lengths for Q and K/V.
|
|
int32_t mMaxSeqLenQ, mMaxSeqLenKv;
|
|
// The maximum number of CTAs for Q.
|
|
int32_t mMaxNumCtasQ;
|
|
// The maximum number of CTAs for K/V.
|
|
int32_t mMaxNumCtasKv;
|
|
// The maximum number of pages per sequence for paged-kv buffer.
|
|
int32_t mMaxNumPagesPerSeqKv;
|
|
// The number of heads for K/V.
|
|
int32_t mNumHeadsKv;
|
|
// The number of heads for Q.
|
|
int32_t mNumHeadsQ;
|
|
// The number of Q heads per K/V head (i.e. mNumHeadsQ / mNumHeadsKv).
|
|
int32_t mNumHeadsQPerKv;
|
|
// The hidden size of O.
|
|
int64_t mNumHiddenEltsO;
|
|
// The total number of pages in the paged-kv memory pool.
|
|
int32_t mNumPagesInMemPool;
|
|
// The output scale for FP8 quantization.
|
|
float mOutputScale;
|
|
// The scaling factor for softmax (multiplied by log2 to use faster exp2).
|
|
float mScaleSoftmaxLog2;
|
|
// The SF scale for Kv.
|
|
float mScaleSfKv;
|
|
// The SF scale for O.
|
|
float mScaleSfO;
|
|
// The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase
|
|
// kernel when inflight batching is enabled in TRT-LLM.
|
|
int32_t mStartTokenIdxSfO;
|
|
// The sum of sequence lengths for Q and K/V.
|
|
int32_t mSumOfSeqLensQ, mSumOfSeqLensKv;
|
|
|
|
// Create the TMA shape/stride for Q.
|
|
template <class FmhaOptions>
|
|
static auto makeTmaShapeStrideQ(FmhaOptions const& options, bool groupsHeadsQ, int32_t tileSizeQ,
|
|
int32_t numEltsInClampedHeadDimQ) {
|
|
//
|
|
// The Q has shape of [numTokens * numHeadsQPerKv, numHeadsKv * 1, headDim]
|
|
// when grouping headsQ, otherwise it would be [numTokens, numHeadsQPerKv * numHeadsKv,
|
|
// headDim].
|
|
|
|
// The number of grouped heads for the A matrix of MMA.
|
|
int32_t numGroupedHeads{1};
|
|
if (groupsHeadsQ) {
|
|
numGroupedHeads = std::min(tileSizeQ, options.mNumHeadsQPerKv);
|
|
}
|
|
|
|
// The number of heads.
|
|
int32_t numHeads{options.mNumHeadsQ};
|
|
if (groupsHeadsQ) {
|
|
numHeads /= numGroupedHeads;
|
|
}
|
|
// Make sure the math works.
|
|
TORCH_CHECK(numHeads * numGroupedHeads == options.mNumHeadsQ, "internal error");
|
|
|
|
// The number of tokens.
|
|
int32_t numTokens{options.mSumOfSeqLensQ};
|
|
|
|
// This maps to flattened TMA shape for Q: (headDim, numTokens, numHeads).
|
|
auto shape = std::vector<uint64_t>{
|
|
static_cast<uint64_t>(options.mHeadDimQk), static_cast<uint64_t>(numGroupedHeads),
|
|
static_cast<uint64_t>(numHeads), static_cast<uint64_t>(numTokens)};
|
|
|
|
// The hidden dimension when the tensor contains only Q (i.e. not QKV packed).
|
|
int32_t const hiddenDimQ{options.mNumHeadsQ * options.mHeadDimQk};
|
|
|
|
// The hidden dimension when the Q, K and V tensors are packed.
|
|
int32_t hiddenDimQkv{hiddenDimQ};
|
|
if (isPackedQkv(options.mQkvLayout)) {
|
|
TORCH_CHECK(!groupsHeadsQ, "internal error");
|
|
hiddenDimQkv += options.mNumHeadsKv * (options.mHeadDimQk + options.mHeadDimV);
|
|
}
|
|
|
|
// The stride between tokens.
|
|
int32_t strideTokens{hiddenDimQkv};
|
|
|
|
// The stride between heads.
|
|
int32_t strideHeads{groupsHeadsQ ? numGroupedHeads * options.mHeadDimQk : options.mHeadDimQk};
|
|
|
|
// The stride between grouped heads.
|
|
int32_t strideGroupedHeads{options.mHeadDimQk};
|
|
|
|
// Assemble the stride (1, strideTokens, strideHeads).
|
|
// Swap the first two dimension as mentioned before.
|
|
auto stride = std::vector<uint64_t>{1, static_cast<uint64_t>(strideGroupedHeads),
|
|
static_cast<uint64_t>(strideHeads),
|
|
static_cast<uint64_t>(strideTokens)};
|
|
|
|
// The tile shape for TMA.
|
|
auto tileShapes = std::vector<uint32_t>{static_cast<uint32_t>(numEltsInClampedHeadDimQ), 1, 1,
|
|
static_cast<uint32_t>(tileSizeQ)};
|
|
if (groupsHeadsQ) {
|
|
if (isSpecDecodingGenerationKernel(options.mKernelType)) {
|
|
TORCH_CHECK((tileSizeQ % numGroupedHeads == 0), "internal error");
|
|
tileShapes = std::vector<uint32_t>{static_cast<uint32_t>(numEltsInClampedHeadDimQ),
|
|
static_cast<uint32_t>(numGroupedHeads), 1,
|
|
static_cast<uint32_t>(tileSizeQ / numGroupedHeads)};
|
|
} else {
|
|
tileShapes = std::vector<uint32_t>{static_cast<uint32_t>(numEltsInClampedHeadDimQ),
|
|
static_cast<uint32_t>(tileSizeQ), 1, 1};
|
|
}
|
|
}
|
|
|
|
return std::make_tuple(shape, stride, tileShapes);
|
|
}
|
|
|
|
// Create the TMA shape/stride for O.
|
|
template <class FmhaOptions>
|
|
static auto makeTmaShapeStrideO(FmhaOptions const& options) {
|
|
//
|
|
// TODO: refactor this as makeTmaShapeStrideQ when removing cutlass tma copy.
|
|
//
|
|
|
|
// The number of tokens.
|
|
int32_t numTokens{options.mSumOfSeqLensQ};
|
|
|
|
// The number of heads per K/V head.
|
|
int32_t numHeadsQPerKv{options.mNumHeadsQPerKv};
|
|
|
|
// The batch dimension.
|
|
int32_t batchSize{1};
|
|
|
|
// The cute tensor shape for Q/O: (numTokens, headDim, ((numHeadsKv, numHeadsQPerKv),
|
|
// batchSize)). This maps to flattened TMA shape for Q/O: (headDim, numTokens, numHeadsKv.
|
|
// numHeadsQPerKv, batchSize). Note that TMA descriptor expects the first dimension's stride to
|
|
// be 1, so swap the first two dimension so that the headDim dimension comes first.
|
|
auto shape = std::vector<uint64_t>{
|
|
static_cast<uint64_t>(options.mHeadDimV), static_cast<uint64_t>(numTokens),
|
|
static_cast<uint64_t>(options.mNumHeadsKv), static_cast<uint64_t>(numHeadsQPerKv),
|
|
static_cast<uint64_t>(batchSize)};
|
|
|
|
// The hidden dimension.
|
|
int32_t const hiddenDimO{options.mNumHeadsQ * options.mHeadDimV};
|
|
|
|
// The stride between tokens.
|
|
int32_t strideTokens{hiddenDimO};
|
|
|
|
// The stride between Q heads.
|
|
int32_t strideHeadsQ{options.mNumHeadsKv * options.mHeadDimV};
|
|
|
|
// The stride between sequences.
|
|
int32_t strideBatch{0};
|
|
|
|
// The stride in between K/V heads.
|
|
int32_t strideHeadsKv{options.mHeadDimV};
|
|
// Assemble the stride (strideTokens, 1, ((strideHeadsKv, strideHeadsQ), strideBatch)).
|
|
// Swap the first two dimension as mentioned before.
|
|
auto stride = std::vector<uint64_t>{
|
|
1, static_cast<uint64_t>(strideTokens), static_cast<uint64_t>(strideHeadsKv),
|
|
static_cast<uint64_t>(strideHeadsQ), static_cast<uint64_t>(strideBatch)};
|
|
|
|
return std::make_tuple(shape, stride);
|
|
}
|
|
|
|
// Create the shape for K and V.
|
|
template <class FmhaOptions>
|
|
static auto makeShapeKv(FmhaOptions const& options, KernelParams const& params) {
|
|
// The number of keys/vals. WARNING: The if/else-if are sorted by priority.
|
|
int32_t numKeysVals{options.mMaxSeqLenKv};
|
|
if (isPagedKv(options.mQkvLayout)) {
|
|
numKeysVals = options.mNumTokensPerPage;
|
|
} else if (isContiguousKv(options.mQkvLayout)) {
|
|
numKeysVals = options.mMaxSeqLenCacheKv;
|
|
} else {
|
|
numKeysVals = options.mSumOfSeqLensKv;
|
|
}
|
|
|
|
// The number of heads per K/V head (packed in the sequence length for mGroupsHeadsQ).
|
|
int32_t numHeadsKv{options.mNumHeadsKv};
|
|
|
|
// The batch dimension. WARNING: The if/else-if are sorted by priority.
|
|
int32_t batchSize{options.mBatchSize};
|
|
if (isPagedKv(options.mQkvLayout)) {
|
|
batchSize = params.mNumPagesInMemPool;
|
|
} else if (isContiguousKv(options.mQkvLayout)) {
|
|
batchSize = options.mBatchSize;
|
|
} else {
|
|
batchSize = 1;
|
|
}
|
|
|
|
// Return the number of keys and batch.
|
|
return std::make_tuple(numKeysVals, numHeadsKv, batchSize);
|
|
}
|
|
|
|
// Compute the strides for K and V.
|
|
template <class FmhaOptions>
|
|
static auto makeStrideKv(FmhaOptions const& options, bool isK) {
|
|
int strideKeysVals = 0;
|
|
int strideHeads = 0;
|
|
int strideBatch = 0;
|
|
|
|
if (isK) {
|
|
strideKeysVals = options.kStrideKeysValues;
|
|
strideHeads = options.kStrideHeads;
|
|
strideBatch = options.kStrideBatch;
|
|
} else {
|
|
strideKeysVals = options.vStrideKeysValues;
|
|
strideHeads = options.vStrideHeads;
|
|
strideBatch = options.vStrideBatch;
|
|
}
|
|
// The 3 strides (the other ones are 1 and 0).
|
|
return std::make_tuple(strideKeysVals, strideHeads, strideBatch);
|
|
}
|
|
|
|
// Create the TMA shape/stride for K.
|
|
template <class FmhaOptions>
|
|
static auto makeTmaShapeStrideKv(FmhaOptions const& options, KernelParams const& params,
|
|
Data_type dtypeKv, bool isK) {
|
|
// The shape elements.
|
|
auto [numKeys, numHeadsQPerKv, batchSize] = makeShapeKv(options, params);
|
|
// The stride elements.
|
|
auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, isK);
|
|
|
|
// The headDim.
|
|
// Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv.
|
|
int32_t headDim = isK ? options.mHeadDimQk : options.mHeadDimV;
|
|
if (isPagedKv(options.mQkvLayout) || isContiguousKv(options.mQkvLayout)) {
|
|
headDim = std::max(options.mHeadDimQk, options.mHeadDimV);
|
|
}
|
|
|
|
// For K, the cute layout: (numKeys, headDim, ((numHeadsQPerKv, numHeadsKv),
|
|
// batchSize)):(strideKeys, _1, _0, strideHeads, strideBatch). Cute swaps the first two
|
|
// dimension (to make sure stride of first dimension is 1) and ignores the numHeadsQPerKv
|
|
// dimension (it's stride is always 0). For V, the headDim dimension is already the first
|
|
// dimension so no swapping is needed.
|
|
|
|
// Therefore, the resulting TMA layout is 4D: (headDim, numKeys, numHeadsKv, batchSize):(1,
|
|
// strideKeys, strideHeads, strideBatch)
|
|
|
|
// Note that for FP4 KV input, elements are stored as uint8_t, each packs 2 FP4 elements.
|
|
// The column index and strides needs to divide by 2.
|
|
auto const colIdxDivisor = dtypeKv == DATA_TYPE_E2M1 ? 2 : 1;
|
|
auto shape = std::vector<uint64_t>{
|
|
static_cast<uint64_t>(headDim / colIdxDivisor), static_cast<uint64_t>(numKeys),
|
|
static_cast<uint64_t>(options.mNumHeadsKv), static_cast<uint64_t>(batchSize)};
|
|
auto stride = std::vector<uint64_t>{1, static_cast<uint64_t>(strideKeys / colIdxDivisor),
|
|
static_cast<uint64_t>(strideHeads / colIdxDivisor),
|
|
static_cast<uint64_t>(strideBatch / colIdxDivisor)};
|
|
|
|
return std::make_tuple(shape, stride);
|
|
}
|
|
|
|
// Create the TMA shape/stride for KV scaling factors.
|
|
template <class FmhaOptions>
|
|
static auto makeTmaShapeStrideKvSf(FmhaOptions const& options, KernelParams const& params,
|
|
bool isK) {
|
|
// The shape elements.
|
|
auto [numKeys, numHeadsQPerKv, batchSize] = makeShapeKv(options, params);
|
|
// The stride elements.
|
|
auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, isK);
|
|
|
|
// The headDim.
|
|
// Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv.
|
|
int32_t headDim = isK ? options.mHeadDimQk : options.mHeadDimV;
|
|
if (isPagedKv(options.mQkvLayout) || isContiguousKv(options.mQkvLayout)) {
|
|
headDim = std::max(options.mHeadDimQk, options.mHeadDimV);
|
|
}
|
|
// The number of elements per SF.
|
|
int32_t NumEltsPerSf = 16;
|
|
|
|
// The KV shape is: (headDim, numKeys, numHeadsKv, batchSize)
|
|
// Therefore, the KV SF shape should be (headDim / NumEltsPerSf, numKeys, numHeadsKv,
|
|
// batchSize). Considering the TMA requires box width to be multiple of 16B, without changing
|
|
// the underlying layout, we reshape into (16, numKeys * headDim / NumEltsPerSf / 16,
|
|
// numHeadsKv, batchSize)
|
|
|
|
// Note that it only works for pagedKv layout.
|
|
TORCH_CHECK(isPagedKv(options.mQkvLayout), "The qkvLayout is not supported.");
|
|
|
|
auto shape = std::vector<uint64_t>{
|
|
16, static_cast<uint64_t>(numKeys * headDim / NumEltsPerSf / 16),
|
|
static_cast<uint64_t>(options.mNumHeadsKv), static_cast<uint64_t>(batchSize)};
|
|
auto stride = std::vector<uint64_t>{1, 16, static_cast<uint64_t>(strideHeads / NumEltsPerSf),
|
|
static_cast<uint64_t>(strideBatch / NumEltsPerSf)};
|
|
|
|
return std::make_tuple(shape, stride);
|
|
}
|
|
|
|
// Prepare pointers for TMA descriptors.
|
|
static std::tuple<void const*, void const*, void const*> getDevicePtrs(
|
|
TllmGenFmhaRunnerParams const& runnerParams, int32_t bytesPerElt) {
|
|
// Declare the q, k, v ptrs.
|
|
void const *qPtr{runnerParams.qPtr}, *kPtr{runnerParams.kPtr}, *vPtr{runnerParams.vPtr};
|
|
|
|
// Set Q, K and V pointer from packed QKV tensor.
|
|
if (isPackedQkv(runnerParams.mQkvLayout)) {
|
|
qPtr = runnerParams.qkvPtr;
|
|
kPtr = reinterpret_cast<void const*>(reinterpret_cast<char const*>(runnerParams.qkvPtr) +
|
|
runnerParams.mNumHeadsQ * runnerParams.mHeadDimQk *
|
|
bytesPerElt);
|
|
vPtr = reinterpret_cast<void const*>(reinterpret_cast<char const*>(runnerParams.qkvPtr) +
|
|
(runnerParams.mNumHeadsQ + runnerParams.mNumHeadsKv) *
|
|
runnerParams.mHeadDimQk * bytesPerElt);
|
|
}
|
|
// Set K and V pointer from pagedKv tensor.
|
|
else if (isPagedKv(runnerParams.mQkvLayout)) {
|
|
// Note that the offsets will be fully handled by the pageIdx buffer.
|
|
kPtr = runnerParams.kPtr;
|
|
vPtr = runnerParams.vPtr;
|
|
}
|
|
// Set K and V pointer from contiguousQAnddKv tensor.
|
|
else if (isContiguousKv(runnerParams.mQkvLayout)) {
|
|
kPtr = runnerParams.kvPtr;
|
|
// The maximum headDim of K and V.
|
|
// Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv.
|
|
int32_t const maxHeadDimKv{std::max(runnerParams.mHeadDimQk, runnerParams.mHeadDimV)};
|
|
vPtr = reinterpret_cast<void const*>(
|
|
reinterpret_cast<char const*>(runnerParams.kvPtr) +
|
|
runnerParams.mNumHeadsKv * runnerParams.mMaxSeqLenCacheKv * maxHeadDimKv * bytesPerElt);
|
|
}
|
|
|
|
// Return the pointers.
|
|
return std::make_tuple(qPtr, kPtr, vPtr);
|
|
}
|
|
|
|
// Build tma descriptors.
|
|
template <class FmhaOptions>
|
|
static CUtensorMap buildNdTmaDescriptor(FmhaOptions const& options, Data_type dtypeElt,
|
|
std::vector<uint64_t> const& shapes,
|
|
std::vector<uint64_t> const& strides,
|
|
std::vector<uint32_t> const& tileShapes, void* gmemAddr,
|
|
bool swizzled = true) {
|
|
CUtensorMap desc{};
|
|
// The data type.
|
|
CUtensorMapDataType tmaDataFormat;
|
|
if (dtypeElt == DATA_TYPE_E2M1 || dtypeElt == DATA_TYPE_E4M3) {
|
|
tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_UINT8;
|
|
} else if (dtypeElt == DATA_TYPE_FP16) {
|
|
tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_FLOAT16;
|
|
} else if (dtypeElt == DATA_TYPE_BF16) {
|
|
tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16;
|
|
} else {
|
|
TORCH_CHECK(false, "Unexpected dtype %d", static_cast<int32_t>(dtypeElt));
|
|
}
|
|
|
|
// The swizzle type.
|
|
CUtensorMapSwizzle swizzleType;
|
|
int32_t numBytesInLeadingDim = tileShapes[0] * get_size_in_bits(dtypeElt) / 8 /*bits*/;
|
|
if (!swizzled) {
|
|
swizzleType = CU_TENSOR_MAP_SWIZZLE_NONE;
|
|
} else if ((numBytesInLeadingDim % 128) == 0) {
|
|
swizzleType = CU_TENSOR_MAP_SWIZZLE_128B;
|
|
} else if ((numBytesInLeadingDim % 64) == 0) {
|
|
swizzleType = CU_TENSOR_MAP_SWIZZLE_64B;
|
|
} else if ((numBytesInLeadingDim % 32) == 0) {
|
|
swizzleType = CU_TENSOR_MAP_SWIZZLE_32B;
|
|
} else {
|
|
TORCH_CHECK(false, "Unexpected numBytesInLeadingDim %d", numBytesInLeadingDim);
|
|
}
|
|
|
|
// Check gmem address must be 16B-aligned
|
|
TORCH_CHECK((reinterpret_cast<uint64_t>(gmemAddr) & 0b1111) == 0);
|
|
|
|
// Check shape must be in range [1, 2^32]
|
|
int32_t dim = shapes.size();
|
|
// Max five dimension and min 3 dimension.
|
|
TORCH_CHECK((dim <= 5) && (dim >= 3));
|
|
// Check shape range.
|
|
for (int32_t ii = 0; ii < dim; ++ii) {
|
|
TORCH_CHECK(shapes[ii] >= (uint64_t(1))); // Size must be min 1
|
|
TORCH_CHECK(shapes[ii] <= (uint64_t(1) << 32)); // Size must be max 2^32
|
|
}
|
|
|
|
// TMA descriptor does not store the zeroth stride and assumes it is 1.
|
|
TORCH_CHECK(static_cast<int32_t>(strides.size()) == dim);
|
|
TORCH_CHECK(strides[0] == 1);
|
|
|
|
// Build strides in bytes.
|
|
// cuTensorMapEncodeTiled ignores the stride of the first dimension (implicitly 1).
|
|
std::vector<uint64_t> stridesInBytes(dim - 1);
|
|
for (int32_t ii = 0; ii < dim - 1; ++ii) {
|
|
stridesInBytes[ii] = strides[ii + 1] *
|
|
std::max(get_size_in_bits(dtypeElt), static_cast<size_t>(8)) / 8 /*bit*/;
|
|
}
|
|
|
|
// Set tile strides to 0;
|
|
std::vector<uint32_t> tileStrides(dim, 1);
|
|
|
|
// Build the descriptor.
|
|
CUresult result =
|
|
cuTensorMapEncodeTiled(&desc, tmaDataFormat,
|
|
/*tensorRank=*/dim, gmemAddr, shapes.data(), stridesInBytes.data(),
|
|
tileShapes.data(), tileStrides.data(),
|
|
/*interleave=*/CU_TENSOR_MAP_INTERLEAVE_NONE, swizzleType,
|
|
/*l2Promotion=*/CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
|
|
/*oobFill=*/CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
|
|
|
|
if (result != CUDA_SUCCESS) {
|
|
char const* err_str;
|
|
cuGetErrorString(result, &err_str);
|
|
std::cerr << "Error: Failed to initialize the TMA descriptor due to " << err_str << std::endl;
|
|
std::cerr << "tmaFormat: " << static_cast<int>(tmaDataFormat) << " dim: " << dim
|
|
<< " gmem: " << gmemAddr << std::endl;
|
|
std::cerr << "Shape: " << shapes[0] << " " << shapes[1] << " " << shapes[2] << " "
|
|
<< shapes[3] << " " << shapes[4] << std::endl;
|
|
std::cerr << "Stride: " << stridesInBytes[0] << " " << stridesInBytes[1] << " "
|
|
<< stridesInBytes[2] << " " << stridesInBytes[3] << std::endl;
|
|
std::cerr << "tileShapes: " << tileShapes[0] << " " << tileShapes[1] << " " << tileShapes[2]
|
|
<< " " << tileShapes[3] << " " << tileShapes[4] << std::endl;
|
|
std::cerr << "tileStrides: " << tileStrides[0] << " " << tileStrides[1] << " "
|
|
<< tileStrides[2] << " " << tileStrides[3] << " " << tileStrides[4] << std::endl;
|
|
std::cerr << "swizzleType: " << int(swizzleType) << std::endl;
|
|
TORCH_CHECK(false);
|
|
}
|
|
|
|
return desc;
|
|
}
|
|
|
|
// Setup the kernel parameters.
|
|
template <class FmhaOptions_, class KernelMeta>
|
|
static KernelParams setKernelParams(FmhaOptions_ const& options, KernelMeta const& kernelMeta,
|
|
int32_t maxNumCtasQ, int32_t maxNumCtasKv) {
|
|
// Create the return struct.
|
|
KernelParams params;
|
|
|
|
// Get the device pointers for TMA descriptors.
|
|
auto [qPtr, kPtr, vPtr] = getDevicePtrs(options, get_size_in_bytes(kernelMeta.mDataTypeKv));
|
|
|
|
// The maximum headDim of K and V.
|
|
// Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv.
|
|
int32_t const maxHeadDimKv{std::max(options.mHeadDimQk, options.mHeadDimV)};
|
|
|
|
// Set the number of pages in the memory pool for paged K/V cache.
|
|
if (isPagedKv(options.mQkvLayout)) {
|
|
params.mNumPagesInMemPool = options.mNumPagesInMemPool == 0
|
|
? options.mMaxNumPagesPerSeqKv * 2 * options.mBatchSize
|
|
: options.mNumPagesInMemPool;
|
|
}
|
|
|
|
// The number of elements in 128B for Q.
|
|
int32_t numEltsIn128BQ = (128 * 8) / get_size_in_bits(kernelMeta.mDataTypeQ);
|
|
// The number of head elts (per token) in each block of shared memory.
|
|
int32_t numEltsInClampedHeadDimQ = std::min(numEltsIn128BQ, options.mHeadDimQk);
|
|
|
|
// Shape/stride for gmem tensor Q.
|
|
auto [shapeQ, strideQ, tileShapeQ] = makeTmaShapeStrideQ(
|
|
options, kernelMeta.mGroupsHeadsQ, kernelMeta.mTileSizeQ, numEltsInClampedHeadDimQ);
|
|
// Build tma descriptor for Q.
|
|
params.tmaQ_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeQ, shapeQ, strideQ, tileShapeQ,
|
|
const_cast<void*>(qPtr));
|
|
|
|
// The number of keys per tile.
|
|
int32_t numKeysPerTile = isPagedKv(options.mQkvLayout)
|
|
? std::min(options.mNumTokensPerPage, kernelMeta.mTileSizeKv)
|
|
: kernelMeta.mTileSizeKv;
|
|
// The number of elements in 128B for Q.
|
|
int32_t numEltsIn128BKv = (128 * 8) / get_size_in_bits(kernelMeta.mDataTypeKv);
|
|
// The number of head elts (per token) in each block of shared memory (see above explanation).
|
|
int32_t numEltsInClampedHeadDimKv = std::min(numEltsIn128BKv, maxHeadDimKv);
|
|
|
|
// Shape/stride for gmem tensor Kv.
|
|
auto [shapeK, strideK] =
|
|
makeTmaShapeStrideKv(options, params, kernelMeta.mDataTypeKv, /*isK*/ true);
|
|
auto [shapeV, strideV] =
|
|
makeTmaShapeStrideKv(options, params, kernelMeta.mDataTypeKv, /*isK*/ false);
|
|
// Build tma descriptor for K.
|
|
// Do we have to transform K/V before MMA?
|
|
bool const transformsKv{kernelMeta.mDataTypeKv != kernelMeta.mDataTypeQ};
|
|
// Note that for FP4 KV input, elements are stored as uint8_t, each packs 2 FP4 elements.
|
|
auto const numEltsDivisor = kernelMeta.mDataTypeKv == DATA_TYPE_E2M1 ? 2 : 1;
|
|
// The tileShapes for K/V.
|
|
std::vector<uint32_t> tileShapeKv(shapeK.size(), 1);
|
|
tileShapeKv[0] = numEltsInClampedHeadDimKv / numEltsDivisor;
|
|
tileShapeKv[1] = numKeysPerTile;
|
|
// Build tma descriptor for K.
|
|
params.tmaK_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeKv, shapeK, strideK,
|
|
tileShapeKv, const_cast<void*>(kPtr),
|
|
/*swizzled = */ !transformsKv);
|
|
// Build tma descriptor for V.
|
|
params.tmaV_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeKv, shapeV, strideV,
|
|
tileShapeKv, const_cast<void*>(vPtr),
|
|
/*swizzled = */ !transformsKv);
|
|
|
|
// If the KV dtype is E2m1, additional scaling factors are needed for dequant.
|
|
if (kernelMeta.mDataTypeKv == DATA_TYPE_E2M1) {
|
|
// The number of elements per SF.
|
|
int32_t NumEltsPerSf = 16;
|
|
// Compute the shape and stride for SF tensor.
|
|
// FIXME: assume K and V uses the same shape.
|
|
auto [shapeKvSf, strideKvSf] = makeTmaShapeStrideKvSf(options, params, /*isK*/ true);
|
|
|
|
// The tileShapes for K/V.
|
|
std::vector<uint32_t> tileShapeKvSf(shapeKvSf.size(), 1);
|
|
tileShapeKvSf[0] = 16;
|
|
tileShapeKvSf[1] = numKeysPerTile * maxHeadDimKv / NumEltsPerSf / 16;
|
|
|
|
// The tile box is reshaped from (headDim / NumEltsPerSf, tileSizeKv) into (16, tileSizeKv *
|
|
// headDim / NumEltsPerSf / 16). See makeTmaShapeStrideKvSf for details. Build tma descriptor
|
|
// for K SF.
|
|
params.tmaKSf_ = buildNdTmaDescriptor(options, DATA_TYPE_E4M3, shapeKvSf, strideKvSf,
|
|
tileShapeKvSf, const_cast<void*>(options.kSfBasePtr),
|
|
/*swizzled = */ false);
|
|
|
|
// Build tma descriptor for V SF.
|
|
params.tmaVSf_ = buildNdTmaDescriptor(options, DATA_TYPE_E4M3, shapeKvSf, strideKvSf,
|
|
tileShapeKvSf, const_cast<void*>(options.vSfBasePtr),
|
|
/*swizzled = */ false);
|
|
}
|
|
|
|
// Shape/stride for gmem tensor O.
|
|
auto [shapeO, strideO] = makeTmaShapeStrideO(options);
|
|
// The tileShapes for O.
|
|
std::vector<uint32_t> tileShapeO(shapeO.size(), 1);
|
|
tileShapeO[0] = numEltsInClampedHeadDimQ;
|
|
tileShapeO[1] = kernelMeta.mTileSizeQ;
|
|
// Build tma descriptor for O.
|
|
params.tmaO_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeQ, shapeO, strideO, tileShapeO,
|
|
const_cast<void*>(options.oPtr));
|
|
|
|
// Set the other kernel parameters.
|
|
params.ptrCumSeqLensQ = options.cumSeqLensQPtr;
|
|
params.ptrCumSeqLensKv = options.cumSeqLensKvPtr;
|
|
|
|
// The packed custom mask.
|
|
params.ptrCustomMask = options.customMaskPtr;
|
|
// The packed custom mask's offsets of each sequence.
|
|
params.ptrCustomMaskOffsets = options.customMaskOffsetsPtr;
|
|
// The first sparseMask offsets in the Kv sequence dimension.
|
|
params.ptrFirstSparseMaskOffsetsKv = options.firstSparseMaskOffsetsKvPtr;
|
|
|
|
// The output buffer.
|
|
params.ptrO = options.oPtr;
|
|
// The output scaling factor buffer.
|
|
params.ptrSfO = options.oSfPtr;
|
|
|
|
// TRT-LLM restrictions: the quantization scales must be on the device.
|
|
params.ptrOutputScale = options.outputScalePtr;
|
|
|
|
// The sequence lengths for Kv.
|
|
params.ptrSeqLensKv = options.seqLensKvPtr;
|
|
|
|
// Attention sink
|
|
params.ptrAttentionSinks = options.ptrAttentionSinks;
|
|
|
|
// The partial buffers' pointers when the multiCtasKv mode is enabled.
|
|
int64_t partialStatsBufferSize = options.mMultiProcessorCount * kernelMeta.mStepQ;
|
|
params.ptrMultiCtasKvCounter = options.multiCtasKvCounterPtr;
|
|
params.ptrPartialStats = reinterpret_cast<float2*>(options.multiCtasKvScratchPtr);
|
|
params.ptrPartialO = params.ptrPartialStats + partialStatsBufferSize;
|
|
|
|
params.ptrPageIdxKv = options.kvPageIdxPtr;
|
|
params.ptrScaleSoftmaxLog2 = options.scaleSoftmaxLog2Ptr;
|
|
|
|
params.ptrScaleSfKv = options.kvSfScalePtr;
|
|
params.ptrScaleSfO = options.oSfScalePtr;
|
|
params.mScaleSfO = options.mScaleSfO;
|
|
|
|
params.mAttentionWindowSize = options.mAttentionWindowSize;
|
|
if (isSlidingOrChunkedCausalMask(
|
|
static_cast<TrtllmGenAttentionMaskType>(kernelMeta.mMaskType)) &&
|
|
options.mChunkedAttentionSize != INT_MAX) {
|
|
TORCH_CHECK((options.mChunkedAttentionSize & (options.mChunkedAttentionSize - 1)) == 0,
|
|
"Chunked attention size must be a power of 2");
|
|
params.mChunkedAttentionSizeLog2 = std::log2(options.mChunkedAttentionSize);
|
|
} else {
|
|
// Default 0 means that chunked attention is disabled.
|
|
params.mChunkedAttentionSizeLog2 = 0;
|
|
}
|
|
params.mMaxSeqLenQ = options.mMaxSeqLenQ;
|
|
params.mMaxSeqLenKv = options.mMaxSeqLenKv;
|
|
params.mMaxNumCtasQ = maxNumCtasQ;
|
|
params.mMaxNumCtasKv = maxNumCtasKv;
|
|
params.mMaxNumPagesPerSeqKv = options.mMaxNumPagesPerSeqKv;
|
|
// TODO: just use mMaxSeqLenQ for number of MTP tokens.
|
|
params.mSumOfSeqLensQ = options.mSumOfSeqLensQ;
|
|
params.mSumOfSeqLensKv = options.mSumOfSeqLensKv;
|
|
params.mBatchSize = options.mBatchSize;
|
|
params.mChunkedAttentionSizeLog2 = 0;
|
|
params.mNumHeadsQ = options.mNumHeadsQ;
|
|
params.mNumHeadsKv = options.mNumHeadsKv;
|
|
params.mNumHeadsQPerKv = options.mNumHeadsQPerKv;
|
|
params.mNumHiddenEltsO = options.mNumHeadsQ * options.mHeadDimQk;
|
|
// todo(Yingyi): might take a scalar tensor later
|
|
params.mOutputScale = options.outputScale;
|
|
params.mScaleSoftmaxLog2 = options.scaleSoftmaxLog2;
|
|
params.mStartTokenIdxSfO = options.mSfStartTokenIdx;
|
|
params.mScaleSfKv = options.mScaleSfKv;
|
|
params.ptrSoftmaxStats = options.softmaxStatsPtr;
|
|
return params;
|
|
}
|
|
};
|