/*************************************************************************************************** * Copyright (c) 2011-2024, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are not permit- * ted. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ #pragma once #ifdef _WIN32 #define _USE_MATH_DEFINES #include #endif #include #include #include #include #include "../common.h" #include "fmhaRunnerParams.h" //////////////////////////////////////////////////////////////////////////////////////////////////// using Dtype = Data_type; struct KernelParams { // TMA descriptor for Q. CUtensorMap tmaQ_; // TMA descriptor for K. CUtensorMap tmaK_; // TMA descriptor for V. CUtensorMap tmaV_; // The descriptor for O. CUtensorMap tmaO_; // For FP4 KV cache, additional scaling factors are needed. // TMA descriptor for K scaling factor. CUtensorMap tmaKSf_; // TMA descriptor for V scaling factor. CUtensorMap tmaVSf_; // grid dimensions, these might differ from actual grid the kernel is launched with // for persistent kernels on Hopper GPUs. int32_t logicalGridDimX, logicalGridDimY, logicalGridDimZ; // The output pointer (used by STG for last tile). void* ptrO; // The output SF pointer (used for FP4 output). void* ptrSfO; // The attention sinks pointer (additional value per head in the denominator of the softmax). float const* ptrAttentionSinks; // The cumulative sequence lengths for Q. int32_t const* ptrCumSeqLensQ; // The cumulative sequence lengths for K/V. int32_t const* ptrCumSeqLensKv; // The packed custom mask. uint32_t const* ptrCustomMask; // The packed custom mask's offsets of each sequence. int64_t const* ptrCustomMaskOffsets; // The debug output matrix O float* ptrDebugO; // The first sparseMask offsets in the Kv sequence dimension. int32_t const* ptrFirstSparseMaskOffsetsKv; // The counter for the multiCtasKv mode. int32_t* ptrMultiCtasKvCounter; // The device output scale for FP8 quantization. Only needed by trt-llm fp8 kernels as the sca- // les have to be on the device currently. float const* ptrOutputScale; // The page indexes of the paged-kv buffer with shape of [batchSize, 2, maxNumPagesPerSeq]. int32_t const* ptrPageIdxKv; // The partial matrix O for each CtaKv when the multiCtasKv mode is enabled. void* ptrPartialO; // The partial softmax stats (max/sum)for each CtaKv when the multiCtasKv mode is enabled. float2* ptrPartialStats; // The scaling factors for K. float const* ptrSageAttnSfsK; // The scaling factors for P. float const* ptrSageAttnSfsP; // The scaling factors for Q. float const* ptrSageAttnSfsQ; // The scaling factors for V. float const* ptrSageAttnSfsV; // The device scaling factor for softmax (multiplied by log2 to use faster exp2). Only needed by // trt-llm fp8 kernels as the scales have to be on the device currently. float const* ptrScaleSoftmaxLog2; // The SF scale for Kv on device. Only needed by trt-llm kernels as the scales have to be on the // device currently. float const* ptrScaleSfKv; // The SF scale for O on device. Only needed by trt-llm kernels as the scales have to be on the // device currently. float const* ptrScaleSfO; // The sequence lengths for K/V. Required by pagedKv kernels to avoid unnecessary computation // based on (ptrCumSeqLensKv[batchIdx + 1] - ptrCumSeqLensKv[batchIdx]). int32_t const* ptrSeqLensKv; // The softmax stats buffer. float2* ptrSoftmaxStats; // The attention window size for sliding window attention. int32_t mAttentionWindowSize; // The batch size int32_t mBatchSize; // The chunked attention size in log2. int32_t mChunkedAttentionSizeLog2; // The log of the Sage Attention block size for K. int32_t mLogNumEltsPerSageAttnBlkK; // The log of the Sage Attention block size for P. int32_t mLogNumEltsPerSageAttnBlkP; // The log of the Sage Attention block size for Q. int32_t mLogNumEltsPerSageAttnBlkQ; // The log of the Sage Attention block size for V. int32_t mLogNumEltsPerSageAttnBlkV; // The sequence lengths for Q and K/V. int32_t mMaxSeqLenQ, mMaxSeqLenKv; // The maximum number of CTAs for Q. int32_t mMaxNumCtasQ; // The maximum number of CTAs for K/V. int32_t mMaxNumCtasKv; // The maximum number of pages per sequence for paged-kv buffer. int32_t mMaxNumPagesPerSeqKv; // The number of heads for K/V. int32_t mNumHeadsKv; // The number of heads for Q. int32_t mNumHeadsQ; // The number of Q heads per K/V head (i.e. mNumHeadsQ / mNumHeadsKv). int32_t mNumHeadsQPerKv; // The hidden size of O. int64_t mNumHiddenEltsO; // The total number of pages in the paged-kv memory pool. int32_t mNumPagesInMemPool; // The output scale for FP8 quantization. float mOutputScale; // The scaling factor for softmax (multiplied by log2 to use faster exp2). float mScaleSoftmaxLog2; // The SF scale for Kv. float mScaleSfKv; // The SF scale for O. float mScaleSfO; // The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase // kernel when inflight batching is enabled in TRT-LLM. int32_t mStartTokenIdxSfO; // The sum of sequence lengths for Q and K/V. int32_t mSumOfSeqLensQ, mSumOfSeqLensKv; // Create the TMA shape/stride for Q. template static auto makeTmaShapeStrideQ(FmhaOptions const& options, bool groupsHeadsQ, int32_t tileSizeQ, int32_t numEltsInClampedHeadDimQ) { // // The Q has shape of [numTokens * numHeadsQPerKv, numHeadsKv * 1, headDim] // when grouping headsQ, otherwise it would be [numTokens, numHeadsQPerKv * numHeadsKv, // headDim]. // The number of grouped heads for the A matrix of MMA. int32_t numGroupedHeads{1}; if (groupsHeadsQ) { numGroupedHeads = std::min(tileSizeQ, options.mNumHeadsQPerKv); } // The number of heads. int32_t numHeads{options.mNumHeadsQ}; if (groupsHeadsQ) { numHeads /= numGroupedHeads; } // Make sure the math works. TORCH_CHECK(numHeads * numGroupedHeads == options.mNumHeadsQ, "internal error"); // The number of tokens. int32_t numTokens{options.mSumOfSeqLensQ}; // This maps to flattened TMA shape for Q: (headDim, numTokens, numHeads). auto shape = std::vector{ static_cast(options.mHeadDimQk), static_cast(numGroupedHeads), static_cast(numHeads), static_cast(numTokens)}; // The hidden dimension when the tensor contains only Q (i.e. not QKV packed). int32_t const hiddenDimQ{options.mNumHeadsQ * options.mHeadDimQk}; // The hidden dimension when the Q, K and V tensors are packed. int32_t hiddenDimQkv{hiddenDimQ}; if (isPackedQkv(options.mQkvLayout)) { TORCH_CHECK(!groupsHeadsQ, "internal error"); hiddenDimQkv += options.mNumHeadsKv * (options.mHeadDimQk + options.mHeadDimV); } // The stride between tokens. int32_t strideTokens{hiddenDimQkv}; // The stride between heads. int32_t strideHeads{groupsHeadsQ ? numGroupedHeads * options.mHeadDimQk : options.mHeadDimQk}; // The stride between grouped heads. int32_t strideGroupedHeads{options.mHeadDimQk}; // Assemble the stride (1, strideTokens, strideHeads). // Swap the first two dimension as mentioned before. auto stride = std::vector{1, static_cast(strideGroupedHeads), static_cast(strideHeads), static_cast(strideTokens)}; // The tile shape for TMA. auto tileShapes = std::vector{static_cast(numEltsInClampedHeadDimQ), 1, 1, static_cast(tileSizeQ)}; if (groupsHeadsQ) { if (isSpecDecodingGenerationKernel(options.mKernelType)) { TORCH_CHECK((tileSizeQ % numGroupedHeads == 0), "internal error"); tileShapes = std::vector{static_cast(numEltsInClampedHeadDimQ), static_cast(numGroupedHeads), 1, static_cast(tileSizeQ / numGroupedHeads)}; } else { tileShapes = std::vector{static_cast(numEltsInClampedHeadDimQ), static_cast(tileSizeQ), 1, 1}; } } return std::make_tuple(shape, stride, tileShapes); } // Create the TMA shape/stride for O. template static auto makeTmaShapeStrideO(FmhaOptions const& options) { // // TODO: refactor this as makeTmaShapeStrideQ when removing cutlass tma copy. // // The number of tokens. int32_t numTokens{options.mSumOfSeqLensQ}; // The number of heads per K/V head. int32_t numHeadsQPerKv{options.mNumHeadsQPerKv}; // The batch dimension. int32_t batchSize{1}; // The cute tensor shape for Q/O: (numTokens, headDim, ((numHeadsKv, numHeadsQPerKv), // batchSize)). This maps to flattened TMA shape for Q/O: (headDim, numTokens, numHeadsKv. // numHeadsQPerKv, batchSize). Note that TMA descriptor expects the first dimension's stride to // be 1, so swap the first two dimension so that the headDim dimension comes first. auto shape = std::vector{ static_cast(options.mHeadDimV), static_cast(numTokens), static_cast(options.mNumHeadsKv), static_cast(numHeadsQPerKv), static_cast(batchSize)}; // The hidden dimension. int32_t const hiddenDimO{options.mNumHeadsQ * options.mHeadDimV}; // The stride between tokens. int32_t strideTokens{hiddenDimO}; // The stride between Q heads. int32_t strideHeadsQ{options.mNumHeadsKv * options.mHeadDimV}; // The stride between sequences. int32_t strideBatch{0}; // The stride in between K/V heads. int32_t strideHeadsKv{options.mHeadDimV}; // Assemble the stride (strideTokens, 1, ((strideHeadsKv, strideHeadsQ), strideBatch)). // Swap the first two dimension as mentioned before. auto stride = std::vector{ 1, static_cast(strideTokens), static_cast(strideHeadsKv), static_cast(strideHeadsQ), static_cast(strideBatch)}; return std::make_tuple(shape, stride); } // Create the shape for K and V. template static auto makeShapeKv(FmhaOptions const& options, KernelParams const& params) { // The number of keys/vals. WARNING: The if/else-if are sorted by priority. int32_t numKeysVals{options.mMaxSeqLenKv}; if (isPagedKv(options.mQkvLayout)) { numKeysVals = options.mNumTokensPerPage; } else if (isContiguousKv(options.mQkvLayout)) { numKeysVals = options.mMaxSeqLenCacheKv; } else { numKeysVals = options.mSumOfSeqLensKv; } // The number of heads per K/V head (packed in the sequence length for mGroupsHeadsQ). int32_t numHeadsKv{options.mNumHeadsKv}; // The batch dimension. WARNING: The if/else-if are sorted by priority. int32_t batchSize{options.mBatchSize}; if (isPagedKv(options.mQkvLayout)) { batchSize = params.mNumPagesInMemPool; } else if (isContiguousKv(options.mQkvLayout)) { batchSize = options.mBatchSize; } else { batchSize = 1; } // Return the number of keys and batch. return std::make_tuple(numKeysVals, numHeadsKv, batchSize); } // Compute the strides for K and V. template static auto makeStrideKv(FmhaOptions const& options, bool isK) { int strideKeysVals = 0; int strideHeads = 0; int strideBatch = 0; if (isK) { strideKeysVals = options.kStrideKeysValues; strideHeads = options.kStrideHeads; strideBatch = options.kStrideBatch; } else { strideKeysVals = options.vStrideKeysValues; strideHeads = options.vStrideHeads; strideBatch = options.vStrideBatch; } // The 3 strides (the other ones are 1 and 0). return std::make_tuple(strideKeysVals, strideHeads, strideBatch); } // Create the TMA shape/stride for K. template static auto makeTmaShapeStrideKv(FmhaOptions const& options, KernelParams const& params, Data_type dtypeKv, bool isK) { // The shape elements. auto [numKeys, numHeadsQPerKv, batchSize] = makeShapeKv(options, params); // The stride elements. auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, isK); // The headDim. // Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv. int32_t headDim = isK ? options.mHeadDimQk : options.mHeadDimV; if (isPagedKv(options.mQkvLayout) || isContiguousKv(options.mQkvLayout)) { headDim = std::max(options.mHeadDimQk, options.mHeadDimV); } // For K, the cute layout: (numKeys, headDim, ((numHeadsQPerKv, numHeadsKv), // batchSize)):(strideKeys, _1, _0, strideHeads, strideBatch). Cute swaps the first two // dimension (to make sure stride of first dimension is 1) and ignores the numHeadsQPerKv // dimension (it's stride is always 0). For V, the headDim dimension is already the first // dimension so no swapping is needed. // Therefore, the resulting TMA layout is 4D: (headDim, numKeys, numHeadsKv, batchSize):(1, // strideKeys, strideHeads, strideBatch) // Note that for FP4 KV input, elements are stored as uint8_t, each packs 2 FP4 elements. // The column index and strides needs to divide by 2. auto const colIdxDivisor = dtypeKv == DATA_TYPE_E2M1 ? 2 : 1; auto shape = std::vector{ static_cast(headDim / colIdxDivisor), static_cast(numKeys), static_cast(options.mNumHeadsKv), static_cast(batchSize)}; auto stride = std::vector{1, static_cast(strideKeys / colIdxDivisor), static_cast(strideHeads / colIdxDivisor), static_cast(strideBatch / colIdxDivisor)}; return std::make_tuple(shape, stride); } // Create the TMA shape/stride for KV scaling factors. template static auto makeTmaShapeStrideKvSf(FmhaOptions const& options, KernelParams const& params, bool isK) { // The shape elements. auto [numKeys, numHeadsQPerKv, batchSize] = makeShapeKv(options, params); // The stride elements. auto [strideKeys, strideHeads, strideBatch] = makeStrideKv(options, isK); // The headDim. // Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv. int32_t headDim = isK ? options.mHeadDimQk : options.mHeadDimV; if (isPagedKv(options.mQkvLayout) || isContiguousKv(options.mQkvLayout)) { headDim = std::max(options.mHeadDimQk, options.mHeadDimV); } // The number of elements per SF. int32_t NumEltsPerSf = 16; // The KV shape is: (headDim, numKeys, numHeadsKv, batchSize) // Therefore, the KV SF shape should be (headDim / NumEltsPerSf, numKeys, numHeadsKv, // batchSize). Considering the TMA requires box width to be multiple of 16B, without changing // the underlying layout, we reshape into (16, numKeys * headDim / NumEltsPerSf / 16, // numHeadsKv, batchSize) // Note that it only works for pagedKv layout. TORCH_CHECK(isPagedKv(options.mQkvLayout), "The qkvLayout is not supported."); auto shape = std::vector{ 16, static_cast(numKeys * headDim / NumEltsPerSf / 16), static_cast(options.mNumHeadsKv), static_cast(batchSize)}; auto stride = std::vector{1, 16, static_cast(strideHeads / NumEltsPerSf), static_cast(strideBatch / NumEltsPerSf)}; return std::make_tuple(shape, stride); } // Prepare pointers for TMA descriptors. static std::tuple getDevicePtrs( TllmGenFmhaRunnerParams const& runnerParams, int32_t bytesPerElt) { // Declare the q, k, v ptrs. void const *qPtr{runnerParams.qPtr}, *kPtr{runnerParams.kPtr}, *vPtr{runnerParams.vPtr}; // Set Q, K and V pointer from packed QKV tensor. if (isPackedQkv(runnerParams.mQkvLayout)) { qPtr = runnerParams.qkvPtr; kPtr = reinterpret_cast(reinterpret_cast(runnerParams.qkvPtr) + runnerParams.mNumHeadsQ * runnerParams.mHeadDimQk * bytesPerElt); vPtr = reinterpret_cast(reinterpret_cast(runnerParams.qkvPtr) + (runnerParams.mNumHeadsQ + runnerParams.mNumHeadsKv) * runnerParams.mHeadDimQk * bytesPerElt); } // Set K and V pointer from pagedKv tensor. else if (isPagedKv(runnerParams.mQkvLayout)) { // Note that the offsets will be fully handled by the pageIdx buffer. kPtr = runnerParams.kPtr; vPtr = runnerParams.vPtr; } // Set K and V pointer from contiguousQAnddKv tensor. else if (isContiguousKv(runnerParams.mQkvLayout)) { kPtr = runnerParams.kvPtr; // The maximum headDim of K and V. // Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv. int32_t const maxHeadDimKv{std::max(runnerParams.mHeadDimQk, runnerParams.mHeadDimV)}; vPtr = reinterpret_cast( reinterpret_cast(runnerParams.kvPtr) + runnerParams.mNumHeadsKv * runnerParams.mMaxSeqLenCacheKv * maxHeadDimKv * bytesPerElt); } // Return the pointers. return std::make_tuple(qPtr, kPtr, vPtr); } // Build tma descriptors. template static CUtensorMap buildNdTmaDescriptor(FmhaOptions const& options, Data_type dtypeElt, std::vector const& shapes, std::vector const& strides, std::vector const& tileShapes, void* gmemAddr, bool swizzled = true) { CUtensorMap desc{}; // The data type. CUtensorMapDataType tmaDataFormat; if (dtypeElt == DATA_TYPE_E2M1 || dtypeElt == DATA_TYPE_E4M3) { tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_UINT8; } else if (dtypeElt == DATA_TYPE_FP16) { tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else if (dtypeElt == DATA_TYPE_BF16) { tmaDataFormat = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else { TORCH_CHECK(false, "Unexpected dtype %d", static_cast(dtypeElt)); } // The swizzle type. CUtensorMapSwizzle swizzleType; int32_t numBytesInLeadingDim = tileShapes[0] * get_size_in_bits(dtypeElt) / 8 /*bits*/; if (!swizzled) { swizzleType = CU_TENSOR_MAP_SWIZZLE_NONE; } else if ((numBytesInLeadingDim % 128) == 0) { swizzleType = CU_TENSOR_MAP_SWIZZLE_128B; } else if ((numBytesInLeadingDim % 64) == 0) { swizzleType = CU_TENSOR_MAP_SWIZZLE_64B; } else if ((numBytesInLeadingDim % 32) == 0) { swizzleType = CU_TENSOR_MAP_SWIZZLE_32B; } else { TORCH_CHECK(false, "Unexpected numBytesInLeadingDim %d", numBytesInLeadingDim); } // Check gmem address must be 16B-aligned TORCH_CHECK((reinterpret_cast(gmemAddr) & 0b1111) == 0); // Check shape must be in range [1, 2^32] int32_t dim = shapes.size(); // Max five dimension and min 3 dimension. TORCH_CHECK((dim <= 5) && (dim >= 3)); // Check shape range. for (int32_t ii = 0; ii < dim; ++ii) { TORCH_CHECK(shapes[ii] >= (uint64_t(1))); // Size must be min 1 TORCH_CHECK(shapes[ii] <= (uint64_t(1) << 32)); // Size must be max 2^32 } // TMA descriptor does not store the zeroth stride and assumes it is 1. TORCH_CHECK(static_cast(strides.size()) == dim); TORCH_CHECK(strides[0] == 1); // Build strides in bytes. // cuTensorMapEncodeTiled ignores the stride of the first dimension (implicitly 1). std::vector stridesInBytes(dim - 1); for (int32_t ii = 0; ii < dim - 1; ++ii) { stridesInBytes[ii] = strides[ii + 1] * std::max(get_size_in_bits(dtypeElt), static_cast(8)) / 8 /*bit*/; } // Set tile strides to 0; std::vector tileStrides(dim, 1); // Build the descriptor. CUresult result = cuTensorMapEncodeTiled(&desc, tmaDataFormat, /*tensorRank=*/dim, gmemAddr, shapes.data(), stridesInBytes.data(), tileShapes.data(), tileStrides.data(), /*interleave=*/CU_TENSOR_MAP_INTERLEAVE_NONE, swizzleType, /*l2Promotion=*/CU_TENSOR_MAP_L2_PROMOTION_L2_128B, /*oobFill=*/CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); if (result != CUDA_SUCCESS) { char const* err_str; cuGetErrorString(result, &err_str); std::cerr << "Error: Failed to initialize the TMA descriptor due to " << err_str << std::endl; std::cerr << "tmaFormat: " << static_cast(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; std::cerr << "Shape: " << shapes[0] << " " << shapes[1] << " " << shapes[2] << " " << shapes[3] << " " << shapes[4] << std::endl; std::cerr << "Stride: " << stridesInBytes[0] << " " << stridesInBytes[1] << " " << stridesInBytes[2] << " " << stridesInBytes[3] << std::endl; std::cerr << "tileShapes: " << tileShapes[0] << " " << tileShapes[1] << " " << tileShapes[2] << " " << tileShapes[3] << " " << tileShapes[4] << std::endl; std::cerr << "tileStrides: " << tileStrides[0] << " " << tileStrides[1] << " " << tileStrides[2] << " " << tileStrides[3] << " " << tileStrides[4] << std::endl; std::cerr << "swizzleType: " << int(swizzleType) << std::endl; TORCH_CHECK(false); } return desc; } // Setup the kernel parameters. template static KernelParams setKernelParams(FmhaOptions_ const& options, KernelMeta const& kernelMeta, int32_t maxNumCtasQ, int32_t maxNumCtasKv) { // Create the return struct. KernelParams params; // Get the device pointers for TMA descriptors. auto [qPtr, kPtr, vPtr] = getDevicePtrs(options, get_size_in_bytes(kernelMeta.mDataTypeKv)); // The maximum headDim of K and V. // Note that contiguousKv or pagedKv will pad K and V to maxHeadDimKv. int32_t const maxHeadDimKv{std::max(options.mHeadDimQk, options.mHeadDimV)}; // Set the number of pages in the memory pool for paged K/V cache. if (isPagedKv(options.mQkvLayout)) { params.mNumPagesInMemPool = options.mNumPagesInMemPool == 0 ? options.mMaxNumPagesPerSeqKv * 2 * options.mBatchSize : options.mNumPagesInMemPool; } // The number of elements in 128B for Q. int32_t numEltsIn128BQ = (128 * 8) / get_size_in_bits(kernelMeta.mDataTypeQ); // The number of head elts (per token) in each block of shared memory. int32_t numEltsInClampedHeadDimQ = std::min(numEltsIn128BQ, options.mHeadDimQk); // Shape/stride for gmem tensor Q. auto [shapeQ, strideQ, tileShapeQ] = makeTmaShapeStrideQ( options, kernelMeta.mGroupsHeadsQ, kernelMeta.mTileSizeQ, numEltsInClampedHeadDimQ); // Build tma descriptor for Q. params.tmaQ_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeQ, shapeQ, strideQ, tileShapeQ, const_cast(qPtr)); // The number of keys per tile. int32_t numKeysPerTile = isPagedKv(options.mQkvLayout) ? std::min(options.mNumTokensPerPage, kernelMeta.mTileSizeKv) : kernelMeta.mTileSizeKv; // The number of elements in 128B for Q. int32_t numEltsIn128BKv = (128 * 8) / get_size_in_bits(kernelMeta.mDataTypeKv); // The number of head elts (per token) in each block of shared memory (see above explanation). int32_t numEltsInClampedHeadDimKv = std::min(numEltsIn128BKv, maxHeadDimKv); // Shape/stride for gmem tensor Kv. auto [shapeK, strideK] = makeTmaShapeStrideKv(options, params, kernelMeta.mDataTypeKv, /*isK*/ true); auto [shapeV, strideV] = makeTmaShapeStrideKv(options, params, kernelMeta.mDataTypeKv, /*isK*/ false); // Build tma descriptor for K. // Do we have to transform K/V before MMA? bool const transformsKv{kernelMeta.mDataTypeKv != kernelMeta.mDataTypeQ}; // Note that for FP4 KV input, elements are stored as uint8_t, each packs 2 FP4 elements. auto const numEltsDivisor = kernelMeta.mDataTypeKv == DATA_TYPE_E2M1 ? 2 : 1; // The tileShapes for K/V. std::vector tileShapeKv(shapeK.size(), 1); tileShapeKv[0] = numEltsInClampedHeadDimKv / numEltsDivisor; tileShapeKv[1] = numKeysPerTile; // Build tma descriptor for K. params.tmaK_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeKv, shapeK, strideK, tileShapeKv, const_cast(kPtr), /*swizzled = */ !transformsKv); // Build tma descriptor for V. params.tmaV_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeKv, shapeV, strideV, tileShapeKv, const_cast(vPtr), /*swizzled = */ !transformsKv); // If the KV dtype is E2m1, additional scaling factors are needed for dequant. if (kernelMeta.mDataTypeKv == DATA_TYPE_E2M1) { // The number of elements per SF. int32_t NumEltsPerSf = 16; // Compute the shape and stride for SF tensor. // FIXME: assume K and V uses the same shape. auto [shapeKvSf, strideKvSf] = makeTmaShapeStrideKvSf(options, params, /*isK*/ true); // The tileShapes for K/V. std::vector tileShapeKvSf(shapeKvSf.size(), 1); tileShapeKvSf[0] = 16; tileShapeKvSf[1] = numKeysPerTile * maxHeadDimKv / NumEltsPerSf / 16; // The tile box is reshaped from (headDim / NumEltsPerSf, tileSizeKv) into (16, tileSizeKv * // headDim / NumEltsPerSf / 16). See makeTmaShapeStrideKvSf for details. Build tma descriptor // for K SF. params.tmaKSf_ = buildNdTmaDescriptor(options, DATA_TYPE_E4M3, shapeKvSf, strideKvSf, tileShapeKvSf, const_cast(options.kSfBasePtr), /*swizzled = */ false); // Build tma descriptor for V SF. params.tmaVSf_ = buildNdTmaDescriptor(options, DATA_TYPE_E4M3, shapeKvSf, strideKvSf, tileShapeKvSf, const_cast(options.vSfBasePtr), /*swizzled = */ false); } // Shape/stride for gmem tensor O. auto [shapeO, strideO] = makeTmaShapeStrideO(options); // The tileShapes for O. std::vector tileShapeO(shapeO.size(), 1); tileShapeO[0] = numEltsInClampedHeadDimQ; tileShapeO[1] = kernelMeta.mTileSizeQ; // Build tma descriptor for O. params.tmaO_ = buildNdTmaDescriptor(options, kernelMeta.mDataTypeQ, shapeO, strideO, tileShapeO, const_cast(options.oPtr)); // Set the other kernel parameters. params.ptrCumSeqLensQ = options.cumSeqLensQPtr; params.ptrCumSeqLensKv = options.cumSeqLensKvPtr; // The packed custom mask. params.ptrCustomMask = options.customMaskPtr; // The packed custom mask's offsets of each sequence. params.ptrCustomMaskOffsets = options.customMaskOffsetsPtr; // The first sparseMask offsets in the Kv sequence dimension. params.ptrFirstSparseMaskOffsetsKv = options.firstSparseMaskOffsetsKvPtr; // The output buffer. params.ptrO = options.oPtr; // The output scaling factor buffer. params.ptrSfO = options.oSfPtr; // TRT-LLM restrictions: the quantization scales must be on the device. params.ptrOutputScale = options.outputScalePtr; // The sequence lengths for Kv. params.ptrSeqLensKv = options.seqLensKvPtr; // Attention sink params.ptrAttentionSinks = options.ptrAttentionSinks; // The partial buffers' pointers when the multiCtasKv mode is enabled. int64_t partialStatsBufferSize = options.mMultiProcessorCount * kernelMeta.mStepQ; params.ptrMultiCtasKvCounter = options.multiCtasKvCounterPtr; params.ptrPartialStats = reinterpret_cast(options.multiCtasKvScratchPtr); params.ptrPartialO = params.ptrPartialStats + partialStatsBufferSize; params.ptrPageIdxKv = options.kvPageIdxPtr; params.ptrScaleSoftmaxLog2 = options.scaleSoftmaxLog2Ptr; params.ptrScaleSfKv = options.kvSfScalePtr; params.ptrScaleSfO = options.oSfScalePtr; params.mScaleSfO = options.mScaleSfO; params.mAttentionWindowSize = options.mAttentionWindowSize; if (isSlidingOrChunkedCausalMask( static_cast(kernelMeta.mMaskType)) && options.mChunkedAttentionSize != INT_MAX) { TORCH_CHECK((options.mChunkedAttentionSize & (options.mChunkedAttentionSize - 1)) == 0, "Chunked attention size must be a power of 2"); params.mChunkedAttentionSizeLog2 = std::log2(options.mChunkedAttentionSize); } else { // Default 0 means that chunked attention is disabled. params.mChunkedAttentionSizeLog2 = 0; } params.mMaxSeqLenQ = options.mMaxSeqLenQ; params.mMaxSeqLenKv = options.mMaxSeqLenKv; params.mMaxNumCtasQ = maxNumCtasQ; params.mMaxNumCtasKv = maxNumCtasKv; params.mMaxNumPagesPerSeqKv = options.mMaxNumPagesPerSeqKv; // TODO: just use mMaxSeqLenQ for number of MTP tokens. params.mSumOfSeqLensQ = options.mSumOfSeqLensQ; params.mSumOfSeqLensKv = options.mSumOfSeqLensKv; params.mBatchSize = options.mBatchSize; params.mChunkedAttentionSizeLog2 = 0; params.mNumHeadsQ = options.mNumHeadsQ; params.mNumHeadsKv = options.mNumHeadsKv; params.mNumHeadsQPerKv = options.mNumHeadsQPerKv; params.mNumHiddenEltsO = options.mNumHeadsQ * options.mHeadDimQk; // todo(Yingyi): might take a scalar tensor later params.mOutputScale = options.outputScale; params.mScaleSoftmaxLog2 = options.scaleSoftmaxLog2; params.mStartTokenIdxSfO = options.mSfStartTokenIdx; params.mScaleSfKv = options.mScaleSfKv; params.ptrSoftmaxStats = options.softmaxStatsPtr; return params; } };