364 lines
12 KiB
C++
Executable File
364 lines
12 KiB
C++
Executable File
/*
|
|
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
#include <cstdint>
|
|
#include <cstdio>
|
|
#include <cstring>
|
|
|
|
#include "flashinfer/exception.h"
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// The attention mask types.
|
|
enum class TrtllmGenAttentionMaskType {
|
|
// Dense mask.
|
|
Dense = 0,
|
|
// Causal mask.
|
|
Causal,
|
|
// Sliding window or chunked causal mask.
|
|
SlidingOrChunkedCausal,
|
|
// Custom mask.
|
|
Custom
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Helper functions to check the mask type.
|
|
|
|
#define ATTENTION_MASK_TYPE_FUNCTION(MaskType) \
|
|
inline bool is##MaskType##Mask(TrtllmGenAttentionMaskType maskType) { \
|
|
return (maskType == TrtllmGenAttentionMaskType::MaskType); \
|
|
}
|
|
|
|
ATTENTION_MASK_TYPE_FUNCTION(Dense)
|
|
ATTENTION_MASK_TYPE_FUNCTION(Causal)
|
|
ATTENTION_MASK_TYPE_FUNCTION(SlidingOrChunkedCausal)
|
|
ATTENTION_MASK_TYPE_FUNCTION(Custom)
|
|
|
|
#undef ATTENTION_MASK_TYPE_FUNCTION
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
enum class FmhaKernelType {
|
|
// The context-phase kernels.
|
|
Context = 0,
|
|
// Choose the best generation kernel based on the heuristic:
|
|
// use SwapsMmaAbForGeneration kernels when numHeadsQPerKv <= 16, otherwise
|
|
// KeepsMmaAbForGeneration.
|
|
Generation = 1,
|
|
// Swap tensor A and tensor B of Mma, which only supports numHeadsQPerKv <= 16.
|
|
SwapsMmaAbForGeneration,
|
|
// Keep tensor A and tensor B of Mma.
|
|
KeepsMmaAbForGeneration,
|
|
// Speculative decoding (Medusa and Eagle) generation-phase attention kernels, where seqLenQ > 1.
|
|
SpecDecodingGeneration
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Helper functions to check the fmha kernel type.
|
|
|
|
#define FMHA_KERNEL_TYPE_FUNCTION(KernelType) \
|
|
inline bool is##KernelType##Kernel(FmhaKernelType kernelType) { \
|
|
return (kernelType == FmhaKernelType::KernelType); \
|
|
}
|
|
|
|
FMHA_KERNEL_TYPE_FUNCTION(Context)
|
|
FMHA_KERNEL_TYPE_FUNCTION(Generation)
|
|
FMHA_KERNEL_TYPE_FUNCTION(SwapsMmaAbForGeneration)
|
|
FMHA_KERNEL_TYPE_FUNCTION(KeepsMmaAbForGeneration)
|
|
FMHA_KERNEL_TYPE_FUNCTION(SpecDecodingGeneration)
|
|
|
|
#undef QKV_LAYOUT_FUNCTION
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Note that (batchSize, seqLen) dimensions will be packed as sumOfSeqLens without paddings for
|
|
// variable sequence lengths.
|
|
enum class QkvLayout {
|
|
// SeparateQkv: separate Q, K and V buffers.
|
|
// Each has the shape: [batchSize, seqLen, numHeads, headDim].
|
|
SeparateQkv = 0,
|
|
// PackedQkv: single buffer for Q, K and V.
|
|
// Shape: [batchSize, seqLen, numHeadsQ + 2*numHeadsKv, headDim].
|
|
PackedQkv,
|
|
// Paged buffer for K and V. Its shape is [batchSize, 2, maxNumPagesPerSeq]. The 2 corresponds to
|
|
// K
|
|
// and V. That buffer stores the logical page index of the paged-KV memory pool. Each "page" of
|
|
// that
|
|
// pool is a contiguous buffer of shape [numHeadsKv, pageSize, headDim].
|
|
PagedKv,
|
|
// ContiguousKv:
|
|
// Contiguous buffer for Q with shape [batchSize, seqLen, numHeads, headDim].
|
|
// Contiguous buffer for Kv with shape [batchSize, seqLen, 2 * numHeads, headDim].
|
|
ContiguousKv,
|
|
};
|
|
|
|
// Helper functions to check the QkvLayout type.
|
|
|
|
#define QKV_LAYOUT_FUNCTION(LayoutType) \
|
|
inline bool is##LayoutType(QkvLayout qkvLayout) { return (qkvLayout == QkvLayout::LayoutType); }
|
|
|
|
QKV_LAYOUT_FUNCTION(SeparateQkv)
|
|
QKV_LAYOUT_FUNCTION(PackedQkv)
|
|
QKV_LAYOUT_FUNCTION(PagedKv)
|
|
QKV_LAYOUT_FUNCTION(ContiguousKv)
|
|
|
|
#undef QKV_LAYOUT_FUNCTION
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
enum class TileScheduler {
|
|
// Static scheduler (Non-persistent).
|
|
Static = 0,
|
|
// Persistent scheduler.
|
|
Persistent
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
enum class MultiCtasKvMode {
|
|
// No multiCtasKvMode.
|
|
Disabled = 0,
|
|
// Do the reduction through the global memory and atomic counters.
|
|
GmemReduction,
|
|
// Do the reduction through the CGA remote shared memory.
|
|
CgaSmemReduction
|
|
};
|
|
|
|
// Helper function to check if the multiCtasKv is enabled.
|
|
inline bool isMultiCtasKvEnabled(MultiCtasKvMode multiCtasKvMode) {
|
|
return multiCtasKvMode != MultiCtasKvMode::Disabled;
|
|
}
|
|
|
|
// Helper function to check the multiCtasKvMode type.
|
|
|
|
#define MULTI_CTAS_KV_MODE_FUNCTION(Type) \
|
|
inline bool is##Type(MultiCtasKvMode multiCtasKvMode) { \
|
|
return (multiCtasKvMode == MultiCtasKvMode::Type); \
|
|
}
|
|
|
|
MULTI_CTAS_KV_MODE_FUNCTION(Disabled)
|
|
MULTI_CTAS_KV_MODE_FUNCTION(GmemReduction)
|
|
MULTI_CTAS_KV_MODE_FUNCTION(CgaSmemReduction)
|
|
|
|
#undef MULTI_CTAS_KV_MODE_FUNCTION
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
struct TllmGenFmhaRunnerParams {
|
|
// Input layout.
|
|
QkvLayout mQkvLayout;
|
|
// Attention mask type.
|
|
TrtllmGenAttentionMaskType mMaskType;
|
|
// The kernel type.
|
|
FmhaKernelType mKernelType;
|
|
// The tile scheduler.
|
|
TileScheduler mTileScheduler;
|
|
// The multiCtasKvMode (i.e. multiBlockMode).
|
|
bool mMultiCtasKvMode;
|
|
|
|
// Input QKV buffers.
|
|
void const* qPtr;
|
|
void const* kPtr;
|
|
void const* vPtr;
|
|
// Packed KV buffer
|
|
void const* kvPtr;
|
|
// Packed QKV buffer
|
|
void const* qkvPtr;
|
|
// The scaling factor pointer of K.
|
|
void const* kSfBasePtr;
|
|
// The scaling factor pointer of V.
|
|
void const* vSfBasePtr;
|
|
// The custom mask ptr.
|
|
uint32_t const* customMaskPtr;
|
|
// The packed custom mask's offsets of each sequence.
|
|
int64_t const* customMaskOffsetsPtr;
|
|
// The first sparseMask offsets in the Kv sequence dimension.
|
|
int32_t const* firstSparseMaskOffsetsKvPtr;
|
|
// The counter for the multiCtasKv mode.
|
|
int32_t* multiCtasKvCounterPtr;
|
|
// The sequence length buffer for K/V.
|
|
int const* seqLensKvPtr;
|
|
// The cumulative sequence length buffer for Q and K/V
|
|
int const* cumSeqLensQPtr;
|
|
int const* cumSeqLensKvPtr;
|
|
// The kv page idx
|
|
int const* kvPageIdxPtr;
|
|
bool useGmemScale;
|
|
// The device output scale for FP8 quantization.
|
|
float const* outputScalePtr;
|
|
float outputScale;
|
|
// The device scaling factor for softmax (multiplied by log2 to use faster exp2)
|
|
float const* scaleSoftmaxLog2Ptr;
|
|
float scaleSoftmaxLog2;
|
|
// The device scale for KV scaling factor.
|
|
float const* kvSfScalePtr;
|
|
// The device scale for O scaling factor.
|
|
float const* oSfScalePtr;
|
|
// The scratch space for each CtaKv when the multiCtasKv mode is enabled.
|
|
// PartialO, partialMax and partialSum will be stored to the scratch space.
|
|
void* multiCtasKvScratchPtr;
|
|
// The softmax stats buffer.
|
|
// The softmax max/sum values will be stored to the buffer if it is not nullptr.
|
|
float2* softmaxStatsPtr;
|
|
// The LSE buffer.
|
|
float* lsePtr;
|
|
|
|
// Attention sink
|
|
float const* ptrAttentionSinks{nullptr};
|
|
// The output buffer.
|
|
void* oPtr;
|
|
// The output scaling factor buffer.
|
|
void* oSfPtr;
|
|
|
|
// The stride between different keys.
|
|
int kStrideKeysValues;
|
|
// The stride between different heads for K.
|
|
int kStrideHeads;
|
|
// The stride between different batches for K.
|
|
int kStrideBatch;
|
|
|
|
// The stride between different values.
|
|
int vStrideKeysValues;
|
|
// The stride between different heads for V.
|
|
int vStrideHeads;
|
|
// The stride between different batches for V.
|
|
int vStrideBatch;
|
|
|
|
// Head dimension for Q and K.
|
|
int mHeadDimQk;
|
|
// Head dimension for V.
|
|
int mHeadDimV;
|
|
// Number of heads for Q and K/V.
|
|
int mNumHeadsQ, mNumHeadsKv, mNumHeadsQPerKv;
|
|
// The batch size.
|
|
int mBatchSize;
|
|
// The max sequence length in the contiguous Kv cache.
|
|
int mMaxSeqLenCacheKv;
|
|
// The max q sequence length.
|
|
int mMaxSeqLenQ;
|
|
// The max kv sequence length.
|
|
int mMaxSeqLenKv;
|
|
// The attention window size for sliding window attention (sliding-window-attention is enabled
|
|
// when seqLenKv > mAttentionWindowSize).
|
|
int mAttentionWindowSize;
|
|
// The chunked attention size (chunked-context is enabled when seqLenKv > mChunkedAttentionSize).
|
|
int mChunkedAttentionSize;
|
|
// The sum of sequence lengths for Q and K/V. (Only used when mSupportsVarSeqLens = true)
|
|
int mSumOfSeqLensQ;
|
|
int mSumOfSeqLensKv;
|
|
// The maximum number of pages per sequence in the paged-kv buffer.
|
|
int mMaxNumPagesPerSeqKv;
|
|
// The number of tokens per pageKv.
|
|
int mNumTokensPerPage;
|
|
// The number of pages in memory pool.
|
|
int mNumPagesInMemPool;
|
|
// The number of multiProcessor for the GPU.
|
|
int mMultiProcessorCount;
|
|
// Scaling factor for Q.
|
|
float mScaleQ;
|
|
// Scaling factor for output.
|
|
float mScaleOutput;
|
|
// The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase
|
|
// kernel when inflight batching is enabled.
|
|
int mSfStartTokenIdx;
|
|
// The SF scale for Kv.
|
|
float mScaleSfKv;
|
|
// The SF scale for output.
|
|
float mScaleSfO;
|
|
// The cuda stream.
|
|
cudaStream_t stream;
|
|
// Whether to enable PDL (Programmatic Dependent Launch).
|
|
bool enable_pdl;
|
|
|
|
// set the attention mask type
|
|
TllmGenFmhaRunnerParams& setAttentionMaskType(std::int8_t maskType) {
|
|
// maskType is the enum of tensorrt_llm::kernels::ContextAttentionMaskType
|
|
// convert ContextAttentionMaskType to TrtllmGenAttentionMaskType
|
|
switch (maskType) {
|
|
case 0: // tensorrt_llm::kernels::ContextAttentionMaskType::PADDING
|
|
mMaskType = TrtllmGenAttentionMaskType::Dense;
|
|
break;
|
|
case 1: // tensorrt_llm::kernels::ContextAttentionMaskType::CAUSAL
|
|
mMaskType = TrtllmGenAttentionMaskType::Causal;
|
|
break;
|
|
case 2: // tensorrt_llm::kernels::ContextAttentionMaskType::SLIDING_OR_CHUNKED_CAUSAL
|
|
mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal;
|
|
break;
|
|
case 3: // tensorrt_llm::kernels::ContextAttentionMaskType::CUSTOM_MASK
|
|
mMaskType = TrtllmGenAttentionMaskType::Custom;
|
|
break;
|
|
default:
|
|
FLASHINFER_ERROR("Invalid attention mask type");
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
TllmGenFmhaRunnerParams() {
|
|
// NOTE(Zihao): all fields are POD types, so we can use memset to initialize them to zero
|
|
static_assert(std::is_standard_layout<TllmGenFmhaRunnerParams>::value,
|
|
"TllmGenFmhaRunnerParams must be a POD type (standard layout) for memset to be "
|
|
"safe.");
|
|
memset(this, 0, sizeof(TllmGenFmhaRunnerParams));
|
|
}
|
|
};
|
|
|
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
|
|
|
// Parameters that might be updated when selecting kernels.
|
|
|
|
struct TllmGenSelectKernelParams {
|
|
// The FMHA kernel type.
|
|
FmhaKernelType mKernelType;
|
|
// The headDimV per CTA, which is only used by MLA generation kernels currently.
|
|
int mHeadDimPerCtaV;
|
|
// The multiCtasKvMode.
|
|
MultiCtasKvMode mMultiCtasKvMode;
|
|
// Force using GmemRedution for the multiCtasKvMode.
|
|
bool mForceGmemReduction;
|
|
// The mask type.
|
|
TrtllmGenAttentionMaskType mMaskType;
|
|
// Reuse smemK for V or not (only work with MLA generation kernels).
|
|
bool mReuseSmemKForV;
|
|
// Do we need to select a new kernel as the parameters have been updated.
|
|
bool mSelectNewKernel;
|
|
// The tile scheduler.
|
|
TileScheduler mTileScheduler;
|
|
// The tile size for Kv.
|
|
int mTileSizeKv;
|
|
// Use 2 CTA MMA or not.
|
|
bool mUses2CtaMma;
|
|
|
|
// The constructor.
|
|
TllmGenSelectKernelParams(TllmGenFmhaRunnerParams params)
|
|
: mKernelType(params.mKernelType),
|
|
mHeadDimPerCtaV(params.mHeadDimV)
|
|
// Note the CgaSmemReduction will be enabled based on the heuristic.
|
|
,
|
|
mMultiCtasKvMode(params.mMultiCtasKvMode ? MultiCtasKvMode::GmemReduction
|
|
: MultiCtasKvMode::Disabled),
|
|
mForceGmemReduction(false),
|
|
mMaskType(params.mMaskType),
|
|
mReuseSmemKForV(false),
|
|
mSelectNewKernel(false),
|
|
mTileScheduler(params.mTileScheduler),
|
|
mTileSizeKv(128),
|
|
mUses2CtaMma(false) {};
|
|
};
|