sglang.0.4.8.post1/sglang/sgl-kernel/csrc/cpu/shm.cpp

667 lines
23 KiB
C++

#include "shm.h"
#include <ATen/ATen.h>
#include <errno.h>
#include <fcntl.h>
#include <immintrin.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <unistd.h>
// states for collectives
enum coll_state {
coll_begin = 0,
coll_allreduce_naive__copy_in_done,
coll_allreduce_naive__reduce_done,
// alternative state when allreduce is working on alternative buffer
// of the double buffer.
coll_alt1_allreduce_naive__copy_in_done,
coll_alt2_allreduce_naive__copy_in_done,
coll_alt1_allreduce_naive__reduce_done,
coll_allgather_naive__copy_in_done,
coll_alt1_allgather_naive__copy_in_done,
coll_alt2_allgather_naive__copy_in_done,
};
// SHM building blocks
struct SharedData {
const char* name;
int descriptor;
void* bytes;
size_t nbytes;
};
void shared_open(SharedData* data, const char* name, size_t nbytes) {
int d = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR);
if (d != -1) {
void* bytes = mmap(NULL, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, d, 0);
data->name = name;
data->descriptor = d;
data->bytes = bytes;
data->nbytes = nbytes;
} else {
if (errno != ENOENT) {
// don't print if shm can not be found because we want to loop over from
// caller again until the other ranks created the shm
printf("shared_open %s failed, errno=%d\n", name, errno);
}
data->descriptor = -1;
}
}
void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) {
int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR);
if (d != -1) {
nbytes = write(d, bytes, nbytes);
if (nbytes > 0) {
shared_open(data, name, nbytes);
}
} else {
printf("shared_create %s failed\n", name);
}
}
static int world_size;
// SHM based allreduce helper functions
// buffer that holds shm name
#define NAME_BUF_SIZE 1000
#define MAX_BUF_SIZE 1048576 * 32
#define NAIVE_ALLREDUCE_THRESHOLD 1048576
#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer"
struct allreduce_workspace {
enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce
// idx=1 -- state for distributed_naive_all_reduce
// double buffer to avoid syncing between rounds
// offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for
// symmetric_naive_all_reduce after that : buffer for
// distributed_naive_all_reduce
char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE];
};
#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD
#define BUFFER1_OFFSET(current_buffer) 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE
struct allreduce_workspace** workspace;
// buffer for small messages, double buffer
char** symmetric_buffer[2];
// buffer for large messages, double buffer
char** distributed_buffer[2];
void wait_buffer_state_until_2(int index, enum coll_state state0, enum coll_state state1, int state_group) {
volatile enum coll_state* state_ptr = &(workspace[index]->states[state_group]);
while (1) {
volatile enum coll_state cur_state = *state_ptr;
if (cur_state == state0 || cur_state == state1) break;
}
}
__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
inline __m512 cvt_bf16_to_fp32(const __m256i src) {
auto y = _mm512_cvtepu16_epi32(src);
return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2));
}
inline __m256i cvt_fp32_to_bf16(const __m512 src) __attribute__((target("avx512bw")));
inline __m256i cvt_fp32_to_bf16(const __m512 src) {
__m512i value = _mm512_castps_si512(src);
__m512i nan = _mm512_set1_epi32(0xffff);
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
__m512i ones = _mm512_set1_epi32(0x1);
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
// uint32_t lsb = (input >> 16) & 1;
auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
// uint32_t rounding_bias = 0x7fff + lsb;
t_value = _mm512_add_epi32(t_value, vec_bias);
// input += rounding_bias;
t_value = _mm512_add_epi32(t_value, value);
// input = input >> 16;
t_value = _mm512_srli_epi32(t_value, 16);
// Check NaN before converting back to bf16
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
return _mm512_cvtusepi32_epi16(t_value);
}
__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
inline __m512 cvt_fp16_to_fp32(const __m256i src) {
return _mm512_cvtph_ps(src);
}
inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw")));
inline __m256i cvt_fp32_to_fp16(const __m512 src) {
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
__attribute__((target("avx512bw")));
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
__attribute__((target("avx512bw")));
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
__attribute__((target("avx512bw")));
void reduce_all_buffers(
int start_elements,
int num_elements,
c10::ScalarType scalar_type,
int to_buffer_idx,
char* to_buffer,
char** buffers) {
switch (scalar_type) {
case c10::ScalarType::BFloat16:
reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers);
break;
case c10::ScalarType::Half:
reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers);
break;
case c10::ScalarType::Float:
reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers);
break;
default:
assert(!"Should not get here");
}
}
#define CVT_ADD_BF16(x) \
do { \
auto in##x##_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
} while (0)
// Reduce functions down below use vectorized algorithm, the number of bytes
// processed each iteration depends on vector length. 256bit vector ==> 32
// bytes, 512bit vector ==> 64 bytes If you change implementation of
// reduce_bf16_buffers, etc. , check whether this number needs to be changed
#define VECTOR_LENGTH_IN_BYTES 32
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
const int element_size = 2;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
int main_elements = num_elements - (num_elements % vector_length);
int remain_elements = num_elements % vector_length;
// process aligned part
#pragma omp parallel for
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
i += VECTOR_LENGTH_IN_BYTES) {
auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
switch (world_size) {
case 16:
CVT_ADD_BF16(15);
case 15:
CVT_ADD_BF16(14);
case 14:
CVT_ADD_BF16(13);
case 13:
CVT_ADD_BF16(12);
case 12:
CVT_ADD_BF16(11);
case 11:
CVT_ADD_BF16(10);
case 10:
CVT_ADD_BF16(9);
case 9:
CVT_ADD_BF16(8);
case 8:
CVT_ADD_BF16(7);
case 7:
CVT_ADD_BF16(6);
case 6:
CVT_ADD_BF16(5);
case 5:
CVT_ADD_BF16(4);
case 4:
CVT_ADD_BF16(3);
case 3:
CVT_ADD_BF16(2);
case 2:
CVT_ADD_BF16(1);
case 1:
break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
inout_val = _mm512_add_ps(inout_val, in_val);
}
}
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val));
}
// process remaining part
int i = (start_elements + main_elements) * element_size;
while (remain_elements > 0) {
float val = 0.0f;
for (int j = 0; j < world_size; j++) {
val += *(at::BFloat16*)(buffers[j] + i);
}
*(at::BFloat16*)(to_buffer + i) = val;
remain_elements--;
i += element_size;
}
}
#define CVT_ADD_FP16(x) \
do { \
auto in##x##_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
} while (0)
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
const int element_size = 2;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
int main_elements = num_elements - (num_elements % vector_length);
int remain_elements = num_elements % vector_length;
// process aligned part
#pragma omp parallel for
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
i += VECTOR_LENGTH_IN_BYTES) {
auto inout_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
switch (world_size) {
case 16:
CVT_ADD_FP16(15);
case 15:
CVT_ADD_FP16(14);
case 14:
CVT_ADD_FP16(13);
case 13:
CVT_ADD_FP16(12);
case 12:
CVT_ADD_FP16(11);
case 11:
CVT_ADD_FP16(10);
case 10:
CVT_ADD_FP16(9);
case 9:
CVT_ADD_FP16(8);
case 8:
CVT_ADD_FP16(7);
case 7:
CVT_ADD_FP16(6);
case 6:
CVT_ADD_FP16(5);
case 5:
CVT_ADD_FP16(4);
case 4:
CVT_ADD_FP16(3);
case 3:
CVT_ADD_FP16(2);
case 2:
CVT_ADD_FP16(1);
case 1:
break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
inout_val = _mm512_add_ps(inout_val, in_val);
}
}
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val));
}
// process remaining part
int i = (start_elements + main_elements) * element_size;
while (remain_elements > 0) {
float val = 0.0f;
for (int j = 0; j < world_size; j++) {
val += *(at::Half*)(buffers[j] + i);
}
*(at::Half*)(to_buffer + i) = val;
remain_elements--;
i += element_size;
}
}
#define CVT_ADD_F32(x) \
do { \
auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \
inout_val = _mm256_add_ps(inout_val, in##x##_val); \
} while (0)
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
const int element_size = 4;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
int main_elements = num_elements - (num_elements % vector_length);
int remain_elements = num_elements % vector_length;
// process aligned part
#pragma omp parallel for
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
i += VECTOR_LENGTH_IN_BYTES) {
auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i));
switch (world_size) {
case 16:
CVT_ADD_F32(15);
case 15:
CVT_ADD_F32(14);
case 14:
CVT_ADD_F32(13);
case 13:
CVT_ADD_F32(12);
case 12:
CVT_ADD_F32(11);
case 11:
CVT_ADD_F32(10);
case 10:
CVT_ADD_F32(9);
case 9:
CVT_ADD_F32(8);
case 8:
CVT_ADD_F32(7);
case 7:
CVT_ADD_F32(6);
case 6:
CVT_ADD_F32(5);
case 5:
CVT_ADD_F32(4);
case 4:
CVT_ADD_F32(3);
case 3:
CVT_ADD_F32(2);
case 2:
CVT_ADD_F32(1);
case 1:
break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i));
inout_val = _mm256_add_ps(inout_val, in_val);
}
}
_mm256_storeu_ps((float*)(to_buffer + i), inout_val);
}
// process remaining part
int i = (start_elements + main_elements) * element_size;
while (remain_elements > 0) {
float val = 0.0f;
for (int j = 0; j < world_size; j++) {
val += *(float*)(buffers[j] + i);
}
*(float*)(to_buffer + i) = val;
remain_elements--;
i += element_size;
}
}
static bool is_initialized = false;
static int world_rank;
void shm_initialize(int size, int rank, const char* addr_string, const char* port_string) {
if (is_initialized) {
return;
}
is_initialized = true;
world_size = size;
world_rank = rank;
char shm_name_prefix[NAME_BUF_SIZE];
char shm_name[NAME_BUF_SIZE];
snprintf(shm_name_prefix, NAME_BUF_SIZE, "%s_%d_%s_%s", SHM_BUFFER_NAME, getuid(), addr_string, port_string);
// create shared workspace for SHM based allreduce
SharedData allreduce_buffer;
// allocate workspace_buf for current rank
struct allreduce_workspace* workspace_buf;
struct allreduce_workspace* workspace_buf_other;
workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace));
snprintf(shm_name, NAME_BUF_SIZE, "%.900s_%d", shm_name_prefix, rank);
shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace));
workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes;
workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done;
workspace_buf->states[1] = coll_begin;
// create the workspace pointer list
workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*));
symmetric_buffer[0] = (char**)malloc(size * sizeof(char**));
symmetric_buffer[1] = (char**)malloc(size * sizeof(char**));
distributed_buffer[0] = (char**)malloc(size * sizeof(char**));
distributed_buffer[1] = (char**)malloc(size * sizeof(char**));
// map shm of all ranks
for (int i = 0; i < size; i++) {
if (i != rank) {
snprintf(shm_name, NAME_BUF_SIZE, "%.900s_%d", shm_name_prefix, i);
// printf("open %s, %d\n", shm_name, rank);
do {
shared_open(&allreduce_buffer, shm_name, sizeof(struct allreduce_workspace));
} while (allreduce_buffer.descriptor == -1 && errno == ENOENT);
workspace_buf_other = (struct allreduce_workspace*)allreduce_buffer.bytes;
workspace[i] = workspace_buf_other;
} else {
workspace[i] = workspace_buf;
}
symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0);
symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1);
distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0);
distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1);
}
}
static void parallel_memcpy(void* to, void* from, size_t n_bytes) __attribute__((target("avx512bw")));
static void parallel_memcpy(void* to, void* from, size_t n_bytes) {
auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES);
// process aligned part
#pragma omp parallel for
for (size_t i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) {
auto val = _mm256_loadu_si256((__m256i*)((char*)from + i));
_mm256_storeu_si256((__m256i*)((char*)to + i), val);
}
// process remaining part
for (size_t i = aligned_bytes; i < n_bytes; i++) {
*((char*)to + i) = *((char*)from + i);
}
}
#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod))
#define rank_mod(rank) positive_mod(rank, world_size)
size_t slice_size(size_t chunk_el, int slice_idx) {
size_t slice_size = chunk_el / world_size;
return slice_idx == world_size - 1 ? slice_size + (chunk_el % world_size) : slice_size;
}
char* slice_data(char* data_ptr, size_t chunk_el, int el_size, int slice_idx) {
size_t slice_size = chunk_el / world_size;
size_t el_offset = slice_size * slice_idx;
return data_ptr + el_offset * el_size;
}
size_t slice_el_start(size_t chunk_el, int slice_idx) {
size_t slice_size = chunk_el / world_size;
return slice_size * slice_idx;
}
void symmetric_naive_all_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) {
const int state_group = 0;
static int current_buffer = 0;
static int state_idx = 0;
// init states to case 0 to get rid of "maybe-uninitialized" warning.
enum coll_state copy_current = coll_allreduce_naive__copy_in_done;
enum coll_state copy_next = coll_alt1_allreduce_naive__copy_in_done;
switch (state_idx) {
case 0:
copy_current = coll_allreduce_naive__copy_in_done;
copy_next = coll_alt1_allreduce_naive__copy_in_done;
break;
case 1:
copy_current = coll_alt1_allreduce_naive__copy_in_done;
copy_next = coll_alt2_allreduce_naive__copy_in_done;
break;
case 2:
copy_current = coll_alt2_allreduce_naive__copy_in_done;
copy_next = coll_allreduce_naive__copy_in_done;
break;
default:
assert(!"Should not get here.");
}
state_idx = (state_idx + 1) % 3;
parallel_memcpy(symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = copy_current;
for (int i = 0; i < world_size; i++) {
// wait until the other rank copy the buffer
if (i != world_rank) {
wait_buffer_state_until_2(i, copy_current, copy_next, state_group);
}
}
// each rank reduce the buffer independently so therre is no need for
// synchronization afterward
reduce_all_buffers(0, chunk_el, scalar_type, world_rank, data_ptr, symmetric_buffer[current_buffer]);
// switch buffer
current_buffer = 1 - current_buffer;
}
// naive allreduce distributed, each rank do naive reduce on its slice
void distributed_naive_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) {
const int state_group = 1;
static int current_buffer = 0;
static int state_idx = 0;
// init states to case 0 to get rid of "maybe-uninitialized" warning.
enum coll_state copy_current = coll_allreduce_naive__copy_in_done;
enum coll_state reduce_current = coll_allreduce_naive__reduce_done;
enum coll_state copy_next = coll_alt1_allreduce_naive__copy_in_done;
// similar to symmetric_naive_allreduce, but here we only need two sets of
// states, because distributed naive reduce has two barriers in the algorithm
switch (state_idx) {
case 0:
copy_current = coll_allreduce_naive__copy_in_done;
reduce_current = coll_allreduce_naive__reduce_done;
copy_next = coll_alt1_allreduce_naive__copy_in_done;
break;
case 1:
copy_current = coll_alt1_allreduce_naive__copy_in_done;
reduce_current = coll_alt1_allreduce_naive__reduce_done;
copy_next = coll_allreduce_naive__copy_in_done;
break;
default:
assert(!"Should not get here.");
}
state_idx = (state_idx + 1) % 2;
int data_size = chunk_size / chunk_el;
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = copy_current;
for (int i = 0; i < world_size; i++) {
// wait until all the other ranks copy the buffer
if (i != world_rank) wait_buffer_state_until_2(i, copy_current, reduce_current, state_group);
}
// reduce scatter
reduce_all_buffers(
slice_el_start(chunk_el, world_rank),
slice_size(chunk_el, world_rank),
scalar_type,
world_rank,
distributed_buffer[current_buffer][world_rank],
distributed_buffer[current_buffer]);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = reduce_current;
for (int i = 0; i < world_size; i++) {
// wait until all the other ranks reduce the buffer
if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group);
}
for (int i = 0; i < world_size; i++) {
int rank = (i + world_rank) % world_size;
parallel_memcpy(
slice_data(data_ptr, chunk_el, data_size, rank),
slice_data(distributed_buffer[current_buffer][rank], chunk_el, chunk_size / chunk_el, rank),
slice_size(chunk_el, rank) * data_size);
}
current_buffer = 1 - current_buffer;
}
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size) {
for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) {
auto data_ptr = ((char*)(data.data_ptr()) + offset);
size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset;
size_t chunk_el = chunk_size / (data_size / numel);
if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) {
symmetric_naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
} else {
distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
}
}
}
void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_t chunk_size, size_t chunk_el) {
const int state_group = 1;
static int current_buffer = 0;
static int state_idx = 0;
// init states to case 0 to get rid of "maybe-uninitialized" warning.
enum coll_state copy_current = coll_allgather_naive__copy_in_done;
enum coll_state copy_next = coll_alt1_allgather_naive__copy_in_done;
switch (state_idx) {
case 0:
copy_current = coll_allgather_naive__copy_in_done;
copy_next = coll_alt1_allgather_naive__copy_in_done;
break;
case 1:
copy_current = coll_alt1_allgather_naive__copy_in_done;
copy_next = coll_alt2_allgather_naive__copy_in_done;
break;
case 2:
copy_current = coll_alt2_allgather_naive__copy_in_done;
copy_next = coll_allgather_naive__copy_in_done;
break;
default:
assert(!"Should not get here.");
}
state_idx = (state_idx + 1) % 3;
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = copy_current;
for (int i = 0; i < world_size; i++) {
// wait until all the other ranks copy the buffer
if (i != world_rank) wait_buffer_state_until_2(i, copy_current, copy_next, state_group);
}
for (int i = 0; i < world_size; i++) {
parallel_memcpy(result_ptr + i * res_stride, distributed_buffer[current_buffer][i], chunk_size);
}
current_buffer = 1 - current_buffer;
}
torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size) {
size_t dim_el = data.stride(dim) * data.size(dim);
int dtype_size = data_size / numel;
size_t dim_size = dim_el * dtype_size;
int dim_count = data_size / dim_size;
auto data_ptr = (char*)(data.data_ptr());
auto result_ptr = (char*)(result.data_ptr());
for (int i = 0; i < dim_count; i++) {
for (size_t offset = 0; offset < dim_size; offset += MAX_BUF_SIZE) {
size_t chunk_size = dim_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : dim_size - offset;
size_t chunk_el = chunk_size / dtype_size;
naive_all_gather(
result_ptr + i * dim_size * world_size + offset,
data_ptr + i * dim_size + offset,
dim_size,
chunk_size,
chunk_el);
}
}
return result;
}