369 lines
13 KiB
Plaintext
Executable File
369 lines
13 KiB
Plaintext
Executable File
#include <c10/cuda/CUDAGuard.h>
|
|
#include <cudaTypedefs.h>
|
|
#include <torch/all.h>
|
|
|
|
#include <iostream>
|
|
|
|
#include "cutlass/array.h"
|
|
|
|
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
|
|
|
__global__ void compute_problem_sizes(
|
|
const int* __restrict__ topk_ids,
|
|
int32_t* problem_sizes1,
|
|
int32_t* problem_sizes2,
|
|
int32_t* atomic_buffer,
|
|
const int64_t topk_length,
|
|
const int64_t n,
|
|
const int64_t k) {
|
|
int expert_id = blockIdx.x;
|
|
|
|
int occurrences = 0;
|
|
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
|
occurrences += (topk_ids[i] == expert_id);
|
|
}
|
|
atomicAdd(&atomic_buffer[expert_id], occurrences);
|
|
__syncthreads();
|
|
|
|
if (threadIdx.x == 0) {
|
|
int final_occurrences = atomic_buffer[expert_id];
|
|
problem_sizes1[expert_id * 3] = final_occurrences;
|
|
problem_sizes1[expert_id * 3 + 1] = static_cast<int32_t>(2 * n);
|
|
problem_sizes1[expert_id * 3 + 2] = static_cast<int32_t>(k);
|
|
problem_sizes2[expert_id * 3] = final_occurrences;
|
|
problem_sizes2[expert_id * 3 + 1] = static_cast<int32_t>(k);
|
|
problem_sizes2[expert_id * 3 + 2] = static_cast<int32_t>(n);
|
|
}
|
|
}
|
|
|
|
__global__ void compute_expert_offsets(
|
|
const int32_t* __restrict__ problem_sizes1,
|
|
int32_t* expert_offsets,
|
|
int32_t* atomic_buffer,
|
|
const int64_t num_experts) {
|
|
int32_t tot_offset = 0;
|
|
expert_offsets[0] = 0;
|
|
for (int i = 0; i < num_experts; ++i) {
|
|
atomic_buffer[i] = tot_offset;
|
|
tot_offset += problem_sizes1[i * 3];
|
|
expert_offsets[i + 1] = tot_offset;
|
|
}
|
|
}
|
|
|
|
__global__ void compute_expert_blockscale_offsets(
|
|
const int32_t* __restrict__ problem_sizes1,
|
|
int32_t* expert_offsets,
|
|
int32_t* blockscale_offsets,
|
|
int32_t* atomic_buffer,
|
|
const int64_t num_experts) {
|
|
int32_t tot_offset = 0;
|
|
int32_t tot_rounded_offset = 0;
|
|
expert_offsets[0] = 0;
|
|
blockscale_offsets[0] = 0;
|
|
for (int i = 0; i < num_experts; ++i) {
|
|
atomic_buffer[i] = tot_offset;
|
|
int num_tokens = problem_sizes1[i * 3];
|
|
int rounded_num_tokens = (num_tokens + (128 - 1)) / 128 * 128;
|
|
tot_offset += num_tokens;
|
|
tot_rounded_offset += rounded_num_tokens;
|
|
expert_offsets[i + 1] = tot_offset;
|
|
blockscale_offsets[i + 1] = tot_rounded_offset;
|
|
}
|
|
}
|
|
|
|
__global__ void compute_arg_sorts(
|
|
const int32_t* __restrict__ topk_ids,
|
|
int32_t* input_permutation,
|
|
int32_t* output_permutation,
|
|
int32_t* atomic_buffer,
|
|
const int64_t topk_length,
|
|
const int64_t topk) {
|
|
int expert_id = blockIdx.x;
|
|
|
|
for (int i = threadIdx.x; i < topk_length; i += THREADS_PER_EXPERT) {
|
|
if (topk_ids[i] == expert_id) {
|
|
int start = atomicAdd(&atomic_buffer[expert_id], 1);
|
|
input_permutation[start] = i / topk;
|
|
output_permutation[i] = start;
|
|
}
|
|
}
|
|
}
|
|
|
|
void get_moe_prepare_input_caller(
|
|
const torch::Tensor& topk_ids,
|
|
torch::Tensor& expert_offsets,
|
|
const std::optional<torch::Tensor>& blockscale_offsets,
|
|
torch::Tensor& problem_sizes1,
|
|
torch::Tensor& problem_sizes2,
|
|
torch::Tensor& input_permutation,
|
|
torch::Tensor& output_permutation,
|
|
const int64_t num_experts,
|
|
const int64_t n,
|
|
const int64_t k) {
|
|
auto stream = at::cuda::getCurrentCUDAStream(topk_ids.device().index());
|
|
auto options_int32 = torch::TensorOptions().dtype(torch::kInt32).device(topk_ids.device());
|
|
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
|
|
|
uint32_t num_threads = static_cast<uint32_t>(min(THREADS_PER_EXPERT, topk_ids.numel()));
|
|
uint32_t num_blocks = static_cast<uint32_t>(num_experts);
|
|
|
|
compute_problem_sizes<<<num_blocks, num_threads, 0, stream>>>(
|
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
|
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
|
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
|
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
|
topk_ids.numel(),
|
|
n,
|
|
k);
|
|
if (blockscale_offsets.has_value()) {
|
|
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
|
|
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
|
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
|
|
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
|
num_experts);
|
|
} else {
|
|
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
|
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
|
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
|
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
|
num_experts);
|
|
}
|
|
compute_arg_sorts<<<num_blocks, num_threads, 0, stream>>>(
|
|
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
|
static_cast<int32_t*>(input_permutation.data_ptr()),
|
|
static_cast<int32_t*>(output_permutation.data_ptr()),
|
|
static_cast<int32_t*>(atomic_buffer.data_ptr()),
|
|
topk_ids.numel(),
|
|
topk_ids.size(1));
|
|
}
|
|
|
|
void prepare_moe_input(
|
|
const torch::Tensor& topk_ids,
|
|
torch::Tensor& expert_offsets,
|
|
const std::optional<torch::Tensor>& blockscale_offsets,
|
|
torch::Tensor& problem_sizes1,
|
|
torch::Tensor& problem_sizes2,
|
|
torch::Tensor& input_permutation,
|
|
torch::Tensor& output_permutation,
|
|
const int64_t num_experts,
|
|
const int64_t n,
|
|
const int64_t k) {
|
|
TORCH_CHECK(topk_ids.dtype() == torch::kInt32);
|
|
get_moe_prepare_input_caller(
|
|
topk_ids,
|
|
expert_offsets,
|
|
blockscale_offsets,
|
|
problem_sizes1,
|
|
problem_sizes2,
|
|
input_permutation,
|
|
output_permutation,
|
|
num_experts,
|
|
n,
|
|
k);
|
|
return;
|
|
}
|
|
|
|
template <typename T>
|
|
__global__ void shuffleRowsKernel(
|
|
const T* input,
|
|
const int32_t* dst2src_map,
|
|
T* output,
|
|
int64_t num_src_rows,
|
|
int64_t num_dst_rows,
|
|
int64_t num_cols) {
|
|
int64_t dest_row_idx = blockIdx.x;
|
|
int64_t const source_row_idx = dst2src_map[dest_row_idx];
|
|
|
|
if (blockIdx.x < num_dst_rows) {
|
|
// Load 128-bits per thread
|
|
constexpr uint64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8;
|
|
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
|
|
|
|
// Duplicate and permute rows
|
|
auto const* source_row_ptr = reinterpret_cast<DataElem const*>(input + source_row_idx * num_cols);
|
|
auto* dest_row_ptr = reinterpret_cast<DataElem*>(output + dest_row_idx * num_cols);
|
|
|
|
auto const start_offset = threadIdx.x;
|
|
auto const stride = blockDim.x;
|
|
auto const num_elems_in_col = num_cols / ELEM_PER_THREAD;
|
|
|
|
for (auto elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) {
|
|
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
|
|
}
|
|
}
|
|
}
|
|
|
|
#define DECLARE_SHUFFLE_ROWS(T) \
|
|
__global__ void shuffleRowsKernel( \
|
|
const T* input, \
|
|
const int32_t* dst2src_map, \
|
|
T* output, \
|
|
int64_t num_src_rows, \
|
|
int64_t num_dest_rows, \
|
|
int64_t num_cols);
|
|
|
|
DECLARE_SHUFFLE_ROWS(float);
|
|
DECLARE_SHUFFLE_ROWS(half);
|
|
DECLARE_SHUFFLE_ROWS(__nv_bfloat16);
|
|
DECLARE_SHUFFLE_ROWS(__nv_fp8_e4m3);
|
|
DECLARE_SHUFFLE_ROWS(uint8_t);
|
|
|
|
#define SHUFFLE_ROWS(T) \
|
|
shuffleRowsKernel<T><<<blocks, threads, 0, stream>>>( \
|
|
reinterpret_cast<const T*>(input), \
|
|
static_cast<const int32_t*>(dst2src_map.data_ptr()), \
|
|
reinterpret_cast<T*>(output), \
|
|
num_src_rows, \
|
|
num_dst_rows, \
|
|
num_cols)
|
|
|
|
#define DTYPE_DISPATCH_CASE(T, CUDA_T) \
|
|
case T: \
|
|
SHUFFLE_ROWS(CUDA_T); \
|
|
break;
|
|
|
|
void shuffle_rows_caller(
|
|
const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
|
|
TORCH_CHECK(
|
|
input_tensor.scalar_type() == output_tensor.scalar_type(),
|
|
"Input and output tensors must have the same data type");
|
|
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
|
uint32_t blocks = static_cast<uint32_t>(output_tensor.size(0));
|
|
uint32_t threads = 256;
|
|
int64_t num_dst_rows = output_tensor.size(0);
|
|
int64_t num_src_rows = input_tensor.size(0);
|
|
int64_t num_cols = input_tensor.size(1);
|
|
const void* input = input_tensor.data_ptr();
|
|
void* output = output_tensor.data_ptr();
|
|
switch (input_tensor.scalar_type()) {
|
|
DTYPE_DISPATCH_CASE(torch::kFloat16, half);
|
|
DTYPE_DISPATCH_CASE(torch::kBFloat16, __nv_bfloat16);
|
|
DTYPE_DISPATCH_CASE(torch::kFloat32, float);
|
|
DTYPE_DISPATCH_CASE(torch::kFloat8_e4m3fn, __nv_fp8_e4m3);
|
|
DTYPE_DISPATCH_CASE(torch::kUInt8, uint8_t);
|
|
default:
|
|
TORCH_CHECK(false, "[moe replicate input] data type dispatch fail!");
|
|
}
|
|
return;
|
|
}
|
|
|
|
void shuffle_rows(const torch::Tensor& input_tensor, const torch::Tensor& dst2src_map, torch::Tensor& output_tensor) {
|
|
shuffle_rows_caller(input_tensor, dst2src_map, output_tensor);
|
|
return;
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
__global__ void apply_shuffle_mul_sum_kernel(
|
|
const scalar_t* __restrict__ input_tensor, // [m * topk, row_stride]
|
|
scalar_t* __restrict__ output_tensor, // [m, row_stride]
|
|
const int32_t* __restrict__ permutation, // [m * topk]
|
|
int m,
|
|
int topk,
|
|
int row_stride,
|
|
const scalar_t* __restrict__ factors) // [m * topk] or nullptr
|
|
{
|
|
int i = blockIdx.x; // [0, m * topk)
|
|
int d = threadIdx.x; // [0, row_stride)
|
|
|
|
if (i >= m || d >= row_stride) return;
|
|
|
|
scalar_t sum_val = 0.0;
|
|
|
|
for (int j = 0; j < topk; ++j) {
|
|
int index_2d = i * topk + j;
|
|
int src_row = permutation[index_2d];
|
|
if (src_row >= m) continue;
|
|
|
|
scalar_t val = input_tensor[src_row * row_stride + d];
|
|
|
|
scalar_t factor = 1.0;
|
|
if (factors != nullptr) {
|
|
factor = factors[index_2d];
|
|
}
|
|
|
|
sum_val += factor * val;
|
|
}
|
|
|
|
output_tensor[i * row_stride + d] = sum_val;
|
|
}
|
|
|
|
void get_apply_shuffle_mul_sum_caller(
|
|
const torch::Tensor& input_tensor, // [m * topk, row_stride], bf16/f16
|
|
torch::Tensor& output_tensor, // [m, row_stride], bf16/f16
|
|
const torch::Tensor& permutation, // [m * topk], int32
|
|
const std::optional<torch::Tensor>& factors_opt) // optional [m * topk], bf16/f16
|
|
{
|
|
TORCH_CHECK(input_tensor.dim() == 2, "input_tensor must be 2D [m * topk, row_stride]");
|
|
TORCH_CHECK(output_tensor.dim() == 2, "output_tensor must be 2D [m, row_stride]");
|
|
TORCH_CHECK(permutation.dim() == 1, "permutation must be 1D [m * topk]");
|
|
|
|
int m = output_tensor.size(0);
|
|
int topk = int(permutation.size(0) / m);
|
|
int row_stride = output_tensor.size(1);
|
|
|
|
TORCH_CHECK(permutation.size(0) == m * topk, "permutation size must match m * topk");
|
|
|
|
dim3 block(std::min(256, row_stride));
|
|
dim3 grid(m); // blockIdx.x = j, blockIdx.y = i
|
|
auto stream = at::cuda::getCurrentCUDAStream(input_tensor.device().index());
|
|
|
|
const int32_t* perm_ptr = permutation.data_ptr<int32_t>();
|
|
|
|
void* factors_ptr = nullptr;
|
|
if (factors_opt.has_value()) {
|
|
TORCH_CHECK(factors_opt->dtype() == output_tensor.dtype(), "Factors must match output dtype");
|
|
TORCH_CHECK(factors_opt->numel() == m * topk, "Factors must have shape [m * topk]");
|
|
factors_ptr = factors_opt->data_ptr();
|
|
}
|
|
|
|
if (output_tensor.scalar_type() == at::ScalarType::Half) {
|
|
const at::Half* factor_data = static_cast<const at::Half*>(factors_ptr);
|
|
apply_shuffle_mul_sum_kernel<at::Half><<<grid, block, 0, stream>>>(
|
|
input_tensor.data_ptr<at::Half>(),
|
|
output_tensor.data_ptr<at::Half>(),
|
|
perm_ptr,
|
|
m,
|
|
topk,
|
|
row_stride,
|
|
static_cast<const at::Half*>(factors_ptr));
|
|
} else if (output_tensor.scalar_type() == at::ScalarType::BFloat16) {
|
|
const c10::BFloat16* factor_data = static_cast<const c10::BFloat16*>(factors_ptr);
|
|
apply_shuffle_mul_sum_kernel<c10::BFloat16><<<grid, block, 0, stream>>>(
|
|
input_tensor.data_ptr<c10::BFloat16>(),
|
|
output_tensor.data_ptr<c10::BFloat16>(),
|
|
perm_ptr,
|
|
m,
|
|
topk,
|
|
row_stride,
|
|
static_cast<const c10::BFloat16*>(factors_ptr));
|
|
} else {
|
|
TORCH_CHECK(false, "Unsupported output dtype for cast+mul kernel: ", output_tensor.scalar_type());
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Applies a permutation-based shuffle, element-wise multiplication, and reduction over the second dimension.
|
|
*
|
|
* This function performs the equivalent of the following PyTorch expression:
|
|
*
|
|
* (c2[c_map].view(m, topk, k) * topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
|
|
*
|
|
* Specifically:
|
|
* - `input` is shuffled using the `permutation` tensor.
|
|
* - The shuffled tensor is reshaped and multiplied element-wise with `factors` (e.g., top-k weights).
|
|
* - The result is summed along dimension 1 (the top-k dimension), and stored in `output`.
|
|
*
|
|
* @param input Input tensor of shape (m * topk, k), representing c2.
|
|
* @param output Output tensor of shape (m, k), where the final reduced results are stored.
|
|
* @param permutation Index tensor (e.g., c_map) that maps positions in `input` to shuffled layout.
|
|
* @param factors Optional scaling factors (e.g., top-k weights), shape (m * topk) or (m, topk).
|
|
*/
|
|
void apply_shuffle_mul_sum(
|
|
const torch::Tensor& input,
|
|
torch::Tensor& output,
|
|
const torch::Tensor& permutation,
|
|
const std::optional<torch::Tensor>& factors) {
|
|
get_apply_shuffle_mul_sum_caller(input, output, permutation, factors);
|
|
}
|