#include #include "flashinfer/comm/trtllm_moe_allreduce_fusion.cuh" #include "pytorch_extension_utils.h" using namespace flashinfer::trtllm_moe_allreduce_fusion; #define DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(scalar_type, c_type, ...) \ [&] { \ switch (scalar_type) { \ case at::ScalarType::Half: { \ using c_type = half; \ return __VA_ARGS__(); \ } \ case at::ScalarType::BFloat16: { \ using c_type = __nv_bfloat16; \ return __VA_ARGS__(); \ } \ default: \ TORCH_CHECK(false, \ "Unsupported dtype in DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE: ", scalar_type); \ } \ }() void trtllm_moe_allreduce_fusion( int64_t world_size, int64_t world_rank, int64_t token_num, int64_t hidden_size, at::Tensor& workspace_ptrs, bool launch_with_pdl, at::Tensor& residual_in, at::Tensor& rms_gamma, double rms_eps, double scale_factor, int64_t moe_reduction_device_num_experts, at::Tensor& moe_reduction_scale_input, at::Tensor& moe_reduction_active_experts_token_input, at::Tensor& moe_reduction_token_input, std::optional layout_code, std::optional moe_allreduce_out, std::optional residual_out, std::optional norm_out, std::optional quant_out, std::optional scale_out) { const c10::cuda::OptionalCUDAGuard device_guard( moe_reduction_active_experts_token_input.device()); auto stream = at::cuda::getCurrentCUDAStream(); DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE( moe_reduction_active_experts_token_input.scalar_type(), c_type, [&] { MoeReductionAllReduceFusionParams params; params.nranks = world_size; params.rank = world_rank; params.size = token_num * hidden_size; params.hidden_dim = hidden_size; params.workspace = reinterpret_cast(workspace_ptrs.data_ptr()); params.moe_allreduce_out = moe_allreduce_out.has_value() ? reinterpret_cast(moe_allreduce_out.value().data_ptr()) : nullptr; params.residual_in = reinterpret_cast(residual_in.data_ptr()); params.residual_out = residual_out.has_value() ? reinterpret_cast(residual_out.value().data_ptr()) : nullptr; params.norm_out = norm_out.has_value() ? reinterpret_cast(norm_out.value().data_ptr()) : nullptr; params.quant_out = quant_out.has_value() ? reinterpret_cast(quant_out.value().data_ptr()) : nullptr; params.scale_out = scale_out.has_value() ? reinterpret_cast(scale_out.value().data_ptr()) : nullptr; params.rms_gamma = reinterpret_cast(rms_gamma.data_ptr()); params.rms_eps = static_cast(rms_eps); params.scale_factor = static_cast(scale_factor); params.layout = layout_code.has_value() ? static_cast(layout_code.value()) : QuantizationSFLayout::SWIZZLED_128x4; params.stream = stream; params.moe_reduction_device_num_experts = moe_reduction_device_num_experts; params.moe_reduction_scale_input = reinterpret_cast(moe_reduction_scale_input.data_ptr()); params.moe_reduction_active_experts_token_input = reinterpret_cast(moe_reduction_active_experts_token_input.data_ptr()); params.moe_reduction_token_input = reinterpret_cast(moe_reduction_token_input.data_ptr()); auto status = moereduction_allreduce_fusion_op(params, launch_with_pdl); TORCH_CHECK(status == cudaSuccess, "moereduction_allreduce_fusion_op failed with error code ", cudaGetErrorString(status)); }); } void trtllm_moe_finalize_allreduce_fusion( at::Tensor const& allreduce_in, at::Tensor const& residual_in, at::Tensor const& norm_weight, at::Tensor const& expanded_idx_to_permuted_idx, at::Tensor& norm_out, at::Tensor& residual_out, bool launch_with_pdl, at::Tensor& workspace, int64_t const world_rank, int64_t const world_size, double const eps, std::optional const& shared_expert_output, std::optional const& expert_scale_factor) { DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(residual_in.scalar_type(), c_type, [&] { MoeFinalizeAllReduceFusionParams params; int hidden_dim = residual_in.size(-1); int top_k = expanded_idx_to_permuted_idx.size(-1); params.quant_out = nullptr; params.scale_out = nullptr; params.nranks = static_cast(world_size); params.rank = static_cast(world_rank); // size: num_token * hidden_dim params.size = residual_in.numel(); params.hidden_dim = hidden_dim; // workspace: AR scratch space params.workspace = reinterpret_cast(workspace.mutable_data_ptr()); params.rms_gamma = norm_weight.data_ptr(); params.rms_eps = static_cast(eps); params.residual_in = residual_in.data_ptr(); params.stream = at::cuda::getCurrentCUDAStream(norm_weight.get_device()); // MOE Reduction specific params params.top_k = top_k; params.allreduce_in = allreduce_in.data_ptr(); params.expert_scale_factor = expert_scale_factor.has_value() ? expert_scale_factor.value().data_ptr() : nullptr; TORCH_CHECK(expanded_idx_to_permuted_idx.scalar_type() == at::ScalarType::Int, "expanded_idx_to_permuted_idx must be int32"); params.expanded_idx_to_permuted_idx = static_cast(expanded_idx_to_permuted_idx.data_ptr()); params.shared_expert_output = shared_expert_output.has_value() ? shared_expert_output.value().data_ptr() : nullptr; // output tensors params.norm_out = norm_out.mutable_data_ptr(); params.residual_out = residual_out.mutable_data_ptr(); auto status = moefinalize_allreduce_fusion_op(params, launch_with_pdl); TORCH_CHECK(status == cudaSuccess, "moefinalize_allreduce_fusion_op failed with error code ", cudaGetErrorString(status)); }); } TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) { m.def("trtllm_moe_allreduce_fusion", &trtllm_moe_allreduce_fusion); m.def("trtllm_moe_finalize_allreduce_fusion", &trtllm_moe_finalize_allreduce_fusion); }