sglang_v0.5.2/flashinfer_0.3.1/csrc/trtllm_mnnvl_allreduce.cu

102 lines
5.5 KiB
Plaintext

#include "flashinfer/comm/trtllm_mnnvl_allreduce.cuh"
#include "pytorch_extension_utils.h"
using namespace flashinfer::trtllm_mnnvl_allreduce;
#define DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(scalar_type, c_type, ...) \
[&] { \
switch (scalar_type) { \
case at::ScalarType::Float: { \
using c_type = float; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using c_type = half; \
return __VA_ARGS__(); \
} \
case at::ScalarType::BFloat16: { \
using c_type = __nv_bfloat16; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported dtype in DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE: ", \
scalar_type); \
} \
}()
void trtllm_mnnvl_all_reduce(at::Tensor& in, int64_t multicast_buffer_ptr, int64_t buffer_ptrs_dev,
int64_t buffer_M, at::Tensor& buffer_flags_mnnvl, int64_t nranks,
int64_t rank, bool wait_for_results, bool launch_with_pdl,
std::optional<at::Tensor> out) {
const c10::cuda::OptionalCUDAGuard device_guard(in.device());
auto stream = at::cuda::getCurrentCUDAStream();
DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(in.scalar_type(), c_type, [&] {
// Extract parameters from tensors
int64_t num_tokens = in.size(0);
int64_t token_dim = in.size(1);
// Validate input parameters
TORCH_CHECK(token_dim % (sizeof(float2) / sizeof(c_type)) == 0,
"token_dim must be divisible by ", sizeof(float2) / sizeof(c_type));
TORCH_CHECK(nranks >= 2 && nranks <= 64, "nranks must be between 2 and 64, got ", nranks);
TORCH_CHECK(rank >= 0 && rank < nranks, "rank must be between 0 and nranks-1, got ", rank);
TORCH_CHECK(out.has_value() || !wait_for_results,
"out tensor must be provided if wait_for_results is true");
// Create the parameters struct
AllReduceParams<c_type> params;
params.nranks = nranks;
params.rank = rank;
params.buffer_M = buffer_M;
params.num_tokens = num_tokens;
params.token_dim = token_dim;
params.buffer_ptrs_dev = reinterpret_cast<void**>(buffer_ptrs_dev);
params.multicast_ptr = reinterpret_cast<void*>(multicast_buffer_ptr);
params.buffer_flags = buffer_flags_mnnvl.data_ptr();
params.wait_for_results = wait_for_results;
params.launch_with_pdl = launch_with_pdl;
params.input = in.data_ptr();
params.output = out.has_value() ? out.value().data_ptr() : nullptr;
params.stream = stream.stream();
auto status = twoshot_allreduce_dispatch_world_size<c_type>(params);
TORCH_CHECK(status == cudaSuccess,
"twoshot_allreduce_dispatch_world_size failed with error code ",
cudaGetErrorString(status));
});
}
void trtllm_mnnvl_rmsnorm(int64_t multicast_buffer_ptr, at::Tensor& prenorm_output,
at::Tensor& normed_output, at::Tensor const& gamma, double epsilon,
at::Tensor const& residual, at::Tensor& buffer_flags,
bool launch_with_pdl) {
const c10::cuda::OptionalCUDAGuard device_guard(prenorm_output.device());
auto stream = at::cuda::getCurrentCUDAStream();
DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(prenorm_output.scalar_type(), c_type, [&] {
// Create the parameters struct
RMSNormParams<c_type> params;
params.residual_output = prenorm_output.data_ptr();
params.output = normed_output.data_ptr();
params.input = reinterpret_cast<void const*>(multicast_buffer_ptr);
params.gamma = gamma.data_ptr();
params.epsilon = epsilon;
params.residual = residual.data_ptr();
params.buffer_flags = reinterpret_cast<uint32_t*>(buffer_flags.data_ptr());
params.batch = normed_output.size(0);
params.hidden_dim = normed_output.size(1);
params.stream = stream.stream();
params.launch_with_pdl = launch_with_pdl;
auto status = twoshot_rmsnorm_dispatch_hidden_dim<c_type>(params);
TORCH_CHECK(status == cudaSuccess,
"twoshot_rmsnorm_dispatch_hidden_dim failed with error code ",
cudaGetErrorString(status));
});
}
TORCH_LIBRARY_FRAGMENT(TORCH_EXTENSION_NAME, m) {
m.def("trtllm_mnnvl_all_reduce", &trtllm_mnnvl_all_reduce);
m.def("trtllm_mnnvl_rmsnorm", &trtllm_mnnvl_rmsnorm);
}