/* * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" namespace moe::dev::routing { namespace routingLlama4 { //////////////////////////////////////////////////////////////////////////////////////////////////// static constexpr int NumThreads = 1024; static constexpr int NumWarps = NumThreads / WarpSize; static constexpr int MaxNumTopExperts = 1; static constexpr int MaxNumExperts = 128; static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; static constexpr int WarpKernelSmemStride = 33; // with further optimization to `routingIndicesWarpKernel`, this limit may // increase. For now, it is a good cut-off point for when the block-wise // operations are more efficient end-to-end. static constexpr int WarpKernelMaxNumTokens = 4; //////////////////////////////////////////////////////////////////////////////////////////////////// template __forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile const& warp, DataType (&warpMaxScore)[MaxNumTopExperts], int32_t (&warpMaxExpertIdx)[MaxNumTopExperts], int32_t const laneIdx, int32_t const numExperts, DataType const* ptrScores) { DataType minScore = DataType{-INFINITY}; DataType maxScore = minScore; int32_t maxExpertIdx{0}; using DataTypeVec = std::conditional_t; // Non-vectorized loading: directly access ptrScores with expertIdx for (int i = 0; i < VecSize; ++i) { auto expertIdx = i * WarpSize + laneIdx; auto newScore = expertIdx < numExperts ? ptrScores[expertIdx] : minScore; // note: use `>=` s.t. highest index always wins, just like in `reduceTopK` if (newScore > maxScore) { maxScore = newScore; maxExpertIdx = expertIdx; } } topk::reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore); } //////////////////////////////////////////////////////////////////////////////////////////////////// template __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParams params) { // types used in this kernel using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; using TypePacked = PackedScoreIdx; // use the default cub warp-scan, with shfl using Scan = cub::WarpScan; __shared__ typename Scan::TempStorage tempStorage; // each thread encodes 4 experts in one `int32_t`. The assumption is that // we don't have more than 127 tokens, but `WarpKernelMaxNumTokens` must be // smaller than that because other approaches will be more efficient for // 127 tokens. static constexpr int ExpertsPerThread = sizeof(int32_t); static_assert(WarpKernelMaxNumTokens <= 127); // this is a full table of which token is routed to which expert. // the assumption here is that there are no more than 128 experts. // we use a stride of 33 instead of 32 to avoid shared memory bank conflicts. __shared__ int32_t __attribute(( aligned(128))) smemExpertTokenCountFull[WarpKernelMaxNumTokens][WarpKernelSmemStride]; static_assert(WarpKernelSmemStride == WarpSize + 1); static_assert(MaxNumExperts / sizeof(int32_t) <= WarpSize); // values needed for the top-1 reduction, if required InputT minScore = InputT{-INFINITY}; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); #pragma unroll for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx) { // reset full shared memory field to 0 smemExpertTokenCountFull[tokenIdx][threadIdx.x] = 0; } __syncwarp(); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // then wait on primary grid if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); } #endif if (params.mPtrScores != nullptr) { // if we use `mPtrScores` as input, we need to perform the top-1 reduction // for each token, we load the scores then use `reduceTopK` for this. // each thread works on 4 experts, so a local reduction is done before for (int tokenIdx = 0; tokenIdx < params.mNumTokens; ++tokenIdx) { auto scoreOffset = tokenIdx * params.mNumExperts; int32_t warpMaxExpertIdx[MaxNumTopExperts]; InputT warpMaxScore[MaxNumTopExperts]; // Use routingTopKExperts function instead of inline logic routingTopKExperts(warp, warpMaxScore, warpMaxExpertIdx, threadIdx.x, params.mNumExperts, params.mPtrScores + scoreOffset); if (cute::elect_one_sync()) { // one thread updates the count linking token to chosen expert auto expertTokenCount = 0; setBits(expertTokenCount, 1, warpMaxExpertIdx[0] % ExpertsPerThread); smemExpertTokenCountFull[tokenIdx][warpMaxExpertIdx[0] / ExpertsPerThread] = expertTokenCount; // we also compute the final score here and write it out if required auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; if (params.mPtrExpertWeights != nullptr) { params.mPtrExpertWeights[tokenIdx] = finalScore; } } } } else { // if we do not have `mPtrScores` as input, we expect that `mPtrExpertWeights` // contains the top-1 packed score and index already. // Each thread represents a token here, and we extract the relevant score // The assumption is that the #tokens is limited by warp-size static_assert(WarpKernelMaxNumTokens <= WarpSize); TypePacked scoreIdx = threadIdx.x < params.mNumTokens ? params.mPtrExpertIdx[threadIdx.x] : TypePacked{}; int32_t expertTokenCount = 0; setBits(expertTokenCount, 1, scoreIdx.idx % ExpertsPerThread); if (threadIdx.x < params.mNumTokens) { smemExpertTokenCountFull[threadIdx.x][scoreIdx.idx / ExpertsPerThread] = expertTokenCount; } // we also compute the final score here and write it out if required auto finalScore = OutputT{sigmoid_accurate(float{scoreIdx.score})}; if (params.mPtrExpertWeights != nullptr && threadIdx.x < params.mNumTokens) { params.mPtrExpertWeights[threadIdx.x] = finalScore; } } // make the full table available to all threads __syncwarp(); // at this point, each thread keeps a count of its 4 assigned experts in // `expertCount`, as well as the offsets for all tokens w.r.t. these 4 experts // in `expertOffset`. int32_t expertCount = 0; int32_t expertOffset[WarpKernelMaxNumTokens + 1]; #pragma unroll for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens + 1; ++tokenIdx) { if (tokenIdx > params.mNumTokens) break; // simple reduction for `expertCount`, and scan for `expertOffset` auto expertTokenCount = tokenIdx < params.mNumTokens ? smemExpertTokenCountFull[tokenIdx][threadIdx.x] : 0; expertOffset[tokenIdx] = expertCount; expertCount += expertTokenCount; } // at this point, we are ready for the scan across all experts to get the // thread-wise offsets across experts // first, we need to reduce across our 4 experts into `numCta` int32_t numCta = 0; #pragma unroll for (int ii = 0; ii < ExpertsPerThread; ++ii) { auto count = getBits(expertCount, ii); numCta += divUpLog2(count, params.mPaddingLog2); } // second, we perform the exclusive sum across the warp int32_t ctaOffset; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); // finally, we perform a scan across our local experts, starting with the // warp-wide scan result (`ctaOffset`) auto ctaOffsetExp = ctaOffset; #pragma unroll for (int ii = 0; ii < ExpertsPerThread; ++ii) { auto count = getBits(expertCount, ii); auto finalNumCta = divUpLog2(count, params.mPaddingLog2); auto expertIdx = threadIdx.x * ExpertsPerThread + ii; // during the scan for expert offsets, we can already write out // both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit` for (int cta = 0; cta < finalNumCta; ++cta) { params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx; params.mPtrCtaIdxXyToMnLimit[ctaOffsetExp + cta] = min(mulLog2(ctaOffsetExp + cta + 1, params.mPaddingLog2), mulLog2(ctaOffsetExp, params.mPaddingLog2) + count); } ctaOffsetExp += finalNumCta; } // at this point, we can write out padded count from the warp-aggregate if (cute::elect_one_sync()) { const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) #if !defined(PDL_PROFILE) || PDL_PROFILE == 0 // we can trigger the next kernel at this point if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif #endif // at this point, all values for offsets are ready, except the final offsets // within the padded index (`permutedIdx`) // for this, we perform a scan similar to the one directly after the warp-scan: // here, we keep the local offset for each of the thread's experts in a field // of registers auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; int32_t finalExpertOffset[ExpertsPerThread]; finalExpertOffset[0] = mulLog2(ctaOffset, params.mPaddingLog2); #pragma unroll for (int ii = 1; ii < ExpertsPerThread; ++ii) { finalExpertOffset[ii] = finalExpertOffset[ii - 1] + divUpMulLog2(getBits(expertCount, ii - 1), params.mPaddingLog2); } #pragma unroll for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx) { // at this point, we can calculate the final index: // we simply loop over all tokens, and all experts assigned to this thread. // For each pair, we determine whether that token was routed to that expert // based on whether the offset for that token changed. // we can then easily compute the final `expertIdx` and `permutedIdx` relative // to this token and expert, and write them out. if (tokenIdx >= params.mNumTokens) break; #pragma unroll for (int ii = 0; ii < ExpertsPerThread; ++ii) { // determine whether the offset for this expert and token changes auto localOffsetToken = getBits(expertOffset[tokenIdx], ii); auto isTokenRouted = getBits(expertOffset[tokenIdx + 1], ii) > localOffsetToken; // the expert index of this expert auto expertIdx = threadIdx.x * ExpertsPerThread + ii; auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; // the permuted index: we add the local offset relative to this expert and token // to the global offset from the scan for this expert auto permutedIdx = isLocalExpert ? finalExpertOffset[ii] + localOffsetToken : int32_t{-1}; // write out `mPtrExpandedIdxToPermutedIdx` if required if (params.mPtrExpandedIdxToPermutedIdx != nullptr && isTokenRouted) { params.mPtrExpandedIdxToPermutedIdx[tokenIdx] = permutedIdx; } // write out `mPtrPermutedIdxToTokenIdx` if required if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert && isTokenRouted) { params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; } } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) routingIndicesClusterKernel(KernelParams params) { // number of tokens/expanded idx is bounded by total number of warps using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; using TypePacked = PackedScoreIdx; __shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps]; uint32_t const clusterBlockRank = blockIdx.x; int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); int32_t const laneIdx = cutlass::arch::LaneId(); // TODO(mjoux): expand to more tokens (possibly) auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx; auto scoreOffset = warpTokenIdx * params.mNumExperts; bool validToken = warpTokenIdx < params.mNumTokens; InputT minScore = InputT{-INFINITY}; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); // then wait on primary grid if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); } if (params.mPtrScores != nullptr) { // in this case, each warp represents a token // we then exchange all token max scores, s.t. afterwards, each thread // represents a token InputT warpMaxScore[MaxNumTopExperts]; int32_t warpMaxExpertIdx[MaxNumTopExperts]; if (validToken) { routingTopKExperts(warp, warpMaxScore, warpMaxExpertIdx, laneIdx, params.mNumExperts, params.mPtrScores + scoreOffset); if (cute::elect_one_sync()) { auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; TypePacked packedScore{finalScore, static_cast(warpMaxExpertIdx[0])}; smemPackedScoreIdx[warpIdx] = packedScore; } } // make packed scores available to all threads in cluster __cluster_barrier_arrive(); __cluster_barrier_wait(); } routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); } #else __global__ void routingIndicesClusterKernel(KernelParams params) { assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// // this kernel is needed in case we have scores as input for the histogram kernel template __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; using TypePacked = PackedScoreIdx; static constexpr int VecSize = MaxNumExperts / WarpSize; // we assume that #experts is a multiple of 4, so VecSize must be 4. static_assert(VecSize == 4); int32_t const laneIdx = cutlass::arch::LaneId(); int32_t const warpIdx = threadIdx.x / WarpSize; int32_t const globalWarpIdx = blockIdx.x * NumWarpsHist + warpIdx; int32_t const globalWarpStride = gridDim.x * NumWarpsHist; InputT minScore = InputT{-INFINITY}; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); // initialize the mPtrExpertCounts int32_t expertCountsNum = 2 * params.mNumExperts; int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x; int32_t globalThreadStride = gridDim.x * NumThreads; initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid and trigger secondary kernel. if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); cudaTriggerProgrammaticLaunchCompletion(); } #endif // in this case, each warp represents a token, and we use a grid-stride loop // over all warps/tokens for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) { auto scoreOffset = tokenIdx * params.mNumExperts; int32_t warpMaxExpertIdx[MaxNumTopExperts]; InputT warpMaxScore[MaxNumTopExperts]; routingTopKExperts(warp, warpMaxScore, warpMaxExpertIdx, laneIdx, params.mNumExperts, params.mPtrScores + scoreOffset); if (cute::elect_one_sync()) { auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; TypePacked packedScore{finalScore, static_cast(warpMaxExpertIdx[0])}; params.mPtrExpertIdx[tokenIdx] = packedScore; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// void run(Data const& data, void* stream) { TORCH_CHECK(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr, "Routing kernel requires at least one input parameter"); TORCH_CHECK(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"); TORCH_CHECK(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= ", MaxNumTopExperts, ", got ", data.mTopK); TORCH_CHECK(data.mNumExperts <= MaxNumExperts, "Routing kernel expects #experts ", data.mNumExperts, " to be at most max #experts ", MaxNumExperts); static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads"); static_assert(MaxNumExperts <= NumThreadsHist, "#experts must be bounded by #threads"); TORCH_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts ", data.mNumExperts, " to be a multiple of 4."); TORCH_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got ", data.mPaddingLog2); bool const useSingleWarp = (data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens) || data.mNumTokens < WarpKernelMaxNumTokens; bool const useSingleCluster = data.mNumTokens <= (data.mPtrScores != nullptr ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster); if (!useSingleCluster) { TORCH_CHECK(data.mPtrExpertIdx != nullptr, "When #tokens is large, `mPtrExpertIdx` is a required input."); TORCH_CHECK(data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); } if (useSingleWarp) { LAUNCH_ROUTING(data, /*coopLaunch=*/false, routingIndicesWarpKernel, 1, WarpSize, /*smemSize=*/0, // No dynamic smem stream); } else if (useSingleCluster) { LAUNCH_ROUTING(data, /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads, /*smemSize=*/0, // No dynamic smem stream); } else { const uint32_t expandedIdxSize = data.mNumTokens * data.mTopK; const uint32_t histogramEltsPerBlock = 8 * NumThreadsHist; const uint32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreadsHist; // Limit grid size (all kernels use a grid-stride loop). const uint32_t maxNumBlocks = 1024; int const numBlocksHistogram = std::min( (expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); int const numBlocksOffsets = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); if (data.mPtrScores != nullptr) { LAUNCH_ROUTING(data, /*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks, NumThreadsHist, /*smemSize=*/0, // No dynamic smem stream); } else { // Reset the global histograms. CHECK_CUDA_ERROR(cudaMemsetAsync(data.mPtrExpertCounts, 0, static_cast(2 * data.mNumExperts) * sizeof(int32_t), (cudaStream_t)stream)); } LAUNCH_ROUTING(data, /*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreadsHist, /*smemSize=*/0, // No dynamic smem stream); LAUNCH_ROUTING(data, /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreadsHist, /*smemSize=*/0, // No dynamic smem stream); } } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace routingLlama4 } // namespace moe::dev::routing