#include #include #include #include #include "cutlass/array.h" constexpr uint64_t THREADS_PER_EXPERT = 512; __global__ void compute_problem_sizes( const int* __restrict__ topk_ids, int32_t* problem_sizes1, int32_t* problem_sizes2, int32_t* atomic_buffer, const int64_t topk_length, const int64_t n, const int64_t k) { int expert_id = blockIdx.x; int occurrences = 0; for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { occurrences += (topk_ids[i] == expert_id); } atomicAdd(&atomic_buffer[expert_id], occurrences); __syncthreads(); if (threadIdx.x == 0) { int final_occurrences = atomic_buffer[expert_id]; problem_sizes1[expert_id * 3] = final_occurrences; problem_sizes1[expert_id * 3 + 1] = static_cast(2 * n); problem_sizes1[expert_id * 3 + 2] = static_cast(k); problem_sizes2[expert_id * 3] = final_occurrences; problem_sizes2[expert_id * 3 + 1] = static_cast(k); problem_sizes2[expert_id * 3 + 2] = static_cast(n); } } __global__ void compute_expert_offsets( const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, int32_t* atomic_buffer, const int64_t num_experts) { int32_t tot_offset = 0; expert_offsets[0] = 0; for (int i = 0; i < num_experts; ++i) { atomic_buffer[i] = tot_offset; tot_offset += problem_sizes1[i * 3]; expert_offsets[i + 1] = tot_offset; } } __global__ void compute_expert_blockscale_offsets( const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets, int32_t* blockscale_offsets, int32_t* atomic_buffer, const int64_t num_experts) { int32_t tot_offset = 0; int32_t tot_rounded_offset = 0; expert_offsets[0] = 0; blockscale_offsets[0] = 0; for (int i = 0; i < num_experts; ++i) { atomic_buffer[i] = tot_offset; int num_tokens = problem_sizes1[i * 3]; int rounded_num_tokens = (num_tokens + (128 - 1)) / 128 * 128; tot_offset += num_tokens; tot_rounded_offset += rounded_num_tokens; expert_offsets[i + 1] = tot_offset; blockscale_offsets[i + 1] = tot_rounded_offset; } } __global__ void compute_arg_sorts( const int32_t* __restrict__ topk_ids, int32_t* input_permutation, int32_t* output_permutation, int32_t* atomic_buffer, const int64_t topk_length, const int64_t topk) { int expert_id = blockIdx.x; for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) { if (topk_ids[i] == expert_id) { int start = atomicAdd(&atomic_buffer[expert_id], 1); input_permutation[start] = i / topk; output_permutation[i] = start; } } } void get_moe_prepare_input_caller( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const std::optional& blockscale_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& input_permutation, torch::Tensor& output_permutation, const int64_t num_experts, const int64_t n, const int64_t k) { auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index()); auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device()); torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32); uint32_t num_threads = static_cast(min(THREADS_PER_EXPERT, topk_ids.numel())); uint32_t num_blocks = static_cast(num_experts); compute_problem_sizes<<>>( static_cast(topk_ids.data_ptr()), static_cast(problem_sizes1.data_ptr()), static_cast(problem_sizes2.data_ptr()), static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); if (blockscale_offsets.has_value()) { compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>( static_cast(problem_sizes1.data_ptr()), static_cast(expert_offsets.data_ptr()), static_cast(blockscale_offsets.value().data_ptr()), static_cast(atomic_buffer.data_ptr()), num_experts); } else { compute_expert_offsets<<<1, 1, 0, stream>>>( static_cast(problem_sizes1.data_ptr()), static_cast(expert_offsets.data_ptr()), static_cast(atomic_buffer.data_ptr()), num_experts); } compute_arg_sorts<<>>( static_cast(topk_ids.data_ptr()), static_cast(input_permutation.data_ptr()), static_cast(output_permutation.data_ptr()), static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), topk_ids.size(1)); } void prepare_moe_input( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, const std::optional& blockscale_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& input_permutation, torch::Tensor& output_permutation, const int64_t num_experts, const int64_t n, const int64_t k) { TORCH_CHECK(topk_ids.dtype() == torch::kInt32); get_moe_prepare_input_caller( topk_ids, expert_offsets, blockscale_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k); return; } template __global__ void shuffleRowsKernel( const T* input, const int32_t* dst2src_map, T* output, int64_t num_src_rows, int64_t num_dst_rows, int64_t num_cols) { int64_t dest_row_idx = blockIdx.x; int64_t const source_row_idx = dst2src_map[dest_row_idx]; if (blockIdx.x < num_dst_rows) { // Load 128-bits per thread constexpr uint64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8; using DataElem = cutlass::Array; // Duplicate and permute rows auto const* source_row_ptr = reinterpret_cast(input + source_row_idx * num_cols); auto* dest_row_ptr = reinterpret_cast(output + dest_row_idx * num_cols); auto const start_offset = threadIdx.x; auto const stride = blockDim.x; auto const num_elems_in_col = num_cols / ELEM_PER_THREAD; for (auto elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { dest_row_ptr[elem_index] = source_row_ptr[elem_index]; } } } #define DECLARE_SHUFFLE_ROWS(T) \ __global__ void shuffleRowsKernel( \ const T* input, \ const int32_t* dst2src_map, \ T* output, \ int64_t num_src_rows, \ int64_t num_dest_rows, \ int64_t num_cols); DECLARE_SHUFFLE_ROWS(float); DECLARE_SHUFFLE_ROWS(half); DECLARE_SHUFFLE_ROWS(__nv_bfloat16); DECLARE_SHUFFLE_ROWS(__nv_fp8_e4m3); DECLARE_SHUFFLE_ROWS(uint8_t); #define SHUFFLE_ROWS(T) \ shuffleRowsKernel<<>>( \ reinterpret_cast(input), \ static_cast(dst2src_map.data_ptr()), \ reinterpret_cast(output), \ num_src_rows, \ num_dst_rows, \ num_cols) #define DTYPE_DISPATCH_CASE(T, CUDA_T) \ case T: \ SHUFFLE_ROWS(CUDA_T); \ break; void shuffle_rows_caller( const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) { TORCH_CHECK( input_tensor.scalar_type() == output_tensor.scalar_type(), "Input and output tensors must have the same data type"); auto stream = at::cuda::getCurrentCUDAStream().stream(); uint32_t blocks = static_cast(output_tensor.size(0)); uint32_t threads = 256; int64_t num_dst_rows = output_tensor.size(0); int64_t num_src_rows = input_tensor.size(0); int64_t num_cols = input_tensor.size(1); const void* input = input_tensor.data_ptr(); void* output = output_tensor.data_ptr(); switch (input_tensor.scalar_type()) { DTYPE_DISPATCH_CASE(torch::kFloat16, half); DTYPE_DISPATCH_CASE(torch::kBFloat16, __nv_bfloat16); DTYPE_DISPATCH_CASE(torch::kFloat32, float); DTYPE_DISPATCH_CASE(torch::kFloat8_e4m3fn, __nv_fp8_e4m3); DTYPE_DISPATCH_CASE(torch::kUInt8, uint8_t); default: TORCH_CHECK(false, "[moe replicate input] data type dispatch fail!"); } return; } void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) { shuffle_rows_caller(input_tensor, dst2src_map, output_tensor); return; } template __global__ void apply_shuffle_mul_sum_kernel( const scalar_t* __restrict__ input_tensor, // [m * topk, row_stride] scalar_t* __restrict__ output_tensor, // [m, row_stride] const int32_t* __restrict__ permutation, // [m * topk] int m, int topk, int row_stride, const scalar_t* __restrict__ factors) // [m * topk] or nullptr { int i = blockIdx.x; // [0, m * topk) int d = threadIdx.x; // [0, row_stride) if (i >= m || d >= row_stride) return; scalar_t sum_val = 0.0; for (int j = 0; j < topk; ++j) { int index_2d = i * topk + j; int src_row = permutation[index_2d]; if (src_row >= m) continue; scalar_t val = input_tensor[src_row * row_stride + d]; scalar_t factor = 1.0; if (factors != nullptr) { factor = factors[index_2d]; } sum_val += factor * val; } output_tensor[i * row_stride + d] = sum_val; } void get_apply_shuffle_mul_sum_caller( const torch::Tensor& input_tensor, // [m * topk, row_stride], bf16/f16 torch::Tensor& output_tensor, // [m, row_stride], bf16/f16 const torch::Tensor& permutation, // [m * topk], int32 const std::optional& factors_opt) // optional [m * topk], bf16/f16 { TORCH_CHECK(input_tensor.dim() == 2, "input_tensor must be 2D [m * topk, row_stride]"); TORCH_CHECK(output_tensor.dim() == 2, "output_tensor must be 2D [m, row_stride]"); TORCH_CHECK(permutation.dim() == 1, "permutation must be 1D [m * topk]"); int m = output_tensor.size(0); int topk = int(permutation.size(0) / m); int row_stride = output_tensor.size(1); TORCH_CHECK(permutation.size(0) == m * topk, "permutation size must match m * topk"); dim3 block(std::min(256, row_stride)); dim3 grid(m); // blockIdx.x = j, blockIdx.y = i auto stream = at::cuda::getCurrentCUDAStream(input_tensor.device().index()); const int32_t* perm_ptr = permutation.data_ptr(); void* factors_ptr = nullptr; if (factors_opt.has_value()) { TORCH_CHECK(factors_opt->dtype() == output_tensor.dtype(), "Factors must match output dtype"); TORCH_CHECK(factors_opt->numel() == m * topk, "Factors must have shape [m * topk]"); factors_ptr = factors_opt->data_ptr(); } if (output_tensor.scalar_type() == at::ScalarType::Half) { const at::Half* factor_data = static_cast(factors_ptr); apply_shuffle_mul_sum_kernel<<>>( input_tensor.data_ptr(), output_tensor.data_ptr(), perm_ptr, m, topk, row_stride, static_cast(factors_ptr)); } else if (output_tensor.scalar_type() == at::ScalarType::BFloat16) { const c10::BFloat16* factor_data = static_cast(factors_ptr); apply_shuffle_mul_sum_kernel<<>>( input_tensor.data_ptr(), output_tensor.data_ptr(), perm_ptr, m, topk, row_stride, static_cast(factors_ptr)); } else { TORCH_CHECK(false, "Unsupported output dtype for cast+mul kernel: ", output_tensor.scalar_type()); } } /** * @brief Applies a permutation-based shuffle, element-wise multiplication, and reduction over the second dimension. * * This function performs the equivalent of the following PyTorch expression: * * (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1) * * Specifically: * - `input` is shuffled using the `permutation` tensor. * - The shuffled tensor is reshaped and multiplied element-wise with `factors` (e.g., top-k weights). * - The result is summed along dimension 1 (the top-k dimension), and stored in `output`. * * @param input Input tensor of shape (m * topk, k), representing c2. * @param output Output tensor of shape (m, k), where the final reduced results are stored. * @param permutation Index tensor (e.g., c_map) that maps positions in `input` to shuffled layout. * @param factors Optional scaling factors (e.g., top-k weights), shape (m * topk) or (m, topk). */ void apply_shuffle_mul_sum( const torch::Tensor& input, torch::Tensor& output, const torch::Tensor& permutation, const std::optional& factors) { get_apply_shuffle_mul_sum_caller(input, output, permutation, factors); }