/* Copyright 2025 SGLang Team. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ // reference: https://github.com/NVIDIA/TensorRT-LLM/blob/release/0.14/cpp/tensorrt_llm/kernels/customAllReduceKernels.h #include #include #include "trt_reduce_internal.cuh" #include "utils.h" using namespace trt_llm; using fptr_t = int64_t; using IPC_KEY = std::array; class AllReduceMeta { public: AllReduceMeta( int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, const std::vector& tmp_result_buffers, const std::vector& barrier_in, const std::vector& barrier_out) { this->rank_id = (int)rank_id; this->world_size = (int)world_size; this->barrier_in = barrier_in; this->barrier_out = barrier_out; this->tmp_result_buffers = tmp_result_buffers; this->rank_data_base = reinterpret_cast(rank_data.data_ptr()); RankData data; for (int i = 0; i < world_size; i++) { data.ptrs[i] = (void*)buffers[i]; } auto d_data = this->rank_data_base++; CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); this->buffers = d_data; } ~AllReduceMeta() { for (auto [_, ptr] : ipc_handles_) { CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr)); } } public: int world_size; int rank_id; std::vector barrier_in; std::vector barrier_out; std::vector tmp_result_buffers; int barrier_flag = 1; RankData* buffers; RankData* rank_data_base; std::vector graph_unreg_buffers; std::map ipc_handles_; }; // Get the number of bits for a given data type. inline int get_bits(at::ScalarType dtype) { switch (dtype) { case at::ScalarType::Float: return 32; case at::ScalarType::Half: case at::ScalarType::BFloat16: return 16; default: assert(false && "Unsupported data type"); } } // Check if customized all-reduce kernels can be applied. inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) { // The customized all-reduce kernel has the following requirement(s). return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0; } fptr_t init_custom_ar( int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, const std::vector& tmp_result_buffers, const std::vector& barrier_in, const std::vector& barrier_out) { auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out); return (fptr_t)m; } void dispose(fptr_t _fa) { auto fa = reinterpret_cast(_fa); delete fa; } std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa) { AllReduceMeta* m = reinterpret_cast(_fa); auto num_buffers = m->graph_unreg_buffers.size(); auto handle_sz = sizeof(cudaIpcMemHandle_t); std::string handles(handle_sz * num_buffers, static_cast(0)); std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = m->graph_unreg_buffers[i]; void* base_ptr; // note: must share the base address of each allocation, or we get wrong // address if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS) { assert(false && "failed to get pointer attr"); } CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); offsets[i] = ((char*)ptr) - ((char*)base_ptr); } std::vector bytes(handles.begin(), handles.end()); return std::make_pair(bytes, offsets); } char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) { auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); if (new_handle) { char* ipc_ptr; CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle( (void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), cudaIpcMemLazyEnablePeerAccess)); it->second = ipc_ptr; } return it->second; } // Note: when registering graph buffers, we intentionally choose to not // deduplicate the addresses. That means if the allocator reuses some // addresses, they will be registered again. This is to account for the remote // possibility of different allocation patterns between ranks. For example, // rank 1 may get the same input address for the second allreduce, but rank 2 // got a different address. IPC handles have internal reference counting // mechanism so overhead should be small. void register_graph_buffers( fptr_t _fa, const std::vector>& handles, const std::vector>& offsets) { AllReduceMeta* m = reinterpret_cast(_fa); std::vector handle_bytes; handle_bytes.reserve(handles.size()); for (int i = 0; i < handles.size(); i++) { handle_bytes.emplace_back(handles[i].begin(), handles[i].end()); } auto num_buffers = m->graph_unreg_buffers.size(); std::vector rank_data(num_buffers); for (int i = 0; i < num_buffers; i++) { auto self_ptr = m->graph_unreg_buffers[i]; auto& rd = rank_data[i]; for (int j = 0; j < m->world_size; j++) { if (j != m->rank_id) { char* handle = open_ipc_handle(m, &handle_bytes[j][i * sizeof(cudaIpcMemHandle_t)]); handle += offsets[j][i]; rd.ptrs[j] = handle; } else { rd.ptrs[j] = self_ptr; } } } CHECK_CUDA_SUCCESS( cudaMemcpy(m->rank_data_base, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice)); m->rank_data_base += num_buffers; m->graph_unreg_buffers.clear(); } void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { AllReduceMeta* m = reinterpret_cast(_fa); auto stream = c10::cuda::getCurrentCUDAStream().stream(); auto num_elements = inp.numel(); auto dtype = inp.scalar_type(); AllReduceStrategyType strategy = SelectImplementation(num_elements * ((get_bits(dtype) + 7) / 8), m->world_size); // should be gurantee in python code assert(strategy == AllReduceStrategyType::ONESHOT || strategy == AllReduceStrategyType::TWOSHOT); assert(CanApplyCustomAllReduce(num_elements, dtype)); // Initialize the all-reduce kernel arguments. int world_size = m->world_size; AllReduceParams params; params.ranks_per_node = world_size; params.rank = m->rank_id; params.local_rank = m->rank_id; params.local_input_buffer_ptr = inp.data_ptr(); params.local_output_buffer_ptr = out.data_ptr(); params.elts_total = inp.numel(); params.elts_size = inp.element_size(); params.barrier_flag = ++(m->barrier_flag); cudaStreamCaptureStatus status; CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status)); params.is_capturing = (status == cudaStreamCaptureStatusActive); if (params.is_capturing) { params.peer_comm_buffer_ptrs = m->rank_data_base + m->graph_unreg_buffers.size(); m->graph_unreg_buffers.push_back(params.local_input_buffer_ptr); } else { params.peer_comm_buffer_ptrs = m->buffers; } for (int i = 0; i < world_size; ++i) { params.tmp_result_buffers[i] = reinterpret_cast(m->tmp_result_buffers[i]); } for (int i = 0; i < world_size; ++i) { params.peer_barrier_ptrs_in[i] = reinterpret_cast(m->barrier_in[i]); } for (int i = 0; i < world_size; ++i) { params.peer_barrier_ptrs_out[i] = reinterpret_cast(m->barrier_out[i]); } auto data_type = out.scalar_type(); trtCustomAllReduce(params, data_type, strategy, stream); }