sglang.0.4.8.post1/sglang/sgl-kernel/csrc/allreduce/mscclpp_allreduce.cu

141 lines
5.1 KiB
Plaintext

#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
#include <torch/library.h>
#include "mscclpp_allreduce.cuh"
enum MscclContextSelection {
MSCCL1NODELL = 1,
MSCCL2NODELL = 2,
};
class MscclContext {
public:
MscclContextSelection selection_;
std::shared_ptr<sglang::Msccl1NodeLLcontext> msccl_1nodeLL_context;
std::shared_ptr<sglang::Msccl2NodeLLcontext> msccl_2nodeLL_context;
MscclContext(MscclContextSelection selection) : selection_(selection) {}
template <typename T>
void allreduce(
cudaStream_t stream, T* input, T* output, const size_t input_numel, int threads = 512, int block_limit = 21) {
if (selection_ == MSCCL1NODELL) {
msccl_1nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
} else if (selection_ == MSCCL2NODELL) {
msccl_2nodeLL_context->allreduce<T>(stream, input, output, input_numel, threads, block_limit);
}
}
};
using fptr_t = int64_t;
static_assert(sizeof(void*) == sizeof(fptr_t));
torch::Tensor _unique_id2tensor(const mscclpp::UniqueId& unique_id) {
auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCPU);
auto tensor = torch::empty({static_cast<int64_t>(unique_id.size())}, options);
std::memcpy(tensor.data_ptr<uint8_t>(), unique_id.data(), unique_id.size());
return tensor;
}
// Function to convert vector of int32_t back to array of uint8_t
mscclpp::UniqueId _tensor2unique_id(const torch::Tensor& tensor) {
mscclpp::UniqueId unique_id;
std::memcpy(unique_id.data(), tensor.data_ptr<uint8_t>(), unique_id.size());
return unique_id;
}
torch::Tensor mscclpp_generate_unique_id() {
mscclpp::UniqueId unique_id = mscclpp::TcpBootstrap::createUniqueId();
return _unique_id2tensor(unique_id);
}
fptr_t mscclpp_init_context(
const torch::Tensor& unique_id,
const int64_t rank,
const int64_t world_size,
torch::Tensor& scratch,
torch::Tensor& put_buffer,
const int64_t nranks_per_node,
const std::vector<int64_t>& rank_to_node,
const std::vector<int64_t>& rank_to_ib,
const int64_t context_selection) {
MscclContext* context_ptr = new MscclContext(static_cast<MscclContextSelection>(context_selection));
mscclpp::UniqueId uid = _tensor2unique_id(unique_id);
if (context_selection == MSCCL1NODELL) {
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
context_ptr->msccl_1nodeLL_context = std::make_shared<sglang::Msccl1NodeLLcontext>(
uid, rank, world_size, scratch_ptr, scratch_bytes, nranks_per_node, rank_to_node, rank_to_ib);
} else if (context_selection == MSCCL2NODELL) {
void* scratch_ptr = reinterpret_cast<void*>(scratch.data_ptr());
const size_t scratch_bytes = scratch.numel() * scratch.element_size();
void* put_buffer_ptr = reinterpret_cast<void*>(put_buffer.data_ptr());
const size_t put_buffer_bytes = put_buffer.numel() * put_buffer.element_size();
context_ptr->msccl_2nodeLL_context = std::make_shared<sglang::Msccl2NodeLLcontext>(
uid,
rank,
world_size,
scratch_ptr,
scratch_bytes,
put_buffer_ptr,
put_buffer_bytes,
nranks_per_node,
rank_to_node,
rank_to_ib);
} else {
throw std::runtime_error("invalid context selection");
}
return (fptr_t)context_ptr;
}
bool _mscclpp_is_weak_contiguous(torch::Tensor& t) {
return t.is_contiguous() ||
(t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size());
}
void mscclpp_allreduce(fptr_t _context, torch::Tensor& inp, torch::Tensor& out, int64_t nthreads, int64_t nblocks) {
MscclContext* context = reinterpret_cast<MscclContext*>(_context);
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
auto stream = c10::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
TORCH_CHECK_EQ(inp.numel(), out.numel());
TORCH_CHECK(_mscclpp_is_weak_contiguous(out));
TORCH_CHECK(_mscclpp_is_weak_contiguous(inp));
switch (out.scalar_type()) {
case at::ScalarType::Float: {
context->allreduce<float>(
stream,
reinterpret_cast<float*>(inp.data_ptr()),
reinterpret_cast<float*>(out.data_ptr()),
inp.numel(),
nthreads,
nblocks);
break;
}
case at::ScalarType::Half: {
context->allreduce<half>(
stream,
reinterpret_cast<half*>(inp.data_ptr()),
reinterpret_cast<half*>(out.data_ptr()),
inp.numel(),
nthreads,
nblocks);
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case at::ScalarType::BFloat16: {
context->allreduce<__nv_bfloat16>(
stream,
reinterpret_cast<__nv_bfloat16*>(inp.data_ptr()),
reinterpret_cast<__nv_bfloat16*>(out.data_ptr()),
inp.numel(),
nthreads,
nblocks);
break;
}
#endif
default:
throw std::runtime_error("custom allreduce only supports float32, float16 and bfloat16");
}
}