sglang_v0.5.2/flashinfer_0.3.1/include/flashinfer/trtllm/fused_moe/runner.h

383 lines
15 KiB
C++

/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <string>
#include "DevKernel.h"
#include "RoutingKernel.h"
// #include "flashinfer/trtllm/common/cudaDriverWrapper.h"
#include "flashinfer/trtllm/batched_gemm/KernelRunner.h"
#include "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
#include "flashinfer/trtllm/common/cudaUtils.h"
namespace tensorrt_llm {
namespace kernels {
namespace trtllmgen_moe {
namespace Routing {
// The type of method in top-K routing, for use in torch custom op
// Please keep this in sync with the counterpart defined in
// flashinfer/fused_moe/core.py
enum class RoutingMethodType : int64_t {
// Default: Softmax -> TopK
Default = 0,
// Renormalize: TopK -> Softmax
Renormalize = 1,
// DeepSeekV3: Sigmoid -> RoutingBiasAdd -> Top2 in group -> Top4 groups -> Top8 experts from the
// Top4 groups
DeepSeekV3 = 2,
// Llama4: Top1 -> Sigmoid
Llama4 = 3,
// RenormalizeNaive: Softmax -> TopK -> Renormalize
RenormalizeNaive = 4,
// TopK only (no softmax)
TopK = 5,
// Unspecified
Unspecified = 6,
};
inline std::string serializeMoeRoutingMethodType(RoutingMethodType routingMethodType) {
switch (routingMethodType) {
case RoutingMethodType::Default:
return "Default";
case RoutingMethodType::Renormalize:
return "Renormalize";
case RoutingMethodType::DeepSeekV3:
return "DeepSeekV3";
case RoutingMethodType::Llama4:
return "Llama4";
case RoutingMethodType::RenormalizeNaive:
return "RenormalizeNaive";
case RoutingMethodType::TopK:
return "TopK";
default:
return "InvalidRountingMethod"; // TODO throw error
};
}
inline int32_t getMaxPermutedPaddedCount(int32_t numTokens, int32_t expertsPerToken,
int32_t numExperts, int32_t padding) {
auto const expandedRowCount = numTokens * expertsPerToken;
auto const maxPaddingRequired = (padding - 1) * numExperts;
return common::roundUp(expandedRowCount + maxPaddingRequired, padding);
}
inline int32_t getMaxNumCtasInBatchDim(int32_t numTokens, int32_t topK, int32_t numExperts,
int32_t tileTokensDim) {
// Get maximum number of CTAs in batch dim per expert.
auto const maxCtasInBatchDimPerExpert = common::ceilDiv(numTokens, tileTokensDim);
// Get maximum enabled experts.
auto const maxEnabledExperts = std::min(numTokens * topK, numExperts);
// Get maximum number of CTAs in batch dim.
auto maxNumCtasInBatchDim = maxEnabledExperts * maxCtasInBatchDimPerExpert;
// For large token counts, the above bound can be pessimistic since not all the tokens can
// be routed to all the enabled experts. Instead we can essentially bound the number of CTAs
// by permuted buffer size. However, this method will be overly pessimistic for low-token
// counts
auto const tilesForPermutedBuffer = common::ceilDiv(
getMaxPermutedPaddedCount(numTokens, topK, numExperts, tileTokensDim), tileTokensDim);
// Set maxNumCtasInBatchDim to be the minimum of the two methods
maxNumCtasInBatchDim = std::min(maxNumCtasInBatchDim, tilesForPermutedBuffer);
return maxNumCtasInBatchDim;
}
class Runner {
public:
explicit Runner();
explicit Runner(int32_t tileTokensDim);
void run(void* routingLogits, void* routingBias, int32_t numTokens, int32_t numExperts,
int32_t topK, int32_t nGroups, int32_t topkGroups, int32_t localExpertOffset,
int32_t localNumExperts, float routedScalingFactor, int32_t* routingExpertIndexes,
int32_t* expertCountHistogram, int32_t* permutedIdxSize,
int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx,
int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert,
int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas,
batchedGemm::trtllm::gen::Dtype dtypeElt, bool useRoutingScalesOnInput,
bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream);
private:
int32_t mTileTokensDim{8};
};
} // namespace Routing
namespace MoE {
// The type of gated activation function
// Please keep this in sync with the counterpart defined in flashinfer/flashinfer/fused_moe/core.py
enum class GatedActType : int64_t {
// SwiGlu
SwiGlu = 0,
// GeGlu
GeGlu = 1,
};
inline std::string serializeGatedActType(GatedActType gatedActType) {
switch (gatedActType) {
case GatedActType::SwiGlu:
return "SwiGlu";
case GatedActType::GeGlu:
return "GeGlu";
default:
return "InvalidGatedActType"; // TODO throw error
};
}
} // namespace MoE
namespace PermuteGemm1 {
class Runner {
public:
explicit Runner(batchedGemm::trtllm::gen::Dtype dtypeAct,
batchedGemm::trtllm::gen::Dtype dtypeWeights, bool useDeepSeekFp8,
int tileTokensDim, MoE::GatedActType gatedActType, bool useShuffledMatrixA,
batchedGemm::gemm::MatrixLayout weight_layout);
size_t getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize,
int32_t numExperts, int32_t numTokens, int32_t configIndex) const;
[[nodiscard]] int32_t getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize,
int32_t intermediateSize, int32_t numExperts,
int32_t numTokens) const;
[[nodiscard]] bool isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize,
int32_t intermediateSize, int32_t numExperts,
int32_t numTokens) const;
[[nodiscard]] std::vector<int64_t> getPassingConfigIndices() const;
void run(void* hiddenState, void* hiddenStateScale, void* weight, void* weightScale,
void* expertWeights, float* outputScalesScalar, float* outputScalesGateScalar,
float* ptrBias, float* ptrGatedActAlpha, float* ptrGatedActBeta, float* ptrClampLimit,
void* output, void* outputScale, int32_t topK, int32_t hiddenSize,
int32_t intermediateSize, int32_t numExperts, int32_t numTokens,
int32_t* permutedIdxToTokenIdx, int32_t* ptrNumNonExitingCtas,
int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx,
int32_t* ptrCtaIdxXyToMnLimit, void* bmm1Workspace, bool useRoutingScalesOnInput,
int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl);
private:
batchedGemm::trtllm::gen::Dtype mDtypeAct;
batchedGemm::trtllm::gen::Dtype mDtypeWeights;
int32_t mTileTokensDim;
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner mRunner;
};
} // namespace PermuteGemm1
namespace Gemm2 {
class Runner {
public:
explicit Runner(batchedGemm::trtllm::gen::Dtype dtypeAct,
batchedGemm::trtllm::gen::Dtype dtypeWeights,
batchedGemm::trtllm::gen::Dtype outputDtype, bool useDeepSeekFp8,
int tileTokensDim, bool useShuffledMatrixA,
batchedGemm::gemm::MatrixLayout weight_layout);
size_t getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize,
int32_t numExperts, int32_t numTokens, int32_t configIndex) const;
[[nodiscard]] int32_t getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize,
int32_t intermediateSize, int32_t numExperts,
int32_t numTokens) const;
[[nodiscard]] bool isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize,
int32_t intermediateSize, int32_t numExperts,
int32_t numTokens) const;
[[nodiscard]] std::vector<int64_t> getPassingConfigIndices() const;
void run(void* permutedHiddenState, void* permutedHiddenStateScale, void* weight,
void* weightScale, float* outputScalesScalar, float* ptrBias, void* output,
void* outputScale, int32_t topK, int32_t hiddenSize, int32_t intermediateSize,
int32_t numExperts, int32_t numTokens, int32_t* ptrNumNonExitingCtas,
int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx,
int32_t* ptrCtaIdxXyToMnLimit, void* bmm2Workspace, int device, cudaStream_t stream,
int32_t configIndex, bool enable_pdl);
private:
batchedGemm::trtllm::gen::Dtype mDtypeAct;
batchedGemm::trtllm::gen::Dtype mDtypeWeights;
batchedGemm::trtllm::gen::Dtype mDtypeOut;
int32_t mTileTokensDim;
tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner mRunner;
};
} // namespace Gemm2
namespace MoE {
namespace btg = batchedGemm::trtllm::gen;
struct MoERunnerArgs {
void* routing_logits = nullptr; // [num_tokens, num_experts] in float, generated after
// gemm(hidden_state, routing_weights)
void* routing_bias = nullptr; // [num_experts] in bfloat16 for now = mDtypeExpW
void* hidden_states = nullptr; // [num_tokens, hidden_size] in fp8 = mDtypeElt
// [hidden_size/128, num_tokens] in float for e4m3 DS recipe
// and [num_tokens, hidden_size/16] in float for e2m1
void* hidden_states_scale = nullptr;
// Gemm input:
void* gemm1_weights = nullptr;
void* gemm1_weights_scale = nullptr;
void* gemm2_weights = nullptr;
void* gemm2_weights_scale = nullptr;
float* gemm1_bias = nullptr;
float* gemm1_alpha = nullptr;
float* gemm1_beta = nullptr;
float* gemm1_clamp_limit = nullptr;
float* gemm2_bias = nullptr;
int32_t num_tokens{0};
int32_t num_experts{0};
// Hidden dimension input of MoE block. It might be padded.
int32_t hidden_size{0};
// Hidden dimension output of MoE block. It is not padded.
// If not provided it is the same as hidden_size.
std::optional<int32_t> hidden_size_output;
// TODO: only compiled routing kernel supports top_k = 8
int32_t top_k{0};
int32_t n_group{0};
// TODO: only compiled routing kernel supports topk_group = 4
int32_t topk_group{0};
float routed_scaling_factor{0.0f};
int32_t intermediate_size{0};
int32_t local_expert_offset{0};
int32_t local_num_experts{0};
// TODO: support other types
btg::Dtype mDtypeElt{btg::Dtype::Void};
btg::Dtype mDtypeExpW{btg::Dtype::Bfloat16};
btg::Dtype mDtypeOut{btg::Dtype::Bfloat16};
// Apply routing scale factors to input activations
bool mUseRoutingScalesOnInput{false};
bool mUseDeepSeekFp8{false};
float* output1_scales_scalar = nullptr;
float* output1_scales_gate_scalar = nullptr;
float* output2_scales_scalar = nullptr;
// Output:
void* output = nullptr;
float* output_scale = nullptr;
// finalize
bool do_finalize{true};
};
struct MoEWorkspace {
// Routing intermediate outputs:
int32_t* routing_expert_indexes = nullptr;
int32_t* permuted_idx_size = nullptr;
int32_t* total_num_padded_tokens = nullptr; // TODO: duplicate of permuted_idx_size
int32_t total_max_padded_tokens{0};
int32_t* expanded_idx_to_permuted_idx = nullptr;
int32_t* permuted_idx_to_expanded_idx = nullptr;
int32_t* permuted_idx_to_token_idx = nullptr;
void* expert_weights = nullptr; // [num_tokens, top_k] in bfloat16 = mDtypeExpW
int32_t* cta_idx_xy_to_batch_idx = nullptr;
int32_t* cta_idx_xy_to_mn_limit = nullptr;
int32_t* num_non_exiting_ctas = nullptr;
void* hidden_states_scale_linear = nullptr;
// Permute intermediate outputs:
void* permuted_hidden_states = nullptr;
float* permuted_hidden_states_scale = nullptr;
// Gemm1 intermediate outputs:
int32_t ProjUpTileN{0};
void* gemm1_output = nullptr;
float* gemm1_output_scale = nullptr;
// Activation intermediate outputs:
void* activation_output = nullptr;
float* activation_output_scale = nullptr;
// Gemm2 intermediate outputs:
void* gemm2_output = nullptr;
float* gemm2_output_scale = nullptr;
// Finalize intermediate outputs (placeholder not used)
void* finalize_output = nullptr;
float* finalize_output_scale = nullptr;
// FC1 workspace:
void* bmm1_workspace = nullptr;
// FC2 workspace:
void* bmm2_workspace = nullptr;
};
// Config indices to be used with Batched GEMM runners
struct MoEConfig {
int64_t gemm1Config;
int64_t gemm2Config;
};
class Runner {
public:
// FIXME: tileTokensDim is hardcoded for now
Runner(batchedGemm::trtllm::gen::Dtype dtypeAct, batchedGemm::trtllm::gen::Dtype dtypeWeights,
bool useDeepSeekFp8, int tileTokensDim = 8,
GatedActType gatedActType = GatedActType::SwiGlu, bool useShuffledMatrixA = false,
batchedGemm::gemm::MatrixLayout weight_layout = batchedGemm::gemm::MatrixLayout::MajorK);
Runner(batchedGemm::trtllm::gen::Dtype dtypeElt, bool useDeepSeekFp8, int tileTokensDim = 8,
bool useShuffledMatrixA = false,
batchedGemm::gemm::MatrixLayout weight_layout = batchedGemm::gemm::MatrixLayout::MajorK);
void run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int device,
cudaStream_t stream, int64_t configIndex, bool enable_pdl);
[[nodiscard]] std::tuple<int32_t, int32_t> getWorkspaceSizeInBytes(MoERunnerArgs const& args,
int64_t configIndex) const;
[[nodiscard]] std::vector<int64_t> getValidConfigIndices(int32_t topK, int32_t hiddenSize,
int32_t intermediateSize,
int32_t numLocalExperts,
int32_t numTokens) const;
[[nodiscard]] int64_t getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize,
int32_t intermediateSize,
int32_t numLocalExperts,
int32_t numTokens) const;
private:
void setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace,
moe::dev::convertsf::Data& convertSfData,
moe::dev::activation::Data& activationData,
moe::dev::finalize::Data& finalizeData);
private:
PermuteGemm1::Runner mPermuteGemm1;
Gemm2::Runner mGemm2;
// This will be the cartesian product of the passing configs for gemm1 and gemm2
// This allows us to autotune the MoE as one operation instead of tuning gemm1 and gemm2
// separately
std::vector<MoEConfig> mPassingConfigs;
};
} // namespace MoE
} // namespace trtllmgen_moe
} // namespace kernels
} // namespace tensorrt_llm