/* * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" namespace moe::dev::routing { namespace routingDeepSeek { //////////////////////////////////////////////////////////////////////////////////////////////////// static constexpr int NumThreads = 256; static constexpr int NumWarps = NumThreads / WarpSize; static constexpr int NumTopGroupScores = 2; static constexpr int MaxNumTopExperts = 8; static constexpr int MaxNumTopGroups = 4; template __global__ void routingMainKernel(KernelParams params) { // declare types using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; // declare shared memory structure // number of experts is bounded by number of threads __shared__ float __attribute((aligned(128))) smemScoreSigmoid[NumThreads]; __shared__ float __attribute((aligned(128))) smemScoreBias[NumThreads]; // number of expert groups is bounded by number of warps __shared__ float __attribute((aligned(128))) smemGroupScores[NumWarps]; // needed for warp reduce auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); // for the final reduction of weight norm, only some lanes need to participate int32_t laneIdx = threadIdx.x % WarpSize; int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); // warps outside the range of expert groups do not participate if constexpr (KernelParams::UseGroups) { if (warpIdx >= params.mNumExpertGroups) { return; } } // note that for invalid scores, we simply use a negative value: // they work well even with the compacted format used in topK, and // sigmoid / bias activated scores cannot be negative static constexpr float invalidScoreFloat = -1.F; const OutputT invalidScore = OutputT{invalidScoreFloat}; // load bias already; each warp represents one expert group auto threadExpert = threadIdx.x; bool expertSelected = threadExpert < params.mNumExperts; if constexpr (KernelParams::UseGroups) { threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx; expertSelected = laneIdx < params.mNumExpertsPerGroup; } auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert; auto biasVal = expertSelected ? params.mPtrRoutingBias[threadExpert] : invalidScore; // initialize the mPtrExpertCounts if (params.mPtrExpertCounts) { int32_t globalThreadIdx = blockIdx.x * NumThreads + threadIdx.x; int32_t globalThreadStride = gridDim.x * NumThreads; int32_t expertCountsNum = 2 * params.mNumExperts; initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); } #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) // trigger the secondary kernel when using PDL, then wait on primary if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); cudaGridDependencySynchronize(); } #endif // get our assigned thread score; each warp represents one expert group float score = expertSelected ? static_cast(params.mPtrScores[scoreIdx]) : invalidScoreFloat; // get the sigmoid score // note that for invalid values, we simply use a negative value: // sigmoig scores are always strictly positive auto scoreSigmoid = sigmoid_accurate(score); // write the sigmoid score to shared for later use if (expertSelected) { smemScoreSigmoid[threadExpert] = scoreSigmoid; } // get the score with bias // note that with invalid values, because sigmoid is < 1 and bias is -1, // we must get a negative value, which is smaller than any valid value auto scoreBias = float{scoreSigmoid + float{biasVal}}; if (expertSelected) { smemScoreBias[threadExpert] = scoreBias; } // registers for top group score reduction float topExpGroupScores[NumTopGroupScores]; [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; float topGroups[MaxNumTopGroups]; // bound of params.mNumLimitedGroups int32_t topGroupIdx[MaxNumTopGroups]; float expertScoreGroup[MaxNumTopGroups]; int32_t expertIdxGroup[MaxNumTopGroups]; float topScores[MaxNumTopExperts]; // bound of params.mTopK int32_t topExperts[MaxNumTopExperts]; if constexpr (KernelParams::UseGroups) { topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, /* minValue */ invalidScoreFloat); // get the final group score and write it to shared if (cute::elect_one_sync()) { auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; smemGroupScores[warpIdx] = groupScore; } } // make group scores available to all warps __syncthreads(); auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; if (warpIdx == 0) { // a single warp performs the selection of top groups, and goes on to select the final experts if constexpr (KernelParams::UseGroups) { float groupScore = laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat; topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, /* minValue */ invalidScoreFloat); // final expert selection: get relevant indexes and scores from shared #pragma unroll for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of params.mNumLimitedGroups auto groupIdx = topGroupIdx[ii]; expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx; // note: expertSelected implies laneIdx < params.mNumExpertsPerGroup. // we have params.mNumExpertsPerGroup == params.mNumExperts / params.mNumExpertGroups, // thus groupIdx <= params.mNumExpertGroups - 1 => // groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - params.mNumExpertsPerGroup // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, // so the access is safe here expertScoreGroup[ii] = groupIdx < params.mNumExpertGroups && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat; } } else { // without groups, each thread just takes `MaxNumTopGroups` experts #pragma unroll for (int ii = 0; ii < MaxNumTopGroups; ++ii) { auto expertIdx = ii * WarpSize + laneIdx; expertIdxGroup[ii] = expertIdx; expertScoreGroup[ii] = expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; } } topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, /* minValue */ invalidScoreFloat, params.mTopK); // determine our lane's expert index and write to output int32_t expertIdx = 0; #pragma unroll for (int ii = 0; ii < params.mTopK; ++ii) { // bound of params.mTopK expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx; } // determine whether our expert is local to this GPU auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F; auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm}; // write expert idx out already auto idxTopK = blockIdx.x * params.mTopK + laneIdx; if (laneIdx < params.mTopK && params.mPtrExpertIdx != nullptr) { PackedScoreIdx packedScore{static_cast(finalScore), static_cast(expertIdx)}; params.mPtrExpertIdx[idxTopK] = packedScore; } if (laneIdx < params.mTopK && params.mPtrExpertWeights != nullptr) { params.mPtrExpertWeights[idxTopK] = finalScore; } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) routingIndicesClusterKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); int32_t const clusterBlockRank = blockIdx.x; //@todo: try to move it into routingPermutation // then wait on primary grid if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); } routingPermutation(params, nullptr, warpIdx, clusterBlockRank); } #else __global__ void routingIndicesClusterKernel(KernelParams params) { assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// template #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) __global__ void __launch_bounds__(NumThreads) routingIndicesCoopKernel(KernelParams params) { // number of experts is bounded by number of threads __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads]; __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads]; // needed for the exclusive sum of token offsets using Scan = cub::BlockScan; __shared__ typename Scan::TempStorage tempStorage; // 64 elements -> 128+ registers. Above that we may start to see spilling to local memory. static constexpr int MaxExpandedIdxPerThread = 64; // Initialize grid. cg::grid_group grid = cg::this_grid(); // Note: the following is more efficient than grid.block_index() because we don't use y and z. int32_t const gridBlockIdx = blockIdx.x; int32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x; int32_t const numBlocks = gridDim.x; int32_t const numThreadsPerGrid = numBlocks * NumThreads; int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); auto expandedIdxSize = params.mNumTokens * params.mTopK; // pre-fill the counts with 0 smemExpertCount[threadIdx.x] = 0; __syncthreads(); // then wait on primary grid if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); } // each thread keeps has some number of "expanded indexes" assigned to it // for each of these, we keep the associated expert and offset within expert in registers int32_t expertIndexes[MaxExpandedIdxPerThread]; int32_t expertOffsets[MaxExpandedIdxPerThread]; auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a // time, and branch between a fast path without bound checks and a slow path with bound checks. int constexpr IterStride = 4; static_assert(MaxExpandedIdxPerThread % IterStride == 0); // Define a lambda to avoid code duplication in both branches. auto loopBody = [&](int ii, int expandedIdx) { int32_t expertIdx = params.mPtrExpertIdx[expandedIdx].idx; expertIndexes[ii] = expertIdx; // check whether this expert is local to our GPU at all and ignore if not auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0; }; #pragma unroll for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) { // Whether it's safe to do multiple iterations without bound checks. bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize; if (takeFastPath) { #pragma unroll for (int32_t jj = 0; jj < IterStride; jj++) { int const ii = ii0 + jj; auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; loopBody(ii, expandedIdx); } } else { bool doBreak = false; #pragma unroll for (int32_t jj = 0; jj < IterStride; jj++) { int const ii = ii0 + jj; auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; if (expandedIdx >= expandedIdxSize) { doBreak = true; break; } loopBody(ii, expandedIdx); } if (doBreak) { break; } } } // Make histogram (token counts per expert) available to all threads in the block. __syncthreads(); // // Each thread now represents one expert // // Add the local bin count to the common bin count and get a per-CTA offset. int32_t const localExpertCount = smemExpertCount[threadIdx.x]; int32_t blockExpertOffset = 0; if (threadIdx.x < params.mNumExperts) { blockExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); } // Sync to wait for completion of the histogram reduction. grid.sync(); // Get total count for this expert. int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; // Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency. // Compute the runtime config for projections // Whether or not an expert is local is taken into account when smemExpertCount is computed // so we do not need to take it into account here. const int32_t numCta = divUpLog2(count, params.mPaddingLog2); int32_t ctaOffset; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks) { const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mulLog2(ctaOffset + cta + 1, params.mPaddingLog2), mulLog2(ctaOffset, params.mPaddingLog2) + count); } // get the padded offset associated with this expert const int32_t offset = mulLog2(ctaOffset, params.mPaddingLog2); const int32_t permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); // write out padded count if (gridBlockIdx == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync()) { params.mPtrPermutedIdxSize[0] = permutedIdxSize; params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } // write expert offsets to shared smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; // make expert offsets available to all threads __syncthreads(); // trigger the secondary kernel when using PDL // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens // TODO: this is not sufficient to ensure visibility in the next kernel! if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } // each thread has the same "expanded indexes" assigned to it as above // at this point, we know the final offsets of experts and the offsets within // experts, which allows writing the final index values #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) { auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; if (expandedIdx >= expandedIdxSize) { break; } auto expertIdx = expertIndexes[ii]; // check whether this expert is local to our GPU at all auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; auto tokenIdx = expandedIdx / params.mTopK; auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; } if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) { params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; } } } #else __global__ void routingIndicesCoopKernel(KernelParams params) { assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures"); } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// void run(Data& data, void* stream) { TORCH_CHECK(data.mPtrExpertIdx != nullptr || data.mPtrPermutedIdxSize != nullptr || data.mPtrExpertWeights != nullptr, "Routing kernel requires at least one output parameter"); if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr) TORCH_CHECK(data.mPtrExpertIdx != nullptr && data.mPtrPermutedIdxSize, "If permuted index is required, `mPtrExpertIdx` is also required"); TORCH_CHECK(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet"); TORCH_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups, "Routing kernel expects <= ", MaxNumTopGroups, " top groups, got ", data.mNumLimitedGroups); TORCH_CHECK(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= ", MaxNumTopExperts, ", got ", data.mTopK); TORCH_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got ", data.mTopK); TORCH_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize, "Routing kernel expects top K * top groups <= warp size (for now), got ", data.mTopK, " * ", data.mNumLimitedGroups); TORCH_CHECK(data.mNumExperts >= MaxNumTopExperts, "Routing kernel expects ", MaxNumTopExperts, " to be at most #experts ", data.mNumExperts); TORCH_CHECK(data.mNumExperts <= NumThreads, "Routing kernel expects #experts ", data.mNumExperts, " <= #threads ", NumThreads); TORCH_CHECK(data.mNumExpertGroups >= data.mNumLimitedGroups, "Routing kernel expects top groups ", data.mNumLimitedGroups, " to be limited by #expert groups ", data.mNumExpertGroups); if (data.mNumExpertGroups > 1) { TORCH_CHECK(data.mNumExpertGroups <= NumWarps, "Routing kernel expects #experts groups ", data.mNumExpertGroups, " to be <= #warps ", NumWarps); TORCH_CHECK(data.mNumExperts % data.mNumExpertGroups == 0, "Routing kernel expects #experts ", data.mNumExperts, " to be a multiple of #expert groups ", data.mNumExpertGroups); TORCH_CHECK(data.mNumExperts / data.mNumExpertGroups <= WarpSize, "Routing kernel expects #experts per group <= warp size, got ", data.mNumExperts / data.mNumExpertGroups); } else { TORCH_CHECK(data.mNumExperts <= WarpSize * MaxNumTopGroups, "Routing kernel expects #experts ", data.mNumExperts, " <= WarpSize * MaxNumTopGroups ", WarpSize * MaxNumTopGroups); TORCH_CHECK(data.mTopK <= NumWarps, "Routing kernel expects top K ", data.mTopK, " to be <= #warps ", NumWarps); } TORCH_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts ", data.mNumExperts, " to be a multiple of 4."); TORCH_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got ", data.mPaddingLog2); int const numBlocks = data.mNumTokens; bool const useSingleCluster = data.mNumTokens <= 1024; if (!useSingleCluster) { // Reset the global histograms (not used in single-cluster code path). // Cover both for the cooperative and two-kernel code paths. TORCH_CHECK(data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); } else { data.mPtrExpertCounts = nullptr; // Set it to nullptr for single-cluster code path, as it won't be used } // Number of blocks we can use in the cooperative kernel // The number of blocks must be: // >= ⌈(numTokens * topK) / (MaxExpandedIdxPerThread * NumThreads)⌉ // <= numSms, assuming an occupancy of 1 block/SM // // If too small for the given numTokens, fall back to the less performant two-step method. // // The upper bound is a strict requirement. The number of blocks should be determined by querying // the device properties, or conservatively low. // /!\ The following number is not portable!! (but works on H100 and B200) int const numBlocksCoop = 128; // Maximum number of tokens supported by the kernel using a cooperative launch. int const maxTokensCoop = (numBlocksCoop * NumThreads * 64) / data.mTopK; LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, /*coopLaunch=*/false, routingMainKernel, numBlocks, NumThreads, /*smemSize=*/0, // No dynamic smem stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); if (data.mPtrPermutedIdxSize != nullptr) { if (useSingleCluster) { LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads, /*smemSize=*/0, // No dynamic smem stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); } else if (data.mNumTokens <= maxTokensCoop) { LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, NumThreads, /*smemSize=*/0, // No dynamic smem stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); } else { const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; const int32_t histogramEltsPerBlock = 8 * NumThreads; const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * NumThreads; // Limit grid size (both kernels use a grid-stride loop). const int32_t maxNumBlocks = 1024; int const numBlocksHistogram = std::min( (expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); int const numBlocksOffsets = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, /*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, NumThreads, /*smemSize=*/0, // No dynamic smem stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); LAUNCH_ROUTING_WITH_EXTRA_FLAG(data, /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, NumThreads, /*smemSize=*/0, // No dynamic smem stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); } } } //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace routingDeepSeek } // namespace moe::dev::routing