sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h

364 lines
12 KiB
C++
Executable File

/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime.h>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include "flashinfer/exception.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
// The attention mask types.
enum class TrtllmGenAttentionMaskType {
// Dense mask.
Dense = 0,
// Causal mask.
Causal,
// Sliding window or chunked causal mask.
SlidingOrChunkedCausal,
// Custom mask.
Custom
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Helper functions to check the mask type.
#define ATTENTION_MASK_TYPE_FUNCTION(MaskType) \
inline bool is##MaskType##Mask(TrtllmGenAttentionMaskType maskType) { \
return (maskType == TrtllmGenAttentionMaskType::MaskType); \
}
ATTENTION_MASK_TYPE_FUNCTION(Dense)
ATTENTION_MASK_TYPE_FUNCTION(Causal)
ATTENTION_MASK_TYPE_FUNCTION(SlidingOrChunkedCausal)
ATTENTION_MASK_TYPE_FUNCTION(Custom)
#undef ATTENTION_MASK_TYPE_FUNCTION
////////////////////////////////////////////////////////////////////////////////////////////////////
enum class FmhaKernelType {
// The context-phase kernels.
Context = 0,
// Choose the best generation kernel based on the heuristic:
// use SwapsMmaAbForGeneration kernels when numHeadsQPerKv <= 16, otherwise
// KeepsMmaAbForGeneration.
Generation = 1,
// Swap tensor A and tensor B of Mma, which only supports numHeadsQPerKv <= 16.
SwapsMmaAbForGeneration,
// Keep tensor A and tensor B of Mma.
KeepsMmaAbForGeneration,
// Speculative decoding (Medusa and Eagle) generation-phase attention kernels, where seqLenQ > 1.
SpecDecodingGeneration
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Helper functions to check the fmha kernel type.
#define FMHA_KERNEL_TYPE_FUNCTION(KernelType) \
inline bool is##KernelType##Kernel(FmhaKernelType kernelType) { \
return (kernelType == FmhaKernelType::KernelType); \
}
FMHA_KERNEL_TYPE_FUNCTION(Context)
FMHA_KERNEL_TYPE_FUNCTION(Generation)
FMHA_KERNEL_TYPE_FUNCTION(SwapsMmaAbForGeneration)
FMHA_KERNEL_TYPE_FUNCTION(KeepsMmaAbForGeneration)
FMHA_KERNEL_TYPE_FUNCTION(SpecDecodingGeneration)
#undef QKV_LAYOUT_FUNCTION
////////////////////////////////////////////////////////////////////////////////////////////////////
// Note that (batchSize, seqLen) dimensions will be packed as sumOfSeqLens without paddings for
// variable sequence lengths.
enum class QkvLayout {
// SeparateQkv: separate Q, K and V buffers.
// Each has the shape: [batchSize, seqLen, numHeads, headDim].
SeparateQkv = 0,
// PackedQkv: single buffer for Q, K and V.
// Shape: [batchSize, seqLen, numHeadsQ + 2*numHeadsKv, headDim].
PackedQkv,
// Paged buffer for K and V. Its shape is [batchSize, 2, maxNumPagesPerSeq]. The 2 corresponds to
// K
// and V. That buffer stores the logical page index of the paged-KV memory pool. Each "page" of
// that
// pool is a contiguous buffer of shape [numHeadsKv, pageSize, headDim].
PagedKv,
// ContiguousKv:
// Contiguous buffer for Q with shape [batchSize, seqLen, numHeads, headDim].
// Contiguous buffer for Kv with shape [batchSize, seqLen, 2 * numHeads, headDim].
ContiguousKv,
};
// Helper functions to check the QkvLayout type.
#define QKV_LAYOUT_FUNCTION(LayoutType) \
inline bool is##LayoutType(QkvLayout qkvLayout) { return (qkvLayout == QkvLayout::LayoutType); }
QKV_LAYOUT_FUNCTION(SeparateQkv)
QKV_LAYOUT_FUNCTION(PackedQkv)
QKV_LAYOUT_FUNCTION(PagedKv)
QKV_LAYOUT_FUNCTION(ContiguousKv)
#undef QKV_LAYOUT_FUNCTION
////////////////////////////////////////////////////////////////////////////////////////////////////
enum class TileScheduler {
// Static scheduler (Non-persistent).
Static = 0,
// Persistent scheduler.
Persistent
};
////////////////////////////////////////////////////////////////////////////////////////////////////
enum class MultiCtasKvMode {
// No multiCtasKvMode.
Disabled = 0,
// Do the reduction through the global memory and atomic counters.
GmemReduction,
// Do the reduction through the CGA remote shared memory.
CgaSmemReduction
};
// Helper function to check if the multiCtasKv is enabled.
inline bool isMultiCtasKvEnabled(MultiCtasKvMode multiCtasKvMode) {
return multiCtasKvMode != MultiCtasKvMode::Disabled;
}
// Helper function to check the multiCtasKvMode type.
#define MULTI_CTAS_KV_MODE_FUNCTION(Type) \
inline bool is##Type(MultiCtasKvMode multiCtasKvMode) { \
return (multiCtasKvMode == MultiCtasKvMode::Type); \
}
MULTI_CTAS_KV_MODE_FUNCTION(Disabled)
MULTI_CTAS_KV_MODE_FUNCTION(GmemReduction)
MULTI_CTAS_KV_MODE_FUNCTION(CgaSmemReduction)
#undef MULTI_CTAS_KV_MODE_FUNCTION
////////////////////////////////////////////////////////////////////////////////////////////////////
struct TllmGenFmhaRunnerParams {
// Input layout.
QkvLayout mQkvLayout;
// Attention mask type.
TrtllmGenAttentionMaskType mMaskType;
// The kernel type.
FmhaKernelType mKernelType;
// The tile scheduler.
TileScheduler mTileScheduler;
// The multiCtasKvMode (i.e. multiBlockMode).
bool mMultiCtasKvMode;
// Input QKV buffers.
void const* qPtr;
void const* kPtr;
void const* vPtr;
// Packed KV buffer
void const* kvPtr;
// Packed QKV buffer
void const* qkvPtr;
// The scaling factor pointer of K.
void const* kSfBasePtr;
// The scaling factor pointer of V.
void const* vSfBasePtr;
// The custom mask ptr.
uint32_t const* customMaskPtr;
// The packed custom mask's offsets of each sequence.
int64_t const* customMaskOffsetsPtr;
// The first sparseMask offsets in the Kv sequence dimension.
int32_t const* firstSparseMaskOffsetsKvPtr;
// The counter for the multiCtasKv mode.
int32_t* multiCtasKvCounterPtr;
// The sequence length buffer for K/V.
int const* seqLensKvPtr;
// The cumulative sequence length buffer for Q and K/V
int const* cumSeqLensQPtr;
int const* cumSeqLensKvPtr;
// The kv page idx
int const* kvPageIdxPtr;
bool useGmemScale;
// The device output scale for FP8 quantization.
float const* outputScalePtr;
float outputScale;
// The device scaling factor for softmax (multiplied by log2 to use faster exp2)
float const* scaleSoftmaxLog2Ptr;
float scaleSoftmaxLog2;
// The device scale for KV scaling factor.
float const* kvSfScalePtr;
// The device scale for O scaling factor.
float const* oSfScalePtr;
// The scratch space for each CtaKv when the multiCtasKv mode is enabled.
// PartialO, partialMax and partialSum will be stored to the scratch space.
void* multiCtasKvScratchPtr;
// The softmax stats buffer.
// The softmax max/sum values will be stored to the buffer if it is not nullptr.
float2* softmaxStatsPtr;
// The LSE buffer.
float* lsePtr;
// Attention sink
float const* ptrAttentionSinks{nullptr};
// The output buffer.
void* oPtr;
// The output scaling factor buffer.
void* oSfPtr;
// The stride between different keys.
int kStrideKeysValues;
// The stride between different heads for K.
int kStrideHeads;
// The stride between different batches for K.
int kStrideBatch;
// The stride between different values.
int vStrideKeysValues;
// The stride between different heads for V.
int vStrideHeads;
// The stride between different batches for V.
int vStrideBatch;
// Head dimension for Q and K.
int mHeadDimQk;
// Head dimension for V.
int mHeadDimV;
// Number of heads for Q and K/V.
int mNumHeadsQ, mNumHeadsKv, mNumHeadsQPerKv;
// The batch size.
int mBatchSize;
// The max sequence length in the contiguous Kv cache.
int mMaxSeqLenCacheKv;
// The max q sequence length.
int mMaxSeqLenQ;
// The max kv sequence length.
int mMaxSeqLenKv;
// The attention window size for sliding window attention (sliding-window-attention is enabled
// when seqLenKv > mAttentionWindowSize).
int mAttentionWindowSize;
// The chunked attention size (chunked-context is enabled when seqLenKv > mChunkedAttentionSize).
int mChunkedAttentionSize;
// The sum of sequence lengths for Q and K/V. (Only used when mSupportsVarSeqLens = true)
int mSumOfSeqLensQ;
int mSumOfSeqLensKv;
// The maximum number of pages per sequence in the paged-kv buffer.
int mMaxNumPagesPerSeqKv;
// The number of tokens per pageKv.
int mNumTokensPerPage;
// The number of pages in memory pool.
int mNumPagesInMemPool;
// The number of multiProcessor for the GPU.
int mMultiProcessorCount;
// Scaling factor for Q.
float mScaleQ;
// Scaling factor for output.
float mScaleOutput;
// The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase
// kernel when inflight batching is enabled.
int mSfStartTokenIdx;
// The SF scale for Kv.
float mScaleSfKv;
// The SF scale for output.
float mScaleSfO;
// The cuda stream.
cudaStream_t stream;
// Whether to enable PDL (Programmatic Dependent Launch).
bool enable_pdl;
// set the attention mask type
TllmGenFmhaRunnerParams& setAttentionMaskType(std::int8_t maskType) {
// maskType is the enum of tensorrt_llm::kernels::ContextAttentionMaskType
// convert ContextAttentionMaskType to TrtllmGenAttentionMaskType
switch (maskType) {
case 0: // tensorrt_llm::kernels::ContextAttentionMaskType::PADDING
mMaskType = TrtllmGenAttentionMaskType::Dense;
break;
case 1: // tensorrt_llm::kernels::ContextAttentionMaskType::CAUSAL
mMaskType = TrtllmGenAttentionMaskType::Causal;
break;
case 2: // tensorrt_llm::kernels::ContextAttentionMaskType::SLIDING_OR_CHUNKED_CAUSAL
mMaskType = TrtllmGenAttentionMaskType::SlidingOrChunkedCausal;
break;
case 3: // tensorrt_llm::kernels::ContextAttentionMaskType::CUSTOM_MASK
mMaskType = TrtllmGenAttentionMaskType::Custom;
break;
default:
FLASHINFER_ERROR("Invalid attention mask type");
}
return *this;
}
TllmGenFmhaRunnerParams() {
// NOTE(Zihao): all fields are POD types, so we can use memset to initialize them to zero
static_assert(std::is_standard_layout<TllmGenFmhaRunnerParams>::value,
"TllmGenFmhaRunnerParams must be a POD type (standard layout) for memset to be "
"safe.");
memset(this, 0, sizeof(TllmGenFmhaRunnerParams));
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// Parameters that might be updated when selecting kernels.
struct TllmGenSelectKernelParams {
// The FMHA kernel type.
FmhaKernelType mKernelType;
// The headDimV per CTA, which is only used by MLA generation kernels currently.
int mHeadDimPerCtaV;
// The multiCtasKvMode.
MultiCtasKvMode mMultiCtasKvMode;
// Force using GmemRedution for the multiCtasKvMode.
bool mForceGmemReduction;
// The mask type.
TrtllmGenAttentionMaskType mMaskType;
// Reuse smemK for V or not (only work with MLA generation kernels).
bool mReuseSmemKForV;
// Do we need to select a new kernel as the parameters have been updated.
bool mSelectNewKernel;
// The tile scheduler.
TileScheduler mTileScheduler;
// The tile size for Kv.
int mTileSizeKv;
// Use 2 CTA MMA or not.
bool mUses2CtaMma;
// The constructor.
TllmGenSelectKernelParams(TllmGenFmhaRunnerParams params)
: mKernelType(params.mKernelType),
mHeadDimPerCtaV(params.mHeadDimV)
// Note the CgaSmemReduction will be enabled based on the heuristic.
,
mMultiCtasKvMode(params.mMultiCtasKvMode ? MultiCtasKvMode::GmemReduction
: MultiCtasKvMode::Disabled),
mForceGmemReduction(false),
mMaskType(params.mMaskType),
mReuseSmemKForV(false),
mSelectNewKernel(false),
mTileScheduler(params.mTileScheduler),
mTileSizeKv(128),
mUses2CtaMma(false) {};
};