1172 lines
63 KiB
Plaintext
1172 lines
63 KiB
Plaintext
/*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#include <ATen/ATen.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <ATen/cuda/EmptyTensor.h>
|
|
#include <c10/core/ScalarType.h>
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
#include <c10/cuda/CUDAStream.h>
|
|
#include <c10/util/Exception.h>
|
|
#include <cuda_runtime.h>
|
|
#include <flashinfer/exception.h>
|
|
#include <nvrtc.h>
|
|
#include <torch/library.h>
|
|
|
|
#include <algorithm>
|
|
#include <cmath>
|
|
#include <iostream>
|
|
#include <vector>
|
|
|
|
#include "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/GemmGatedActOptions.h"
|
|
#include "flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
|
|
#include "flashinfer/trtllm/fused_moe/DevKernel.h"
|
|
#include "flashinfer/trtllm/fused_moe/RoutingKernel.h"
|
|
#include "flashinfer/trtllm/fused_moe/runner.h"
|
|
#include "nv_internal/tensorrt_llm/kernels/quantization.h"
|
|
#include "nv_internal/tensorrt_llm/thop/thUtils.h"
|
|
|
|
namespace flashinfer {
|
|
|
|
namespace btg = batchedGemm::trtllm::gen;
|
|
using tensorrt_llm::kernels::trtllmgen_moe::MoE::GatedActType;
|
|
using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType;
|
|
|
|
at::Tensor trtllm_fp8_per_tensor_scale_moe_launcher(
|
|
at::Tensor const& routing_logits, std::optional<at::Tensor> routing_bias,
|
|
at::Tensor const& hidden_states, at::Tensor const& gemm1_weights,
|
|
at::Tensor const& output1_scales_scalar, at::Tensor const& output1_scales_gate_scalar,
|
|
at::Tensor const& gemm2_weights, at::Tensor const& output2_scales_scalar,
|
|
int64_t const num_experts, int64_t const top_k, int64_t const n_group, int64_t const topk_group,
|
|
int64_t const intermediate_size, int64_t const local_expert_offset,
|
|
int64_t const local_num_experts, double const routed_scaling_factor,
|
|
bool const use_routing_scales_on_input, int64_t const tile_tokens_dim,
|
|
int64_t const routing_method_type, bool enable_pdl) {
|
|
auto device = hidden_states.device();
|
|
|
|
static const std::tuple<int, int> device_props = [&device] {
|
|
int major, minor;
|
|
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device.index());
|
|
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device.index());
|
|
return std::make_tuple(major, minor);
|
|
}();
|
|
|
|
TORCH_CHECK(std::get<0>(device_props) == 10,
|
|
"This kernel requires 10.x architecture. Current device has SM ",
|
|
std::get<0>(device_props), std::get<1>(device_props));
|
|
|
|
if (use_routing_scales_on_input) {
|
|
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::BFloat16,
|
|
"routing_logits must be bfloat16.");
|
|
} else {
|
|
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float,
|
|
"routing_logits must be float.");
|
|
}
|
|
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
|
|
TORCH_CHECK(routing_logits.sizes()[1] == num_experts, "routing_logits has incorrect shape.");
|
|
if (routing_bias.has_value()) {
|
|
TORCH_CHECK(routing_bias.value().scalar_type() == at::ScalarType::BFloat16,
|
|
"routing_bias must be bfloat16.");
|
|
TORCH_CHECK(routing_bias.value().dim() == 1, "routing_bias must be 1D.");
|
|
TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts,
|
|
"routing_bias has incorrect shape.");
|
|
}
|
|
|
|
if (n_group <= 0 || topk_group <= 0) {
|
|
TORCH_CHECK(top_k == 1, "Current routing kernel (no groups) only supports top_k=1.");
|
|
} else {
|
|
TORCH_CHECK(top_k <= 8, "Current routing kernel (with groups) only supports top_k<=8.");
|
|
TORCH_CHECK(topk_group <= 4,
|
|
"Current routing kernel (with groups) only supports topk_group<=4.");
|
|
TORCH_CHECK(topk_group <= n_group, "n_group must not be smaller than topk_group.");
|
|
TORCH_CHECK(num_experts % n_group == 0, "num_experts must be divisible by n_group");
|
|
// This check ensures we have enough experts in the selected groups to handle the top_k routing
|
|
TORCH_CHECK(top_k < (topk_group * num_experts / n_group),
|
|
"top_k must be less than total number of experts in selected groups");
|
|
}
|
|
TORCH_CHECK(num_experts % 4 == 0,
|
|
"Routing kernel expects that num_experts must be divisible by 4");
|
|
TORCH_CHECK(num_experts > top_k, "num_experts must be greater than top_k");
|
|
TORCH_CHECK(local_num_experts + local_expert_offset <= num_experts,
|
|
"num_experts must be greater or equal to local_num_experts + local_expert_offset");
|
|
|
|
tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs args;
|
|
tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace;
|
|
|
|
// Convert PyTorch dtype to TensorRT-LLM dtype
|
|
auto dtype = hidden_states.dtype();
|
|
if (dtype == at::ScalarType::Half) {
|
|
args.mDtypeElt = btg::Dtype::Fp16;
|
|
} else if (dtype == at::ScalarType::BFloat16) {
|
|
args.mDtypeElt = btg::Dtype::Bfloat16;
|
|
} else if (dtype == at::ScalarType::Float8_e4m3fn) {
|
|
args.mDtypeElt = btg::Dtype::E4m3;
|
|
} else {
|
|
TORCH_CHECK(false, "Unsupported input dtype for MoE: ", dtype);
|
|
}
|
|
|
|
args.routing_logits = routing_logits.data_ptr();
|
|
auto const routing_bias_dtype =
|
|
routing_bias.has_value() ? routing_bias.value().scalar_type() : at::ScalarType::BFloat16;
|
|
args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr;
|
|
args.hidden_states = hidden_states.data_ptr();
|
|
args.gemm1_weights = gemm1_weights.data_ptr();
|
|
args.output1_scales_scalar = output1_scales_scalar.data_ptr<float>();
|
|
args.output1_scales_gate_scalar = output1_scales_gate_scalar.data_ptr<float>();
|
|
args.gemm2_weights = gemm2_weights.data_ptr();
|
|
args.output2_scales_scalar = output2_scales_scalar.data_ptr<float>();
|
|
args.num_tokens = hidden_states.sizes()[0];
|
|
args.num_experts = num_experts;
|
|
args.hidden_size = hidden_states.sizes()[1];
|
|
args.hidden_size_output = args.hidden_size;
|
|
args.top_k = top_k;
|
|
args.n_group = n_group;
|
|
args.topk_group = topk_group;
|
|
args.local_expert_offset = local_expert_offset;
|
|
args.local_num_experts = local_num_experts;
|
|
args.routed_scaling_factor = routed_scaling_factor;
|
|
args.intermediate_size = intermediate_size;
|
|
args.mUseRoutingScalesOnInput = use_routing_scales_on_input;
|
|
|
|
// allocate workspace for routing kernel
|
|
at::Tensor num_tokens_per_expert = at::detail::empty_cuda({num_experts}, at::ScalarType::Int,
|
|
routing_logits.device(), std::nullopt);
|
|
int32_t max_num_padded_tokens =
|
|
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount(
|
|
args.num_tokens, top_k, num_experts, tile_tokens_dim);
|
|
at::Tensor total_num_padded_tokens =
|
|
at::empty({}, at::TensorOptions().device(routing_logits.device()).dtype(at::ScalarType::Int));
|
|
at::Tensor expanded_idx_to_permuted_idx = at::detail::empty_cuda(
|
|
{args.num_tokens * args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
|
|
at::Tensor permuted_idx_to_token_idx = at::detail::empty_cuda(
|
|
{max_num_padded_tokens}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
|
|
at::Tensor expert_weights = at::detail::empty_cuda(
|
|
{args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits.device(), std::nullopt);
|
|
at::Tensor expert_indexes = at::detail::empty_cuda(
|
|
{args.num_tokens, args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
|
|
at::Tensor expert_count_histogram = at::detail::empty_cuda(
|
|
{2 * 256},
|
|
at::ScalarType::Int, // 256 is the max number of threads per block and max number of experts
|
|
routing_logits.device(), std::nullopt);
|
|
|
|
// allocate workspace for activation/gemm/finalize kernels
|
|
at::Tensor gemm1_output =
|
|
at::detail::empty_cuda({max_num_padded_tokens, 2 * intermediate_size},
|
|
at::ScalarType::Float8_e4m3fn, hidden_states.device(), std::nullopt);
|
|
at::Tensor gemm1_output_scale =
|
|
at::detail::empty_cuda({2 * intermediate_size / 128, max_num_padded_tokens},
|
|
at::ScalarType::Float, hidden_states.device(), std::nullopt);
|
|
at::Tensor activation_output =
|
|
at::detail::empty_cuda({max_num_padded_tokens, intermediate_size},
|
|
at::ScalarType::Float8_e4m3fn, hidden_states.device(), std::nullopt);
|
|
at::Tensor activation_output_scale =
|
|
at::detail::empty_cuda({intermediate_size / 128, max_num_padded_tokens},
|
|
at::ScalarType::Float, hidden_states.device(), std::nullopt);
|
|
at::Tensor gemm2_output =
|
|
at::detail::empty_cuda({max_num_padded_tokens, args.hidden_size}, at::ScalarType::BFloat16,
|
|
hidden_states.device(), std::nullopt);
|
|
|
|
int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim(
|
|
args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim);
|
|
at::Tensor cta_idx_xy_to_batch_idx = at::detail::empty_cuda(
|
|
{max_num_ctas}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
|
|
at::Tensor cta_idx_xy_to_mn_limit = at::detail::empty_cuda({max_num_ctas}, at::ScalarType::Int,
|
|
routing_logits.device(), std::nullopt);
|
|
at::Tensor num_non_exiting_ctas =
|
|
at::empty({}, at::TensorOptions().device(routing_logits.device()).dtype(at::ScalarType::Int));
|
|
|
|
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
|
|
auto const& stream = at::cuda::getCurrentCUDAStream(routing_logits.get_device());
|
|
routing_runner.run(routing_logits.data_ptr(), args.routing_bias, args.num_tokens,
|
|
args.num_experts, args.top_k, args.n_group, args.topk_group,
|
|
args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor,
|
|
expert_indexes.data_ptr<int>(), expert_count_histogram.data_ptr<int>(),
|
|
total_num_padded_tokens.data_ptr<int>(),
|
|
expanded_idx_to_permuted_idx.data_ptr<int>(),
|
|
nullptr /*permuted_idx_to_expanded_idx.data_ptr<int>()*/,
|
|
permuted_idx_to_token_idx.data_ptr<int>(), expert_weights.data_ptr(),
|
|
num_tokens_per_expert.data_ptr<int>(), cta_idx_xy_to_batch_idx.data_ptr<int>(),
|
|
cta_idx_xy_to_mn_limit.data_ptr<int>(), num_non_exiting_ctas.data_ptr<int>(),
|
|
args.mDtypeElt, use_routing_scales_on_input, false /* use_deep_seek_fp8 */,
|
|
static_cast<RoutingMethodType>(routing_method_type), stream);
|
|
|
|
// MoE kernel except routing
|
|
TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn,
|
|
"hidden_states must be fp8.");
|
|
TORCH_CHECK(gemm1_weights.scalar_type() == at::ScalarType::Float8_e4m3fn,
|
|
"gemm1_weights must be fp8.");
|
|
TORCH_CHECK(gemm1_weights.dim() == 3, "gemm1_weights must be 3D.");
|
|
TORCH_CHECK(gemm1_weights.sizes()[1] % 2 == 0, "the second dimension of weights must be even.");
|
|
TORCH_CHECK(intermediate_size == gemm1_weights.sizes()[1] / 2,
|
|
"intermediate_size has incorrect shape.");
|
|
TORCH_CHECK(gemm1_weights.sizes()[2] == hidden_states.sizes()[1],
|
|
"the third dimension of weights must be equal to hidden_size.");
|
|
TORCH_CHECK(intermediate_size % 128 == 0,
|
|
"the second dimension of weights must be a multiple of 128.");
|
|
|
|
TORCH_CHECK(output1_scales_scalar.scalar_type() == at::ScalarType::Float,
|
|
"output1_scales_scalar must be float.");
|
|
TORCH_CHECK(output1_scales_scalar.dim() == 1, "output1_scales_scalar must be 1D.");
|
|
TORCH_CHECK(output1_scales_scalar.sizes()[0] == local_num_experts,
|
|
"output1_scales_scalar has incorrect dim 0.");
|
|
TORCH_CHECK(output1_scales_gate_scalar.scalar_type() == at::ScalarType::Float,
|
|
"output1_scales_gate_scalar must be float.");
|
|
TORCH_CHECK(output1_scales_gate_scalar.dim() == 1, "output1_scales_gate_scalar must be 1D.");
|
|
TORCH_CHECK(output1_scales_gate_scalar.sizes()[0] == local_num_experts,
|
|
"output1_scales_gate_scalar has incorrect dim 0.");
|
|
|
|
TORCH_CHECK(gemm2_weights.scalar_type() == at::ScalarType::Float8_e4m3fn,
|
|
"gemm2_weights must be fp8.");
|
|
TORCH_CHECK(gemm2_weights.dim() == 3, "gemm2_weights must be 3D.");
|
|
TORCH_CHECK(gemm2_weights.sizes()[2] == intermediate_size,
|
|
"the third dimension of weights must be equal to intermediate_size.");
|
|
|
|
TORCH_CHECK(output2_scales_scalar.scalar_type() == at::ScalarType::Float,
|
|
"output2_scales_scalar must be float.");
|
|
TORCH_CHECK(output2_scales_scalar.dim() == 1, "output2_scales_scalar must be 1D.");
|
|
TORCH_CHECK(output2_scales_scalar.sizes()[0] == local_num_experts,
|
|
"output2_scales_scalar has incorrect dim 0.");
|
|
|
|
// allocate output
|
|
at::Tensor output =
|
|
at::detail::empty_cuda({args.num_tokens, args.hidden_size}, at::ScalarType::BFloat16,
|
|
hidden_states.device(), std::nullopt);
|
|
|
|
// setup workspace
|
|
workspace.total_num_padded_tokens = total_num_padded_tokens.data_ptr<int>();
|
|
workspace.total_max_padded_tokens = max_num_padded_tokens;
|
|
workspace.ProjUpTileN = tile_tokens_dim;
|
|
workspace.routing_expert_indexes = expert_indexes.data_ptr<int>();
|
|
workspace.permuted_idx_size = total_num_padded_tokens.data_ptr<int>();
|
|
workspace.expanded_idx_to_permuted_idx =
|
|
expanded_idx_to_permuted_idx.data_ptr<int>(); // Needed by activation/finalize kernels
|
|
workspace.permuted_idx_to_token_idx =
|
|
permuted_idx_to_token_idx.data_ptr<int>(); // Needed by permuteGemm1 kernel
|
|
workspace.expert_weights = expert_weights.data_ptr(); // Consumed by finalize kernel
|
|
|
|
workspace.cta_idx_xy_to_batch_idx = cta_idx_xy_to_batch_idx.data_ptr<int>();
|
|
workspace.cta_idx_xy_to_mn_limit = cta_idx_xy_to_mn_limit.data_ptr<int>();
|
|
workspace.num_non_exiting_ctas = num_non_exiting_ctas.data_ptr<int>();
|
|
|
|
// gemm1 intermediate ws
|
|
workspace.gemm1_output = gemm1_output.data_ptr();
|
|
workspace.gemm1_output_scale = gemm1_output_scale.data_ptr<float>();
|
|
// activation intermediate ws
|
|
workspace.activation_output = activation_output.data_ptr();
|
|
workspace.activation_output_scale = activation_output_scale.data_ptr<float>();
|
|
// gemm2 intermediate ws
|
|
workspace.gemm2_output = gemm2_output.data_ptr();
|
|
workspace.gemm2_output_scale = nullptr;
|
|
args.output = output.data_ptr();
|
|
args.output_scale = nullptr;
|
|
|
|
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
|
|
args.mDtypeElt, args.mUseDeepSeekFp8, tile_tokens_dim, /*useShuffledMatrixA*/ true);
|
|
|
|
auto const moeConfigIndex =
|
|
moe_runner.getDefaultValidConfigIndex(args.top_k, args.hidden_size, args.intermediate_size,
|
|
args.local_num_experts, args.num_tokens);
|
|
|
|
auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex);
|
|
at::Tensor workspace_fc1 = at::detail::empty_cuda(
|
|
{std::get<0>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt);
|
|
at::Tensor workspace_fc2 = at::detail::empty_cuda(
|
|
{std::get<1>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt);
|
|
workspace.bmm1_workspace = workspace_fc1.data_ptr();
|
|
workspace.bmm2_workspace = workspace_fc2.data_ptr();
|
|
auto const& moe_stream = at::cuda::getCurrentCUDAStream(hidden_states.get_device());
|
|
moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex,
|
|
enable_pdl);
|
|
return output;
|
|
}
|
|
|
|
at::Tensor trtllm_fp8_per_tensor_scale_moe(
|
|
at::Tensor routing_logits, std::optional<at::Tensor> routing_bias, at::Tensor hidden_states,
|
|
at::Tensor gemm1_weights, at::Tensor output1_scales_scalar,
|
|
at::Tensor output1_scales_gate_scalar, at::Tensor gemm2_weights,
|
|
at::Tensor output2_scales_scalar, int64_t num_experts, int64_t top_k, int64_t n_group,
|
|
int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset,
|
|
int64_t local_num_experts, double routed_scaling_factor, bool use_routing_scales_on_input,
|
|
int64_t tile_tokens_dim, int64_t routing_method_type, bool enable_pdl) {
|
|
auto dtype = hidden_states.dtype();
|
|
if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16 ||
|
|
dtype == at::ScalarType::Float8_e4m3fn) {
|
|
return trtllm_fp8_per_tensor_scale_moe_launcher(
|
|
routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar,
|
|
output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar, num_experts, top_k,
|
|
n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts,
|
|
routed_scaling_factor, use_routing_scales_on_input, tile_tokens_dim, routing_method_type,
|
|
enable_pdl);
|
|
} else {
|
|
TORCH_CHECK(false, "Unsupported input type: ", dtype);
|
|
}
|
|
}
|
|
|
|
at::Tensor trtllm_fp8_block_scale_moe_launcher(
|
|
at::Tensor const& routing_logits, std::optional<at::Tensor> routing_bias,
|
|
at::Tensor const& hidden_states, at::Tensor const& hidden_states_scale,
|
|
at::Tensor const& gemm1_weights, at::Tensor const& gemm1_weights_scale,
|
|
at::Tensor const& gemm2_weights, at::Tensor const& gemm2_weights_scale,
|
|
int64_t const num_experts, int64_t const top_k, int64_t const n_group, int64_t const topk_group,
|
|
int64_t const intermediate_size, int64_t const local_expert_offset,
|
|
int64_t const local_num_experts, double const routed_scaling_factor,
|
|
int64_t const tile_tokens_dim, int64_t const routing_method_type,
|
|
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex,
|
|
bool enable_pdl) {
|
|
auto device = hidden_states.device();
|
|
|
|
static const std::tuple<int, int> device_props = [&device] {
|
|
int major, minor;
|
|
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device.index());
|
|
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device.index());
|
|
return std::make_tuple(major, minor);
|
|
}();
|
|
|
|
TORCH_CHECK(std::get<0>(device_props) == 10,
|
|
"This kernel requires 10.x architecture. Current device has SM ",
|
|
std::get<0>(device_props), std::get<1>(device_props));
|
|
|
|
TORCH_CHECK(routing_logits.scalar_type() == at::ScalarType::Float,
|
|
"routing_logits must be float.");
|
|
TORCH_CHECK(routing_logits.dim() == 2, "routing_logits must be 2D.");
|
|
TORCH_CHECK(routing_logits.sizes()[0] == hidden_states.sizes()[0],
|
|
"routing_logits and hidden_states must have the same number of tokens.");
|
|
TORCH_CHECK(routing_logits.sizes()[1] == num_experts,
|
|
"routing_logits dim1 must match num_experts.");
|
|
if (routing_bias.has_value()) {
|
|
TORCH_CHECK(routing_bias.value().scalar_type() == at::ScalarType::BFloat16 ||
|
|
routing_bias.value().scalar_type() == at::ScalarType::Float,
|
|
"routing_bias must be bfloat16 or float.");
|
|
TORCH_CHECK(routing_bias.value().dim() == 1, "routing_bias must be 1D.");
|
|
TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts,
|
|
"routing_bias has incorrect shape.");
|
|
}
|
|
|
|
if (n_group <= 0 || topk_group <= 0) {
|
|
TORCH_CHECK(top_k == 1, "Current routing kernel (no groups) only supports top_k=1.");
|
|
} else {
|
|
TORCH_CHECK(top_k <= 8, "Current routing kernel (with groups) only supports top_k<=8.");
|
|
TORCH_CHECK(topk_group <= 4,
|
|
"Current routing kernel (with groups) only supports topk_group<=4.");
|
|
TORCH_CHECK(topk_group <= n_group, "n_group must not be smaller than topk_group.");
|
|
TORCH_CHECK(num_experts % n_group == 0, "num_experts must be divisible by n_group");
|
|
// This check ensures we have enough experts in the selected groups to handle the top_k routing
|
|
TORCH_CHECK(top_k < (topk_group * num_experts / n_group),
|
|
"top_k must be less than total number of experts in selected groups");
|
|
}
|
|
TORCH_CHECK(num_experts % 4 == 0,
|
|
"Routing kernel expects that num_experts must be divisible by 4");
|
|
TORCH_CHECK(num_experts > top_k, "num_experts must be greater than top_k");
|
|
|
|
tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs args;
|
|
tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace;
|
|
|
|
// Convert PyTorch dtype to TensorRT-LLM dtype
|
|
auto dtype = hidden_states.dtype();
|
|
if (dtype == at::ScalarType::Half) {
|
|
args.mDtypeElt = btg::Dtype::Fp16;
|
|
} else if (dtype == at::ScalarType::BFloat16) {
|
|
args.mDtypeElt = btg::Dtype::Bfloat16;
|
|
} else if (dtype == at::ScalarType::Float8_e4m3fn) {
|
|
args.mDtypeElt = btg::Dtype::E4m3;
|
|
} else {
|
|
TORCH_CHECK(false, "Unsupported input dtype for MoE: ", dtype);
|
|
}
|
|
|
|
auto const routing_bias_dtype =
|
|
routing_bias.has_value() ? routing_bias.value().scalar_type() : at::ScalarType::BFloat16;
|
|
args.mDtypeExpW =
|
|
routing_bias_dtype == at::ScalarType::BFloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
|
|
args.routing_logits = routing_logits.data_ptr<float>();
|
|
args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr;
|
|
args.hidden_states = hidden_states.data_ptr();
|
|
args.hidden_states_scale = hidden_states_scale.data_ptr<float>();
|
|
args.gemm1_weights = gemm1_weights.data_ptr();
|
|
args.gemm1_weights_scale = gemm1_weights_scale.data_ptr<float>();
|
|
args.gemm2_weights = gemm2_weights.data_ptr();
|
|
args.gemm2_weights_scale = gemm2_weights_scale.data_ptr<float>();
|
|
args.num_tokens = hidden_states.sizes()[0];
|
|
args.num_experts = num_experts;
|
|
args.hidden_size = hidden_states.sizes()[1];
|
|
args.hidden_size_output = args.hidden_size;
|
|
args.top_k = top_k;
|
|
args.n_group = n_group;
|
|
args.topk_group = topk_group;
|
|
args.local_expert_offset = local_expert_offset;
|
|
args.local_num_experts = local_num_experts;
|
|
args.routed_scaling_factor = routed_scaling_factor;
|
|
args.intermediate_size = intermediate_size;
|
|
args.mUseDeepSeekFp8 = true;
|
|
|
|
// allocate workspace for routing kernel
|
|
at::Tensor num_tokens_per_expert = at::detail::empty_cuda({num_experts}, at::ScalarType::Int,
|
|
routing_logits.device(), std::nullopt);
|
|
int32_t max_num_padded_tokens =
|
|
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount(
|
|
args.num_tokens, top_k, num_experts, tile_tokens_dim);
|
|
at::Tensor total_num_padded_tokens =
|
|
at::empty({}, at::TensorOptions().device(routing_logits.device()).dtype(at::ScalarType::Int));
|
|
at::Tensor expanded_idx_to_permuted_idx = at::detail::empty_cuda(
|
|
{args.num_tokens * args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
|
|
at::Tensor permuted_idx_to_token_idx = at::detail::empty_cuda(
|
|
{max_num_padded_tokens}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
|
|
at::Tensor expert_weights = at::detail::empty_cuda(
|
|
{args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits.device(), std::nullopt);
|
|
at::Tensor expert_indexes = at::detail::empty_cuda(
|
|
{args.num_tokens, args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
|
|
int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2));
|
|
at::Tensor expert_count_histogram = at::detail::empty_cuda(
|
|
{size_of_expert_count_histogram},
|
|
at::ScalarType::Int, // 256 is the max number of threads per block and max number of experts
|
|
routing_logits.device(), std::nullopt);
|
|
|
|
// allocate workspace for activation/gemm/finalize kernels
|
|
at::Tensor gemm1_output =
|
|
at::detail::empty_cuda({max_num_padded_tokens, 2 * intermediate_size},
|
|
at::ScalarType::Float8_e4m3fn, hidden_states.device(), std::nullopt);
|
|
at::Tensor gemm1_output_scale =
|
|
at::detail::empty_cuda({2 * intermediate_size / 128, max_num_padded_tokens},
|
|
at::ScalarType::Float, hidden_states.device(), std::nullopt);
|
|
at::Tensor activation_output =
|
|
at::detail::empty_cuda({max_num_padded_tokens, intermediate_size},
|
|
at::ScalarType::Float8_e4m3fn, hidden_states.device(), std::nullopt);
|
|
at::Tensor activation_output_scale =
|
|
at::detail::empty_cuda({intermediate_size / 128, max_num_padded_tokens},
|
|
at::ScalarType::Float, hidden_states.device(), std::nullopt);
|
|
at::Tensor gemm2_output =
|
|
at::detail::empty_cuda({max_num_padded_tokens, args.hidden_size}, at::ScalarType::BFloat16,
|
|
hidden_states.device(), std::nullopt);
|
|
|
|
int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim(
|
|
args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim);
|
|
at::Tensor cta_idx_xy_to_batch_idx = at::detail::empty_cuda(
|
|
{max_num_ctas}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
|
|
at::Tensor cta_idx_xy_to_mn_limit = at::detail::empty_cuda({max_num_ctas}, at::ScalarType::Int,
|
|
routing_logits.device(), std::nullopt);
|
|
at::Tensor num_non_exiting_ctas =
|
|
at::empty({}, at::TensorOptions().device(routing_logits.device()).dtype(at::ScalarType::Int));
|
|
|
|
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
|
|
auto const& stream = at::cuda::getCurrentCUDAStream(routing_logits.get_device());
|
|
routing_runner.run(
|
|
routing_logits.data_ptr<float>(), args.routing_bias, args.num_tokens, args.num_experts,
|
|
args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts,
|
|
args.routed_scaling_factor, expert_indexes.data_ptr<int>(),
|
|
expert_count_histogram.data_ptr<int>(), total_num_padded_tokens.data_ptr<int>(),
|
|
expanded_idx_to_permuted_idx.data_ptr<int>(),
|
|
nullptr /*permuted_idx_to_expanded_idx.data_ptr<int>()*/,
|
|
permuted_idx_to_token_idx.data_ptr<int>(), expert_weights.data_ptr(),
|
|
num_tokens_per_expert.data_ptr<int>(), cta_idx_xy_to_batch_idx.data_ptr<int>(),
|
|
cta_idx_xy_to_mn_limit.data_ptr<int>(), num_non_exiting_ctas.data_ptr<int>(), args.mDtypeElt,
|
|
false, true, static_cast<RoutingMethodType>(routing_method_type), stream);
|
|
|
|
// MoE kernel except routing
|
|
TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn,
|
|
"hidden_states must be fp8.");
|
|
TORCH_CHECK(hidden_states_scale.scalar_type() == at::ScalarType::Float,
|
|
"hidden_states_scale must be float.");
|
|
TORCH_CHECK(hidden_states_scale.dim() == 2, "hidden_states_scale must be 2D.");
|
|
TORCH_CHECK(hidden_states_scale.sizes()[0] == hidden_states.sizes()[1] / 128,
|
|
"hidden_states_scale dim0 must match hidden_states dim1 / 128.");
|
|
TORCH_CHECK(hidden_states_scale.sizes()[1] == args.num_tokens,
|
|
"hidden_states_scale dim1 must match num_tokens.");
|
|
TORCH_CHECK(gemm1_weights.scalar_type() == at::ScalarType::Float8_e4m3fn,
|
|
"gemm1_weights must be fp8.");
|
|
|
|
TORCH_CHECK(gemm1_weights.dim() == 3 || gemm1_weights.dim() == 4,
|
|
"gemm1_weights must be 3D or 4D.");
|
|
{
|
|
int64_t Mn = 0, K = 0;
|
|
if (gemm1_weights.dim() == 3) {
|
|
// MajorK [num_experts, M, K]
|
|
Mn = gemm1_weights.sizes()[1];
|
|
K = gemm1_weights.sizes()[2];
|
|
} else if (gemm1_weights.dim() == 4) {
|
|
// BlockMajorK [num_experts, K/block_k, M, block_k]
|
|
Mn = gemm1_weights.sizes()[2];
|
|
int64_t block_k = gemm1_weights.sizes()[3];
|
|
K = gemm1_weights.sizes()[1] * block_k;
|
|
}
|
|
TORCH_CHECK(Mn % 2 == 0, "the second dimension of weights must be even.");
|
|
TORCH_CHECK(intermediate_size == Mn / 2, "intermediate_size has incorrect shape.");
|
|
TORCH_CHECK(K == hidden_states.sizes()[1],
|
|
"the third dimension of weights must be equal to hidden_size.");
|
|
}
|
|
TORCH_CHECK(gemm1_weights_scale.scalar_type() == at::ScalarType::Float,
|
|
"gemm1_weights_scale must be float.");
|
|
TORCH_CHECK(gemm1_weights_scale.dim() == 3, "gemm1_weights_scale must be 3D.");
|
|
|
|
TORCH_CHECK(gemm1_weights_scale.sizes()[0] == local_num_experts,
|
|
"gemm1_weights_scale has incorrect shape.");
|
|
TORCH_CHECK(intermediate_size % 128 == 0,
|
|
"the second dimension of weights must be a multiple of 128.");
|
|
TORCH_CHECK(gemm1_weights_scale.sizes()[1] == 2 * intermediate_size / 128,
|
|
"gemm1_weights_scale has incorrect shape.");
|
|
TORCH_CHECK(gemm1_weights_scale.sizes()[2] == args.hidden_size / 128,
|
|
"gemm1_weights_scale has incorrect shape.");
|
|
TORCH_CHECK(gemm2_weights.scalar_type() == at::ScalarType::Float8_e4m3fn,
|
|
"gemm2_weights must be fp8.");
|
|
|
|
TORCH_CHECK(gemm2_weights.dim() == 3 || gemm2_weights.dim() == 4,
|
|
"gemm2_weights must be 3D or 4D.");
|
|
{
|
|
int64_t K = 0;
|
|
if (gemm2_weights.dim() == 3) {
|
|
// MajorK [num_experts, M, K]
|
|
K = gemm2_weights.sizes()[2];
|
|
} else if (gemm2_weights.dim() == 4) {
|
|
// BlockMajorK [num_experts, K/block_k, M, block_k]
|
|
int64_t block_k = gemm2_weights.sizes()[3];
|
|
K = gemm2_weights.sizes()[1] * block_k;
|
|
}
|
|
TORCH_CHECK(K == intermediate_size,
|
|
"the third dimension of weights must be equal to intermediate_size.");
|
|
}
|
|
TORCH_CHECK(gemm2_weights_scale.scalar_type() == at::ScalarType::Float,
|
|
"gemm2_weights_scale must be float.");
|
|
TORCH_CHECK(gemm2_weights_scale.dim() == 3, "gemm2_weights_scale must be 3D.");
|
|
TORCH_CHECK(gemm2_weights_scale.sizes()[0] == local_num_experts,
|
|
"gemm2_weights_scale has incorrect shape.");
|
|
TORCH_CHECK(gemm2_weights_scale.sizes()[1] == args.hidden_size / 128,
|
|
"gemm2_weights_scale has incorrect shape.");
|
|
TORCH_CHECK(gemm2_weights_scale.sizes()[2] == intermediate_size / 128,
|
|
"gemm2_weights_scale has incorrect shape.");
|
|
|
|
// allocate output
|
|
at::Tensor output =
|
|
at::detail::empty_cuda({args.num_tokens, args.hidden_size}, at::ScalarType::BFloat16,
|
|
hidden_states.device(), std::nullopt);
|
|
|
|
// setup workspace
|
|
workspace.total_num_padded_tokens = total_num_padded_tokens.data_ptr<int>();
|
|
workspace.total_max_padded_tokens = max_num_padded_tokens;
|
|
workspace.ProjUpTileN = tile_tokens_dim;
|
|
workspace.routing_expert_indexes = expert_indexes.data_ptr<int>();
|
|
workspace.permuted_idx_size = total_num_padded_tokens.data_ptr<int>();
|
|
workspace.expanded_idx_to_permuted_idx =
|
|
expanded_idx_to_permuted_idx.data_ptr<int>(); // Needed by activation/finalize kernels
|
|
workspace.permuted_idx_to_token_idx =
|
|
permuted_idx_to_token_idx.data_ptr<int>(); // Needed by permuteGemm1 kernel
|
|
workspace.expert_weights = expert_weights.data_ptr(); // Consumed by finalize kernel
|
|
|
|
workspace.cta_idx_xy_to_batch_idx = cta_idx_xy_to_batch_idx.data_ptr<int>();
|
|
workspace.cta_idx_xy_to_mn_limit = cta_idx_xy_to_mn_limit.data_ptr<int>();
|
|
workspace.num_non_exiting_ctas = num_non_exiting_ctas.data_ptr<int>();
|
|
|
|
// gemm1 intermediate ws
|
|
workspace.gemm1_output = gemm1_output.data_ptr();
|
|
workspace.gemm1_output_scale = gemm1_output_scale.data_ptr<float>();
|
|
// activation intermediate ws
|
|
workspace.activation_output = activation_output.data_ptr();
|
|
workspace.activation_output_scale = activation_output_scale.data_ptr<float>();
|
|
// gemm2 intermediate ws
|
|
workspace.gemm2_output = gemm2_output.data_ptr();
|
|
workspace.gemm2_output_scale = nullptr;
|
|
args.output = output.data_ptr();
|
|
args.output_scale = nullptr;
|
|
|
|
auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex);
|
|
at::Tensor workspace_fc1 = at::detail::empty_cuda(
|
|
{std::get<0>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt);
|
|
at::Tensor workspace_fc2 = at::detail::empty_cuda(
|
|
{std::get<1>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt);
|
|
workspace.bmm1_workspace = workspace_fc1.data_ptr();
|
|
workspace.bmm2_workspace = workspace_fc2.data_ptr();
|
|
|
|
auto const& moe_stream = at::cuda::getCurrentCUDAStream(hidden_states.get_device());
|
|
moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex,
|
|
enable_pdl);
|
|
return output;
|
|
}
|
|
|
|
at::Tensor trtllm_fp8_block_scale_moe(
|
|
at::Tensor const& routing_logits, std::optional<at::Tensor> routing_bias,
|
|
at::Tensor const& hidden_states, at::Tensor const& hidden_states_scale,
|
|
at::Tensor const& gemm1_weights, at::Tensor const& gemm1_weights_scale,
|
|
at::Tensor const& gemm2_weights, at::Tensor const& gemm2_weights_scale, int64_t num_experts,
|
|
int64_t top_k, int64_t n_group, int64_t topk_group, int64_t intermediate_size,
|
|
int64_t local_expert_offset, int64_t local_num_experts, double routed_scaling_factor,
|
|
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight,
|
|
int64_t weight_layout, bool enable_pdl) {
|
|
auto dtype = hidden_states.dtype();
|
|
if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16 ||
|
|
dtype == at::ScalarType::Float8_e4m3fn) {
|
|
using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
|
|
|
|
btg::Dtype mDtypeElt{btg::Dtype::E4m3}; // FP8 runner so hard-coded
|
|
bool mUseDeepSeekFp8{true}; // Always true for BlockScaleMoe
|
|
|
|
TORCH_CHECK(0 <= weight_layout && weight_layout <= 2,
|
|
"the value of weight_layout is not recognized");
|
|
|
|
// Properly initialize the runner using make_unique like in the original code
|
|
auto mRunner = std::make_unique<RunnerType>(
|
|
mDtypeElt, mUseDeepSeekFp8, tile_tokens_dim, use_shuffled_weight,
|
|
static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout));
|
|
|
|
// Always use fallback config (equivalent to moeConfigIndex == -1 case from original code)
|
|
auto const num_tokens = hidden_states.sizes()[0];
|
|
auto const hidden_size = hidden_states.sizes()[1];
|
|
|
|
int64_t moeConfigIndex = mRunner->getDefaultValidConfigIndex(
|
|
top_k, hidden_size, intermediate_size, local_num_experts, num_tokens);
|
|
|
|
return trtllm_fp8_block_scale_moe_launcher(
|
|
routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights,
|
|
gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, num_experts, top_k, n_group,
|
|
topk_group, intermediate_size, local_expert_offset, local_num_experts,
|
|
routed_scaling_factor, tile_tokens_dim, routing_method_type, *mRunner, moeConfigIndex,
|
|
enable_pdl);
|
|
} else {
|
|
TORCH_CHECK(false, "Unsupported input type: ", dtype);
|
|
}
|
|
}
|
|
|
|
// TODO(siyuan): This launcher supports flexible weight and activation types.
|
|
// We should cleanup other launchers and only use this one in the future.
|
|
std::vector<at::Tensor> trtllm_fp4_block_scale_moe_launcher(
|
|
std::optional<at::Tensor> const& routing_logits, at::Tensor& expert_indices,
|
|
at::Tensor& expert_weights, std::optional<at::Tensor> const& routing_bias,
|
|
at::Tensor const& hidden_states, std::optional<at::Tensor> const& hidden_states_scale,
|
|
at::Tensor const& gemm1_weights, at::Tensor const& gemm1_weights_scale,
|
|
std::optional<at::Tensor> const& gemm1_bias, std::optional<at::Tensor> const& gemm1_alpha,
|
|
std::optional<at::Tensor> const& gemm1_beta, std::optional<at::Tensor> const& gemm1_clamp_limit,
|
|
at::Tensor const& gemm2_weights, at::Tensor const& gemm2_weights_scale,
|
|
std::optional<at::Tensor> const& gemm2_bias,
|
|
std::optional<at::Tensor> const& output1_scales_scalar,
|
|
std::optional<at::Tensor> const& output1_scales_gate_scalar,
|
|
std::optional<at::Tensor> const& output2_scales_scalar, int64_t const num_experts,
|
|
int64_t const top_k, std::optional<int64_t> const n_group,
|
|
std::optional<int64_t> const topk_group, int64_t const intermediate_size,
|
|
int64_t const local_expert_offset, int64_t const local_num_experts,
|
|
std::optional<double> const routed_scaling_factor, int64_t const tile_tokens_dim,
|
|
int64_t const routing_method_type, bool const do_finalize,
|
|
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, btg::Dtype dtype_act,
|
|
btg::Dtype dtype_weights, int64_t const moeConfigIndex, bool enable_pdl, at::Tensor& output) {
|
|
auto device = hidden_states.device();
|
|
|
|
static const std::tuple<int, int> device_props = [&device] {
|
|
int major, minor;
|
|
cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device.index());
|
|
cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device.index());
|
|
return std::make_tuple(major, minor);
|
|
}();
|
|
|
|
TORCH_CHECK(std::get<0>(device_props) == 10,
|
|
"This kernel requires 10.x architecture. Current device has SM ",
|
|
std::get<0>(device_props), std::get<1>(device_props));
|
|
|
|
TORCH_CHECK(dtype_act == btg::Dtype::E2m1 || dtype_act == btg::Dtype::Bfloat16 ||
|
|
dtype_act == btg::Dtype::E4m3 || dtype_act == btg::Dtype::MxE4m3,
|
|
"Only E2m1, Bfloat16, MxE4m3 and E4m3 are supported by block scale MoE");
|
|
if (dtype_act == btg::Dtype::E2m1) {
|
|
TORCH_CHECK(dtype_weights == btg::Dtype::E2m1,
|
|
"Only E2m1 and MxE2m1 are supported by block scale MoE with E2m1 activation");
|
|
TORCH_CHECK(hidden_states_scale.has_value(),
|
|
"hidden_states_scale is required for E2m1 activation");
|
|
TORCH_CHECK(output1_scales_scalar.has_value(),
|
|
"output1_scales_scalar is required for E2m1 activation");
|
|
TORCH_CHECK(output1_scales_gate_scalar.has_value(),
|
|
"output1_scales_gate_scalar is required for E2m1 activation");
|
|
TORCH_CHECK(output2_scales_scalar.has_value(),
|
|
"output2_scales_scalar is required for E2m1 activation");
|
|
} else if (dtype_act == btg::Dtype::Bfloat16 || dtype_act == btg::Dtype::E4m3 ||
|
|
dtype_act == btg::Dtype::MxE4m3) {
|
|
TORCH_CHECK(dtype_weights == btg::Dtype::MxE2m1,
|
|
"Only MxE2m1 weights are supported by block scale MoE with Bfloat16, E4m3 or "
|
|
"MxE4m3 activation");
|
|
} else {
|
|
TORCH_CHECK(false, "Invalid dtype_act");
|
|
}
|
|
|
|
if (dtype_act == btg::Dtype::E4m3) {
|
|
TORCH_CHECK(output1_scales_scalar.has_value(),
|
|
"output1_scales_scalar is required for E4m3 activation");
|
|
TORCH_CHECK(output1_scales_gate_scalar.has_value(),
|
|
"output1_scales_gate_scalar is required for E4m3 activation");
|
|
TORCH_CHECK(output2_scales_scalar.has_value(),
|
|
"output2_scales_scalar is required for E4m3 activation");
|
|
}
|
|
|
|
if (routing_logits.has_value()) {
|
|
TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::Float ||
|
|
routing_logits.value().scalar_type() == at::ScalarType::BFloat16,
|
|
"routing_logits must be float or bfloat16.");
|
|
TORCH_CHECK(routing_logits.value().dim() == 2, "routing_logits must be 2D.");
|
|
TORCH_CHECK(routing_logits.value().sizes()[1] == num_experts,
|
|
"routing_logits has incorrect shape.");
|
|
}
|
|
if (routing_bias.has_value()) {
|
|
TORCH_CHECK(routing_bias.value().scalar_type() == at::ScalarType::BFloat16,
|
|
"routing_bias must be bfloat16.");
|
|
TORCH_CHECK(routing_bias.value().dim() == 1, "routing_bias must be 1D.");
|
|
TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts,
|
|
"routing_bias has incorrect shape.");
|
|
}
|
|
|
|
if (n_group.value_or(0) != 0) {
|
|
TORCH_CHECK(
|
|
static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3,
|
|
"Routing kernel with groups implies DeepSeekV3 routing method.");
|
|
TORCH_CHECK(topk_group.has_value(), "if n_group is given, topk_group must be given");
|
|
TORCH_CHECK(num_experts % n_group.value() == 0, "num_experts must be divisible by n_group");
|
|
TORCH_CHECK(top_k <= 8 && top_k > 0,
|
|
"Current routing kernel (with groups) only supports top_k<=8 && top_k>0.");
|
|
TORCH_CHECK(
|
|
topk_group.value() <= 4 && topk_group.value() > 0,
|
|
"Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0.");
|
|
TORCH_CHECK(topk_group.value() <= n_group.value(),
|
|
"n_group must not be smaller than topk_group.");
|
|
// This check ensures we have enough experts in the selected groups to handle the top_k routing
|
|
TORCH_CHECK(top_k < (topk_group.value() * num_experts / n_group.value()),
|
|
"top_k must be less than total number of experts in selected groups");
|
|
} else if (static_cast<RoutingMethodType>(routing_method_type) ==
|
|
RoutingMethodType::Renormalize ||
|
|
static_cast<RoutingMethodType>(routing_method_type) ==
|
|
RoutingMethodType::RenormalizeNaive ||
|
|
static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::TopK) {
|
|
TORCH_CHECK(
|
|
top_k <= 8 && top_k > 0,
|
|
"Current routing kernel (no groups, renormalize/topk) only supports top_k<=8 && top_k>0.");
|
|
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
|
|
TORCH_CHECK(top_k == 1, "Current routing kernel (no groups, Llama4) only supports top_k=1.");
|
|
}
|
|
|
|
TORCH_CHECK(num_experts % 4 == 0,
|
|
"Routing kernel expects that num_experts must be divisible by 4");
|
|
TORCH_CHECK(num_experts > top_k, "num_experts must be greater than top_k");
|
|
|
|
tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs args;
|
|
tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace;
|
|
|
|
// setup args
|
|
// note: the assumption is that output data type is always Bfloat16 (the default)
|
|
auto routing_bias_dtype = at::ScalarType::BFloat16;
|
|
if (routing_bias.has_value()) {
|
|
routing_bias_dtype = routing_bias.value().scalar_type();
|
|
} else if (routing_logits.has_value()) {
|
|
routing_bias_dtype = routing_logits.value().scalar_type();
|
|
}
|
|
args.mDtypeElt = dtype_act;
|
|
args.mDtypeExpW =
|
|
routing_bias_dtype == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16;
|
|
args.routing_logits = routing_logits.has_value() ? routing_logits.value().data_ptr() : nullptr;
|
|
args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr;
|
|
args.hidden_states = hidden_states.data_ptr();
|
|
args.hidden_states_scale =
|
|
hidden_states_scale.has_value() ? hidden_states_scale.value().data_ptr() : nullptr;
|
|
args.gemm1_weights = gemm1_weights.data_ptr();
|
|
args.gemm1_weights_scale = gemm1_weights_scale.data_ptr();
|
|
args.gemm1_bias = gemm1_bias.has_value() ? gemm1_bias.value().data_ptr<float>() : nullptr;
|
|
args.gemm1_alpha = gemm1_alpha.has_value() ? gemm1_alpha.value().data_ptr<float>() : nullptr;
|
|
args.gemm1_beta = gemm1_beta.has_value() ? gemm1_beta.value().data_ptr<float>() : nullptr;
|
|
args.gemm1_clamp_limit =
|
|
gemm1_clamp_limit.has_value() ? gemm1_clamp_limit.value().data_ptr<float>() : nullptr;
|
|
args.gemm2_weights = gemm2_weights.data_ptr();
|
|
args.gemm2_weights_scale = gemm2_weights_scale.data_ptr();
|
|
args.gemm2_bias = gemm2_bias.has_value() ? gemm2_bias.value().data_ptr<float>() : nullptr;
|
|
args.num_tokens = hidden_states.sizes()[0];
|
|
args.num_experts = num_experts;
|
|
// * 2 to compensate for the fact that sizeof(hidden_states.dtype) is 1 because we pack 2 e2m1
|
|
// into 1 byte.
|
|
auto const hidden_states_hidden_size =
|
|
dtype_act == btg::Dtype::E2m1 ? hidden_states.sizes()[1] * 2 : hidden_states.sizes()[1];
|
|
args.hidden_size = hidden_states_hidden_size;
|
|
args.hidden_size_output = args.hidden_size;
|
|
args.top_k = top_k;
|
|
args.n_group = n_group.value_or(0);
|
|
args.topk_group = topk_group.value_or(0);
|
|
args.local_expert_offset = local_expert_offset;
|
|
args.local_num_experts = local_num_experts;
|
|
args.routed_scaling_factor = routed_scaling_factor.value_or(1.0);
|
|
args.intermediate_size = intermediate_size;
|
|
|
|
// allocate workspace for routing kernel
|
|
at::Tensor num_tokens_per_expert = at::detail::empty_cuda({num_experts}, at::ScalarType::Int,
|
|
hidden_states.device(), std::nullopt);
|
|
int32_t max_num_padded_tokens =
|
|
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount(
|
|
args.num_tokens, top_k, num_experts, tile_tokens_dim);
|
|
at::Tensor total_num_padded_tokens =
|
|
at::empty({1}, at::TensorOptions().device(hidden_states.device()).dtype(at::ScalarType::Int));
|
|
at::Tensor expanded_idx_to_permuted_idx = at::detail::empty_cuda(
|
|
{args.num_tokens, args.top_k}, at::ScalarType::Int, hidden_states.device(), std::nullopt);
|
|
|
|
at::Tensor permuted_idx_to_token_idx = at::detail::empty_cuda(
|
|
{max_num_padded_tokens}, at::ScalarType::Int, hidden_states.device(), std::nullopt);
|
|
// at::Tensor expert_weights = at::detail::empty_cuda(
|
|
// {args.num_tokens, args.top_k}, routing_bias_dtype, hidden_states.device(), std::nullopt);
|
|
// at::Tensor expert_indexes = at::detail::empty_cuda(
|
|
// {args.num_tokens, args.top_k}, at::ScalarType::Int, hidden_states.device(), std::nullopt);
|
|
at::Tensor expert_count_histogram = at::detail::empty_cuda(
|
|
{2 * 256},
|
|
at::ScalarType::Int, // 256 is the max number of threads per block and max number of experts
|
|
hidden_states.device(), std::nullopt);
|
|
|
|
auto const sf_vec_size = dtype_weights == btg::Dtype::MxE2m1 ? 32 : 16;
|
|
|
|
// allocate workspace for activation/gemm/finalize kernels
|
|
auto const gemm1_output_hidden =
|
|
dtype_act == btg::Dtype::E2m1 ? intermediate_size / 2 : intermediate_size;
|
|
at::Tensor gemm1_output = at::detail::empty_cuda(
|
|
{max_num_padded_tokens, gemm1_output_hidden},
|
|
dtype_act == btg::Dtype::Bfloat16 ? at::ScalarType::BFloat16 : at::ScalarType::Float8_e4m3fn,
|
|
hidden_states.device(), std::nullopt);
|
|
|
|
std::optional<at::Tensor> gemm1_output_scale = std::nullopt;
|
|
if (dtype_act == btg::Dtype::E2m1 || dtype_act == btg::Dtype::MxE4m3) {
|
|
int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens,
|
|
intermediate_size / sf_vec_size);
|
|
gemm1_output_scale = at::detail::empty_cuda({sf_size}, at::ScalarType::Float8_e4m3fn,
|
|
hidden_states.device(), std::nullopt);
|
|
}
|
|
|
|
at::Tensor gemm2_output =
|
|
at::detail::empty_cuda({max_num_padded_tokens, args.hidden_size}, at::ScalarType::BFloat16,
|
|
hidden_states.device(), std::nullopt);
|
|
|
|
int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim(
|
|
args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim);
|
|
at::Tensor cta_idx_xy_to_batch_idx = at::detail::empty_cuda({max_num_ctas}, at::ScalarType::Int,
|
|
hidden_states.device(), std::nullopt);
|
|
at::Tensor cta_idx_xy_to_mn_limit = at::detail::empty_cuda({max_num_ctas}, at::ScalarType::Int,
|
|
hidden_states.device(), std::nullopt);
|
|
at::Tensor num_non_exiting_ctas =
|
|
at::empty({1}, at::TensorOptions().device(hidden_states.device()).dtype(at::ScalarType::Int));
|
|
|
|
//
|
|
// TopK routing
|
|
//
|
|
|
|
tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim);
|
|
auto const& stream = at::cuda::getCurrentCUDAStream(hidden_states.get_device());
|
|
routing_runner.run(
|
|
args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k,
|
|
args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts,
|
|
args.routed_scaling_factor, expert_indices.data_ptr<int>(),
|
|
expert_count_histogram.data_ptr<int>(), total_num_padded_tokens.data_ptr<int>(),
|
|
expanded_idx_to_permuted_idx.data_ptr<int>(),
|
|
nullptr, /*permuted_idx_to_expanded_idx.data_ptr<int>(),*/
|
|
permuted_idx_to_token_idx.data_ptr<int>(), expert_weights.data_ptr(),
|
|
num_tokens_per_expert.data_ptr<int>(), cta_idx_xy_to_batch_idx.data_ptr<int>(),
|
|
cta_idx_xy_to_mn_limit.data_ptr<int>(), num_non_exiting_ctas.data_ptr<int>(), args.mDtypeElt,
|
|
false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */,
|
|
static_cast<RoutingMethodType>(routing_method_type), stream);
|
|
|
|
//
|
|
// FC13 (gemm1) + FC2 (gemm2)
|
|
//
|
|
|
|
if (dtype_act == btg::Dtype::E2m1) {
|
|
TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Byte, "hidden_states must be byte.");
|
|
} else if (dtype_act == btg::Dtype::E4m3 || dtype_act == btg::Dtype::MxE4m3) {
|
|
TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn,
|
|
"hidden_states must be fp8.");
|
|
} else if (dtype_act == btg::Dtype::Bfloat16) {
|
|
TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::BFloat16,
|
|
"hidden_states must be bfloat16.");
|
|
} else {
|
|
TORCH_CHECK(false, "Invalid dtype_act");
|
|
}
|
|
|
|
if (hidden_states_scale.has_value()) {
|
|
TORCH_CHECK(hidden_states_scale.value().scalar_type() == at::ScalarType::Float8_e4m3fn,
|
|
"hidden_states_scale must be fp8.");
|
|
|
|
TORCH_CHECK(
|
|
hidden_states_scale.value().numel() == tensorrt_llm::computeLinearLayoutSFSize(
|
|
args.num_tokens, args.hidden_size / sf_vec_size),
|
|
"hidden_states_scale has incorrect size");
|
|
}
|
|
|
|
TORCH_CHECK(gemm1_weights.scalar_type() == torch_ext::FLOAT4_E2M1X2,
|
|
"gemm1_weights must be byte.");
|
|
|
|
TORCH_CHECK(gemm1_weights.dim() == 3, "gemm1_weights must be 3D.");
|
|
TORCH_CHECK(gemm1_weights.sizes()[1] % 2 == 0, "the second dimension of weights must be even.");
|
|
TORCH_CHECK(intermediate_size == gemm1_weights.sizes()[1] / 2,
|
|
"intermediate_size has incorrect dim 1.");
|
|
// This check passes even though the actual shape of the weights[2] and hidden_states[1] is
|
|
// 2 times larger due to the fact that 2 e2m1 are packed into 1 byte.
|
|
TORCH_CHECK(
|
|
gemm1_weights.sizes()[2] ==
|
|
(dtype_act == btg::Dtype::E2m1 ? hidden_states.sizes()[1] : hidden_states.sizes()[1] / 2),
|
|
"the third dimension of weights must be equal to hidden_size.");
|
|
|
|
TORCH_CHECK(gemm1_weights_scale.scalar_type() == at::ScalarType::Float8_e4m3fn,
|
|
"gemm1_weights_scale must be fp8.");
|
|
|
|
TORCH_CHECK(gemm1_weights_scale.dim() == 3, "gemm1_weights_scale must be 3D.");
|
|
TORCH_CHECK(gemm1_weights_scale.sizes()[0] == local_num_experts,
|
|
"gemm1_weights_scale has incorrect dim 0.");
|
|
TORCH_CHECK(intermediate_size % sf_vec_size == 0,
|
|
"the second dimension of weights must be a multiple of ", sf_vec_size);
|
|
TORCH_CHECK(gemm1_weights_scale.sizes()[1] == 2 * intermediate_size,
|
|
"gemm1_weights_scale has incorrect dim 1.");
|
|
TORCH_CHECK(gemm1_weights_scale.sizes()[2] == args.hidden_size / sf_vec_size,
|
|
"gemm1_weights_scale has incorrect dim 2.");
|
|
|
|
if (gemm1_bias.has_value()) {
|
|
TORCH_CHECK(gemm1_bias.value().scalar_type() == at::ScalarType::Float,
|
|
"gemm1_bias must be float, got ", c10::toString(gemm1_bias.value().scalar_type()));
|
|
TORCH_CHECK(gemm1_bias.value().dim() == 2, "gemm1_bias must be 2D.");
|
|
TORCH_CHECK(gemm1_bias.value().sizes()[0] == local_num_experts,
|
|
"gemm1_bias has incorrect dim 0.");
|
|
TORCH_CHECK(gemm1_bias.value().sizes()[1] == 2 * intermediate_size,
|
|
"gemm1_bias has incorrect dim 1.");
|
|
}
|
|
|
|
if (gemm1_alpha.has_value()) {
|
|
TORCH_CHECK(gemm1_alpha.value().scalar_type() == at::ScalarType::Float,
|
|
"gemm1_alpha must be float, got ",
|
|
c10::toString(gemm1_alpha.value().scalar_type()));
|
|
TORCH_CHECK(gemm1_alpha.value().dim() == 1, "gemm1_alpha must be 1D.");
|
|
TORCH_CHECK(gemm1_alpha.value().sizes()[0] == local_num_experts,
|
|
"gemm1_alpha has incorrect dim 0.");
|
|
}
|
|
if (gemm1_beta.has_value()) {
|
|
TORCH_CHECK(gemm1_beta.value().scalar_type() == at::ScalarType::Float,
|
|
"gemm1_beta must be float, got ", c10::toString(gemm1_beta.value().scalar_type()));
|
|
TORCH_CHECK(gemm1_beta.value().dim() == 1, "gemm1_beta must be 1D.");
|
|
TORCH_CHECK(gemm1_beta.value().sizes()[0] == local_num_experts,
|
|
"gemm1_beta has incorrect dim 0.");
|
|
}
|
|
|
|
TORCH_CHECK(gemm2_weights.scalar_type() == torch_ext::FLOAT4_E2M1X2,
|
|
"gemm2_weights must be byte.");
|
|
|
|
TORCH_CHECK(gemm2_weights.dim() == 3, "gemm2_weights must be 3D.");
|
|
// / 2 to compensate for the fact that we pack 2 e2m1 into 1 byte.
|
|
TORCH_CHECK(gemm2_weights.sizes()[2] == intermediate_size / 2,
|
|
"the third dimension of weights must be equal to intermediate_size.");
|
|
|
|
TORCH_CHECK(gemm2_weights_scale.scalar_type() == at::ScalarType::Float8_e4m3fn,
|
|
"gemm2_weights_scale must be fp8.");
|
|
|
|
TORCH_CHECK(gemm2_weights_scale.dim() == 3, "gemm2_weights_scale must be 3D.");
|
|
TORCH_CHECK(gemm2_weights_scale.sizes()[0] == local_num_experts,
|
|
"gemm2_weights_scale has incorrect dim 0.");
|
|
TORCH_CHECK(gemm2_weights_scale.sizes()[1] == args.hidden_size,
|
|
"gemm2_weights_scale has incorrect dim 1.");
|
|
TORCH_CHECK(gemm2_weights_scale.sizes()[2] == intermediate_size / sf_vec_size,
|
|
"gemm2_weights_scale has incorrect dim 2.");
|
|
|
|
if (output1_scales_scalar.has_value()) {
|
|
TORCH_CHECK(output1_scales_scalar.value().scalar_type() == at::ScalarType::Float,
|
|
"output1_scales_scalar must be float.");
|
|
TORCH_CHECK(output1_scales_scalar.value().dim() == 1, "output1_scales_scalar must be 1D.");
|
|
TORCH_CHECK(output1_scales_scalar.value().sizes()[0] == local_num_experts,
|
|
"output1_scales_scalar has incorrect dim 0.");
|
|
}
|
|
|
|
if (output1_scales_gate_scalar.has_value()) {
|
|
TORCH_CHECK(output1_scales_gate_scalar.value().scalar_type() == at::ScalarType::Float,
|
|
"output1_scales_gate_scalar must be float.");
|
|
TORCH_CHECK(output1_scales_gate_scalar.value().dim() == 1,
|
|
"output1_scales_gate_scalar must be 1D.");
|
|
TORCH_CHECK(output1_scales_gate_scalar.value().sizes()[0] == local_num_experts,
|
|
"output1_scales_gate_scalar has incorrect dim 0.");
|
|
}
|
|
|
|
if (output2_scales_scalar.has_value()) {
|
|
TORCH_CHECK(output2_scales_scalar.value().scalar_type() == at::ScalarType::Float,
|
|
"output2_scales_scalar must be float.");
|
|
TORCH_CHECK(output2_scales_scalar.value().dim() == 1, "output2_scales_scalar must be 1D.");
|
|
TORCH_CHECK(output2_scales_scalar.value().sizes()[0] == local_num_experts,
|
|
"output2_scales_scalar has incorrect dim 0.");
|
|
}
|
|
|
|
// setup workspace
|
|
workspace.total_num_padded_tokens = total_num_padded_tokens.data_ptr<int>();
|
|
workspace.total_max_padded_tokens = max_num_padded_tokens;
|
|
workspace.ProjUpTileN = tile_tokens_dim;
|
|
workspace.routing_expert_indexes = expert_indices.data_ptr<int>();
|
|
workspace.permuted_idx_size = total_num_padded_tokens.data_ptr<int>();
|
|
workspace.expanded_idx_to_permuted_idx =
|
|
expanded_idx_to_permuted_idx.data_ptr<int>(); // Needed by permute/finalize kernels
|
|
workspace.permuted_idx_to_token_idx =
|
|
permuted_idx_to_token_idx.data_ptr<int>(); // Needed by permuteGemm1 kernel
|
|
workspace.expert_weights = expert_weights.data_ptr(); // Consumed by finalize kernel
|
|
|
|
workspace.cta_idx_xy_to_batch_idx = cta_idx_xy_to_batch_idx.data_ptr<int>();
|
|
workspace.cta_idx_xy_to_mn_limit = cta_idx_xy_to_mn_limit.data_ptr<int>();
|
|
workspace.num_non_exiting_ctas = num_non_exiting_ctas.data_ptr<int>();
|
|
|
|
workspace.hidden_states_scale_linear = nullptr;
|
|
|
|
// gemm1 intermediate ws
|
|
workspace.gemm1_output = gemm1_output.data_ptr();
|
|
workspace.gemm1_output_scale =
|
|
gemm1_output_scale.has_value()
|
|
? reinterpret_cast<float*>(gemm1_output_scale.value().data_ptr())
|
|
: nullptr;
|
|
|
|
// gemm2 intermediate ws
|
|
workspace.gemm2_output = gemm2_output.data_ptr();
|
|
workspace.gemm2_output_scale = nullptr;
|
|
args.output = output.data_ptr();
|
|
args.output_scale = nullptr;
|
|
args.output1_scales_scalar =
|
|
output1_scales_scalar.has_value() ? output1_scales_scalar.value().data_ptr<float>() : nullptr;
|
|
args.output1_scales_gate_scalar = output1_scales_gate_scalar.has_value()
|
|
? output1_scales_gate_scalar.value().data_ptr<float>()
|
|
: nullptr;
|
|
args.output2_scales_scalar =
|
|
output2_scales_scalar.has_value() ? output2_scales_scalar.value().data_ptr<float>() : nullptr;
|
|
args.do_finalize = do_finalize;
|
|
|
|
auto const workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex);
|
|
|
|
at::Tensor workspace_fc1 = at::detail::empty_cuda(
|
|
{std::get<0>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt);
|
|
at::Tensor workspace_fc2 = at::detail::empty_cuda(
|
|
{std::get<1>(workspace_sizes)}, at::ScalarType::Char, hidden_states.device(), std::nullopt);
|
|
workspace.bmm1_workspace = workspace_fc1.data_ptr();
|
|
workspace.bmm2_workspace = workspace_fc2.data_ptr();
|
|
auto const& moe_stream = at::cuda::getCurrentCUDAStream(hidden_states.get_device());
|
|
moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex,
|
|
enable_pdl);
|
|
|
|
if (!do_finalize) {
|
|
return {gemm2_output, expert_weights, expanded_idx_to_permuted_idx};
|
|
}
|
|
return {output};
|
|
}
|
|
|
|
std::vector<at::Tensor> trtllm_fp4_block_scale_moe(
|
|
std::optional<at::Tensor> const& routing_logits, at::Tensor& topk_ids,
|
|
at::Tensor& expert_weights, std::optional<at::Tensor> const& routing_bias,
|
|
at::Tensor const& hidden_states, std::optional<at::Tensor> const& hidden_states_scale,
|
|
at::Tensor const& gemm1_weights, at::Tensor const& gemm1_weights_scale,
|
|
std::optional<at::Tensor> const& gemm1_bias, std::optional<at::Tensor> const& gemm1_alpha,
|
|
std::optional<at::Tensor> const& gemm1_beta, std::optional<at::Tensor> const& gemm1_clamp_limit,
|
|
at::Tensor const& gemm2_weights, at::Tensor const& gemm2_weights_scale,
|
|
std::optional<at::Tensor> const& gemm2_bias,
|
|
std::optional<at::Tensor> const& output1_scales_scalar,
|
|
std::optional<at::Tensor> const& output1_scales_gate_scalar,
|
|
std::optional<at::Tensor> const& output2_scales_scalar, int64_t num_experts, int64_t top_k,
|
|
std::optional<int64_t> n_group, std::optional<int64_t> topk_group, int64_t intermediate_size,
|
|
int64_t local_expert_offset, int64_t local_num_experts,
|
|
std::optional<double> routed_scaling_factor, int64_t tile_tokens_dim,
|
|
int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t gated_act_type,
|
|
at::Tensor& output, int64_t config_index) {
|
|
using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner;
|
|
|
|
int const num_tokens = hidden_states.sizes()[0];
|
|
int hidden_size = hidden_states.sizes()[1];
|
|
if (hidden_states.scalar_type() == torch_ext::FLOAT4_E2M1X2) hidden_size *= 2;
|
|
int hidden_states_scale_vec_size = -1;
|
|
if (hidden_states_scale.has_value()) {
|
|
hidden_states_scale_vec_size = (num_tokens * hidden_size) / hidden_states_scale.value().numel();
|
|
}
|
|
int weight_scale_vec_size =
|
|
(local_num_experts * intermediate_size * 2 * hidden_size) / gemm1_weights_scale.numel();
|
|
TORCH_CHECK(weight_scale_vec_size == 16 || weight_scale_vec_size == 32,
|
|
"unsupported weight_scale_vec_size.");
|
|
auto mDtypeWeights = weight_scale_vec_size == 16 ? btg::Dtype::E2m1 : btg::Dtype::MxE2m1;
|
|
|
|
TORCH_CHECK(gemm1_weights.scalar_type() == at::ScalarType::Byte &&
|
|
gemm2_weights.scalar_type() == at::ScalarType::Byte,
|
|
"weights must be fp4 packed in uint8.");
|
|
TORCH_CHECK(hidden_states.scalar_type() == torch_ext::FLOAT4_E2M1X2 ||
|
|
hidden_states.scalar_type() == at::ScalarType::BFloat16 ||
|
|
hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn,
|
|
"hidden_states must be bf16, fp8 or uint8 (packed fp4).");
|
|
auto mDtypeAct = btg::Dtype::Bfloat16;
|
|
if (hidden_states.scalar_type() == torch_ext::FLOAT4_E2M1X2) {
|
|
TORCH_CHECK(hidden_states_scale.has_value() &&
|
|
hidden_states_scale.value().scalar_type() == at::ScalarType::Float8_e4m3fn,
|
|
"hidden_states_scale must be provided for fp4 activation.");
|
|
if (hidden_states_scale_vec_size == 16) {
|
|
mDtypeAct = btg::Dtype::E2m1;
|
|
} else if (hidden_states_scale_vec_size == 32) {
|
|
mDtypeAct = btg::Dtype::MxE2m1;
|
|
} else {
|
|
TORCH_CHECK(false, "unsupported hidden_states_scale shape.");
|
|
}
|
|
} else if (hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn) {
|
|
if (hidden_states_scale.has_value()) {
|
|
if (hidden_states_scale_vec_size == 32) {
|
|
mDtypeAct = btg::Dtype::MxE4m3;
|
|
} else {
|
|
TORCH_CHECK(false, "unsupported hidden_states_scale shape.");
|
|
}
|
|
} else {
|
|
mDtypeAct = btg::Dtype::E4m3;
|
|
}
|
|
}
|
|
bool mUseDeepSeekFp8{false}; // FP4 doesn't use DeepSeek FP8
|
|
|
|
// Properly initialize the runner using make_unique like in the original code
|
|
auto mRunner = std::make_unique<RunnerType>(
|
|
mDtypeAct, mDtypeWeights, mUseDeepSeekFp8, (int32_t)tile_tokens_dim,
|
|
static_cast<GatedActType>(gated_act_type), /*useShuffledMatrixA*/ true);
|
|
|
|
if (config_index == -1) {
|
|
config_index = mRunner->getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
|
|
local_num_experts, num_tokens);
|
|
}
|
|
|
|
return trtllm_fp4_block_scale_moe_launcher(
|
|
routing_logits, topk_ids, expert_weights, routing_bias, hidden_states, hidden_states_scale,
|
|
gemm1_weights, gemm1_weights_scale, gemm1_bias, gemm1_alpha, gemm1_beta, gemm1_clamp_limit,
|
|
gemm2_weights, gemm2_weights_scale, gemm2_bias, output1_scales_scalar,
|
|
output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group,
|
|
intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor,
|
|
tile_tokens_dim, routing_method_type, do_finalize, *mRunner, mDtypeAct, mDtypeWeights,
|
|
config_index, enable_pdl, output);
|
|
}
|
|
|
|
int64_t trtllm_get_default_moe_configs(int64_t const tile_tokens_dim, int64_t const dtype_act_,
|
|
int64_t const dtype_weights_, bool const useDeepSeekFp8,
|
|
int64_t const top_k, int64_t const hidden_size,
|
|
int64_t const intermediate_size,
|
|
int64_t const num_local_experts,
|
|
int64_t const gated_act_type, int64_t const num_tokens) {
|
|
auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
|
|
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
|
|
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
|
|
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim,
|
|
static_cast<GatedActType>(gated_act_type), /*useShuffledMatrixA*/ true);
|
|
return moe_runner.getDefaultValidConfigIndex(top_k, hidden_size, intermediate_size,
|
|
num_local_experts, num_tokens);
|
|
}
|
|
|
|
std::vector<int64_t> trtllm_get_valid_moe_configs(
|
|
int64_t const tile_tokens_dim, int64_t const dtype_act_, int64_t const dtype_weights_,
|
|
bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size,
|
|
int64_t const intermediate_size, int64_t const num_local_experts, int64_t const gated_act_type,
|
|
int64_t const num_tokens) {
|
|
auto dtype_act = static_cast<btg::Dtype>(dtype_act_);
|
|
auto dtype_weights = static_cast<btg::Dtype>(dtype_weights_);
|
|
tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner(
|
|
dtype_act, dtype_weights, useDeepSeekFp8, (int32_t)tile_tokens_dim,
|
|
static_cast<GatedActType>(gated_act_type), /*useShuffledMatrixA*/ true);
|
|
return moe_runner.getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts,
|
|
num_tokens);
|
|
}
|
|
|
|
namespace trtllm_cubin_loader {
|
|
#include <flashinfer/cubin_loader.h>
|
|
}
|
|
|
|
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
|
|
m.def("trtllm_fp8_per_tensor_scale_moe", trtllm_fp8_per_tensor_scale_moe);
|
|
m.def("trtllm_fp8_block_scale_moe", trtllm_fp8_block_scale_moe);
|
|
m.def("trtllm_fp4_block_scale_moe", trtllm_fp4_block_scale_moe);
|
|
m.def("trtllm_get_default_moe_configs", trtllm_get_default_moe_configs);
|
|
m.def("trtllm_get_valid_moe_configs", trtllm_get_valid_moe_configs);
|
|
}
|
|
|
|
} // namespace flashinfer
|