667 lines
23 KiB
C++
667 lines
23 KiB
C++
#include "shm.h"
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <errno.h>
|
|
#include <fcntl.h>
|
|
#include <immintrin.h>
|
|
#include <stddef.h>
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
#include <sys/mman.h>
|
|
#include <unistd.h>
|
|
|
|
// states for collectives
|
|
enum coll_state {
|
|
coll_begin = 0,
|
|
coll_allreduce_naive__copy_in_done,
|
|
coll_allreduce_naive__reduce_done,
|
|
// alternative state when allreduce is working on alternative buffer
|
|
// of the double buffer.
|
|
coll_alt1_allreduce_naive__copy_in_done,
|
|
coll_alt2_allreduce_naive__copy_in_done,
|
|
coll_alt1_allreduce_naive__reduce_done,
|
|
coll_allgather_naive__copy_in_done,
|
|
coll_alt1_allgather_naive__copy_in_done,
|
|
coll_alt2_allgather_naive__copy_in_done,
|
|
};
|
|
|
|
// SHM building blocks
|
|
struct SharedData {
|
|
const char* name;
|
|
int descriptor;
|
|
void* bytes;
|
|
size_t nbytes;
|
|
};
|
|
|
|
void shared_open(SharedData* data, const char* name, size_t nbytes) {
|
|
int d = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR);
|
|
if (d != -1) {
|
|
void* bytes = mmap(NULL, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, d, 0);
|
|
data->name = name;
|
|
data->descriptor = d;
|
|
data->bytes = bytes;
|
|
data->nbytes = nbytes;
|
|
} else {
|
|
if (errno != ENOENT) {
|
|
// don't print if shm can not be found because we want to loop over from
|
|
// caller again until the other ranks created the shm
|
|
printf("shared_open %s failed, errno=%d\n", name, errno);
|
|
}
|
|
data->descriptor = -1;
|
|
}
|
|
}
|
|
|
|
void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) {
|
|
int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR);
|
|
if (d != -1) {
|
|
nbytes = write(d, bytes, nbytes);
|
|
if (nbytes > 0) {
|
|
shared_open(data, name, nbytes);
|
|
}
|
|
} else {
|
|
printf("shared_create %s failed\n", name);
|
|
}
|
|
}
|
|
|
|
static int world_size;
|
|
|
|
// SHM based allreduce helper functions
|
|
// buffer that holds shm name
|
|
#define NAME_BUF_SIZE 1000
|
|
#define MAX_BUF_SIZE 1048576 * 32
|
|
#define NAIVE_ALLREDUCE_THRESHOLD 1048576
|
|
#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer"
|
|
struct allreduce_workspace {
|
|
enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce
|
|
// idx=1 -- state for distributed_naive_all_reduce
|
|
// double buffer to avoid syncing between rounds
|
|
// offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for
|
|
// symmetric_naive_all_reduce after that : buffer for
|
|
// distributed_naive_all_reduce
|
|
char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE];
|
|
};
|
|
|
|
#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD
|
|
#define BUFFER1_OFFSET(current_buffer) 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE
|
|
|
|
struct allreduce_workspace** workspace;
|
|
|
|
// buffer for small messages, double buffer
|
|
char** symmetric_buffer[2];
|
|
// buffer for large messages, double buffer
|
|
char** distributed_buffer[2];
|
|
|
|
void wait_buffer_state_until_2(int index, enum coll_state state0, enum coll_state state1, int state_group) {
|
|
volatile enum coll_state* state_ptr = &(workspace[index]->states[state_group]);
|
|
|
|
while (1) {
|
|
volatile enum coll_state cur_state = *state_ptr;
|
|
if (cur_state == state0 || cur_state == state1) break;
|
|
}
|
|
}
|
|
|
|
__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
|
|
inline __m512 cvt_bf16_to_fp32(const __m256i src) {
|
|
auto y = _mm512_cvtepu16_epi32(src);
|
|
return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2));
|
|
}
|
|
|
|
inline __m256i cvt_fp32_to_bf16(const __m512 src) __attribute__((target("avx512bw")));
|
|
inline __m256i cvt_fp32_to_bf16(const __m512 src) {
|
|
__m512i value = _mm512_castps_si512(src);
|
|
__m512i nan = _mm512_set1_epi32(0xffff);
|
|
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
|
|
__m512i ones = _mm512_set1_epi32(0x1);
|
|
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
|
|
// uint32_t lsb = (input >> 16) & 1;
|
|
auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
|
|
// uint32_t rounding_bias = 0x7fff + lsb;
|
|
t_value = _mm512_add_epi32(t_value, vec_bias);
|
|
// input += rounding_bias;
|
|
t_value = _mm512_add_epi32(t_value, value);
|
|
// input = input >> 16;
|
|
t_value = _mm512_srli_epi32(t_value, 16);
|
|
// Check NaN before converting back to bf16
|
|
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
|
|
return _mm512_cvtusepi32_epi16(t_value);
|
|
}
|
|
|
|
__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
|
|
inline __m512 cvt_fp16_to_fp32(const __m256i src) {
|
|
return _mm512_cvtph_ps(src);
|
|
}
|
|
|
|
inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw")));
|
|
inline __m256i cvt_fp32_to_fp16(const __m512 src) {
|
|
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
|
}
|
|
|
|
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
|
__attribute__((target("avx512bw")));
|
|
|
|
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
|
__attribute__((target("avx512bw")));
|
|
|
|
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
|
__attribute__((target("avx512bw")));
|
|
|
|
void reduce_all_buffers(
|
|
int start_elements,
|
|
int num_elements,
|
|
c10::ScalarType scalar_type,
|
|
int to_buffer_idx,
|
|
char* to_buffer,
|
|
char** buffers) {
|
|
switch (scalar_type) {
|
|
case c10::ScalarType::BFloat16:
|
|
reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers);
|
|
break;
|
|
case c10::ScalarType::Half:
|
|
reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers);
|
|
break;
|
|
case c10::ScalarType::Float:
|
|
reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers);
|
|
break;
|
|
default:
|
|
assert(!"Should not get here");
|
|
}
|
|
}
|
|
|
|
#define CVT_ADD_BF16(x) \
|
|
do { \
|
|
auto in##x##_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
|
|
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
|
|
} while (0)
|
|
|
|
// Reduce functions down below use vectorized algorithm, the number of bytes
|
|
// processed each iteration depends on vector length. 256bit vector ==> 32
|
|
// bytes, 512bit vector ==> 64 bytes If you change implementation of
|
|
// reduce_bf16_buffers, etc. , check whether this number needs to be changed
|
|
#define VECTOR_LENGTH_IN_BYTES 32
|
|
|
|
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
|
|
const int element_size = 2;
|
|
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
|
int main_elements = num_elements - (num_elements % vector_length);
|
|
int remain_elements = num_elements % vector_length;
|
|
|
|
// process aligned part
|
|
#pragma omp parallel for
|
|
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
|
i += VECTOR_LENGTH_IN_BYTES) {
|
|
auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
|
|
switch (world_size) {
|
|
case 16:
|
|
CVT_ADD_BF16(15);
|
|
case 15:
|
|
CVT_ADD_BF16(14);
|
|
case 14:
|
|
CVT_ADD_BF16(13);
|
|
case 13:
|
|
CVT_ADD_BF16(12);
|
|
case 12:
|
|
CVT_ADD_BF16(11);
|
|
case 11:
|
|
CVT_ADD_BF16(10);
|
|
case 10:
|
|
CVT_ADD_BF16(9);
|
|
case 9:
|
|
CVT_ADD_BF16(8);
|
|
case 8:
|
|
CVT_ADD_BF16(7);
|
|
case 7:
|
|
CVT_ADD_BF16(6);
|
|
case 6:
|
|
CVT_ADD_BF16(5);
|
|
case 5:
|
|
CVT_ADD_BF16(4);
|
|
case 4:
|
|
CVT_ADD_BF16(3);
|
|
case 3:
|
|
CVT_ADD_BF16(2);
|
|
case 2:
|
|
CVT_ADD_BF16(1);
|
|
case 1:
|
|
break;
|
|
default:
|
|
for (int j = 1; j < world_size; j++) {
|
|
auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
|
|
inout_val = _mm512_add_ps(inout_val, in_val);
|
|
}
|
|
}
|
|
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val));
|
|
}
|
|
|
|
// process remaining part
|
|
int i = (start_elements + main_elements) * element_size;
|
|
while (remain_elements > 0) {
|
|
float val = 0.0f;
|
|
for (int j = 0; j < world_size; j++) {
|
|
val += *(at::BFloat16*)(buffers[j] + i);
|
|
}
|
|
*(at::BFloat16*)(to_buffer + i) = val;
|
|
remain_elements--;
|
|
i += element_size;
|
|
}
|
|
}
|
|
|
|
#define CVT_ADD_FP16(x) \
|
|
do { \
|
|
auto in##x##_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
|
|
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
|
|
} while (0)
|
|
|
|
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
|
|
const int element_size = 2;
|
|
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
|
int main_elements = num_elements - (num_elements % vector_length);
|
|
int remain_elements = num_elements % vector_length;
|
|
|
|
// process aligned part
|
|
#pragma omp parallel for
|
|
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
|
i += VECTOR_LENGTH_IN_BYTES) {
|
|
auto inout_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
|
|
switch (world_size) {
|
|
case 16:
|
|
CVT_ADD_FP16(15);
|
|
case 15:
|
|
CVT_ADD_FP16(14);
|
|
case 14:
|
|
CVT_ADD_FP16(13);
|
|
case 13:
|
|
CVT_ADD_FP16(12);
|
|
case 12:
|
|
CVT_ADD_FP16(11);
|
|
case 11:
|
|
CVT_ADD_FP16(10);
|
|
case 10:
|
|
CVT_ADD_FP16(9);
|
|
case 9:
|
|
CVT_ADD_FP16(8);
|
|
case 8:
|
|
CVT_ADD_FP16(7);
|
|
case 7:
|
|
CVT_ADD_FP16(6);
|
|
case 6:
|
|
CVT_ADD_FP16(5);
|
|
case 5:
|
|
CVT_ADD_FP16(4);
|
|
case 4:
|
|
CVT_ADD_FP16(3);
|
|
case 3:
|
|
CVT_ADD_FP16(2);
|
|
case 2:
|
|
CVT_ADD_FP16(1);
|
|
case 1:
|
|
break;
|
|
default:
|
|
for (int j = 1; j < world_size; j++) {
|
|
auto in_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
|
|
inout_val = _mm512_add_ps(inout_val, in_val);
|
|
}
|
|
}
|
|
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val));
|
|
}
|
|
|
|
// process remaining part
|
|
int i = (start_elements + main_elements) * element_size;
|
|
while (remain_elements > 0) {
|
|
float val = 0.0f;
|
|
for (int j = 0; j < world_size; j++) {
|
|
val += *(at::Half*)(buffers[j] + i);
|
|
}
|
|
*(at::Half*)(to_buffer + i) = val;
|
|
remain_elements--;
|
|
i += element_size;
|
|
}
|
|
}
|
|
|
|
#define CVT_ADD_F32(x) \
|
|
do { \
|
|
auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \
|
|
inout_val = _mm256_add_ps(inout_val, in##x##_val); \
|
|
} while (0)
|
|
|
|
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
|
|
const int element_size = 4;
|
|
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
|
int main_elements = num_elements - (num_elements % vector_length);
|
|
int remain_elements = num_elements % vector_length;
|
|
|
|
// process aligned part
|
|
#pragma omp parallel for
|
|
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
|
i += VECTOR_LENGTH_IN_BYTES) {
|
|
auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i));
|
|
switch (world_size) {
|
|
case 16:
|
|
CVT_ADD_F32(15);
|
|
case 15:
|
|
CVT_ADD_F32(14);
|
|
case 14:
|
|
CVT_ADD_F32(13);
|
|
case 13:
|
|
CVT_ADD_F32(12);
|
|
case 12:
|
|
CVT_ADD_F32(11);
|
|
case 11:
|
|
CVT_ADD_F32(10);
|
|
case 10:
|
|
CVT_ADD_F32(9);
|
|
case 9:
|
|
CVT_ADD_F32(8);
|
|
case 8:
|
|
CVT_ADD_F32(7);
|
|
case 7:
|
|
CVT_ADD_F32(6);
|
|
case 6:
|
|
CVT_ADD_F32(5);
|
|
case 5:
|
|
CVT_ADD_F32(4);
|
|
case 4:
|
|
CVT_ADD_F32(3);
|
|
case 3:
|
|
CVT_ADD_F32(2);
|
|
case 2:
|
|
CVT_ADD_F32(1);
|
|
case 1:
|
|
break;
|
|
default:
|
|
for (int j = 1; j < world_size; j++) {
|
|
auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i));
|
|
inout_val = _mm256_add_ps(inout_val, in_val);
|
|
}
|
|
}
|
|
_mm256_storeu_ps((float*)(to_buffer + i), inout_val);
|
|
}
|
|
|
|
// process remaining part
|
|
int i = (start_elements + main_elements) * element_size;
|
|
while (remain_elements > 0) {
|
|
float val = 0.0f;
|
|
for (int j = 0; j < world_size; j++) {
|
|
val += *(float*)(buffers[j] + i);
|
|
}
|
|
*(float*)(to_buffer + i) = val;
|
|
remain_elements--;
|
|
i += element_size;
|
|
}
|
|
}
|
|
|
|
static bool is_initialized = false;
|
|
static int world_rank;
|
|
|
|
void shm_initialize(int size, int rank, const char* addr_string, const char* port_string) {
|
|
if (is_initialized) {
|
|
return;
|
|
}
|
|
is_initialized = true;
|
|
|
|
world_size = size;
|
|
world_rank = rank;
|
|
|
|
char shm_name_prefix[NAME_BUF_SIZE];
|
|
char shm_name[NAME_BUF_SIZE];
|
|
snprintf(shm_name_prefix, NAME_BUF_SIZE, "%s_%d_%s_%s", SHM_BUFFER_NAME, getuid(), addr_string, port_string);
|
|
// create shared workspace for SHM based allreduce
|
|
SharedData allreduce_buffer;
|
|
// allocate workspace_buf for current rank
|
|
struct allreduce_workspace* workspace_buf;
|
|
struct allreduce_workspace* workspace_buf_other;
|
|
workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace));
|
|
snprintf(shm_name, NAME_BUF_SIZE, "%.900s_%d", shm_name_prefix, rank);
|
|
shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace));
|
|
workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes;
|
|
workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done;
|
|
workspace_buf->states[1] = coll_begin;
|
|
|
|
// create the workspace pointer list
|
|
workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*));
|
|
symmetric_buffer[0] = (char**)malloc(size * sizeof(char**));
|
|
symmetric_buffer[1] = (char**)malloc(size * sizeof(char**));
|
|
distributed_buffer[0] = (char**)malloc(size * sizeof(char**));
|
|
distributed_buffer[1] = (char**)malloc(size * sizeof(char**));
|
|
|
|
// map shm of all ranks
|
|
for (int i = 0; i < size; i++) {
|
|
if (i != rank) {
|
|
snprintf(shm_name, NAME_BUF_SIZE, "%.900s_%d", shm_name_prefix, i);
|
|
// printf("open %s, %d\n", shm_name, rank);
|
|
do {
|
|
shared_open(&allreduce_buffer, shm_name, sizeof(struct allreduce_workspace));
|
|
} while (allreduce_buffer.descriptor == -1 && errno == ENOENT);
|
|
workspace_buf_other = (struct allreduce_workspace*)allreduce_buffer.bytes;
|
|
workspace[i] = workspace_buf_other;
|
|
} else {
|
|
workspace[i] = workspace_buf;
|
|
}
|
|
symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0);
|
|
symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1);
|
|
distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0);
|
|
distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1);
|
|
}
|
|
}
|
|
|
|
static void parallel_memcpy(void* to, void* from, size_t n_bytes) __attribute__((target("avx512bw")));
|
|
static void parallel_memcpy(void* to, void* from, size_t n_bytes) {
|
|
auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES);
|
|
// process aligned part
|
|
#pragma omp parallel for
|
|
for (size_t i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) {
|
|
auto val = _mm256_loadu_si256((__m256i*)((char*)from + i));
|
|
_mm256_storeu_si256((__m256i*)((char*)to + i), val);
|
|
}
|
|
|
|
// process remaining part
|
|
for (size_t i = aligned_bytes; i < n_bytes; i++) {
|
|
*((char*)to + i) = *((char*)from + i);
|
|
}
|
|
}
|
|
|
|
#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod))
|
|
#define rank_mod(rank) positive_mod(rank, world_size)
|
|
size_t slice_size(size_t chunk_el, int slice_idx) {
|
|
size_t slice_size = chunk_el / world_size;
|
|
return slice_idx == world_size - 1 ? slice_size + (chunk_el % world_size) : slice_size;
|
|
}
|
|
|
|
char* slice_data(char* data_ptr, size_t chunk_el, int el_size, int slice_idx) {
|
|
size_t slice_size = chunk_el / world_size;
|
|
size_t el_offset = slice_size * slice_idx;
|
|
return data_ptr + el_offset * el_size;
|
|
}
|
|
|
|
size_t slice_el_start(size_t chunk_el, int slice_idx) {
|
|
size_t slice_size = chunk_el / world_size;
|
|
return slice_size * slice_idx;
|
|
}
|
|
|
|
void symmetric_naive_all_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) {
|
|
const int state_group = 0;
|
|
static int current_buffer = 0;
|
|
static int state_idx = 0;
|
|
|
|
// init states to case 0 to get rid of "maybe-uninitialized" warning.
|
|
enum coll_state copy_current = coll_allreduce_naive__copy_in_done;
|
|
enum coll_state copy_next = coll_alt1_allreduce_naive__copy_in_done;
|
|
|
|
switch (state_idx) {
|
|
case 0:
|
|
copy_current = coll_allreduce_naive__copy_in_done;
|
|
copy_next = coll_alt1_allreduce_naive__copy_in_done;
|
|
break;
|
|
case 1:
|
|
copy_current = coll_alt1_allreduce_naive__copy_in_done;
|
|
copy_next = coll_alt2_allreduce_naive__copy_in_done;
|
|
break;
|
|
case 2:
|
|
copy_current = coll_alt2_allreduce_naive__copy_in_done;
|
|
copy_next = coll_allreduce_naive__copy_in_done;
|
|
break;
|
|
default:
|
|
assert(!"Should not get here.");
|
|
}
|
|
state_idx = (state_idx + 1) % 3;
|
|
|
|
parallel_memcpy(symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size);
|
|
std::atomic_thread_fence(std::memory_order_release);
|
|
workspace[world_rank]->states[state_group] = copy_current;
|
|
|
|
for (int i = 0; i < world_size; i++) {
|
|
// wait until the other rank copy the buffer
|
|
if (i != world_rank) {
|
|
wait_buffer_state_until_2(i, copy_current, copy_next, state_group);
|
|
}
|
|
}
|
|
|
|
// each rank reduce the buffer independently so therre is no need for
|
|
// synchronization afterward
|
|
reduce_all_buffers(0, chunk_el, scalar_type, world_rank, data_ptr, symmetric_buffer[current_buffer]);
|
|
|
|
// switch buffer
|
|
current_buffer = 1 - current_buffer;
|
|
}
|
|
|
|
// naive allreduce distributed, each rank do naive reduce on its slice
|
|
void distributed_naive_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) {
|
|
const int state_group = 1;
|
|
static int current_buffer = 0;
|
|
static int state_idx = 0;
|
|
|
|
// init states to case 0 to get rid of "maybe-uninitialized" warning.
|
|
enum coll_state copy_current = coll_allreduce_naive__copy_in_done;
|
|
enum coll_state reduce_current = coll_allreduce_naive__reduce_done;
|
|
enum coll_state copy_next = coll_alt1_allreduce_naive__copy_in_done;
|
|
|
|
// similar to symmetric_naive_allreduce, but here we only need two sets of
|
|
// states, because distributed naive reduce has two barriers in the algorithm
|
|
switch (state_idx) {
|
|
case 0:
|
|
copy_current = coll_allreduce_naive__copy_in_done;
|
|
reduce_current = coll_allreduce_naive__reduce_done;
|
|
copy_next = coll_alt1_allreduce_naive__copy_in_done;
|
|
break;
|
|
case 1:
|
|
copy_current = coll_alt1_allreduce_naive__copy_in_done;
|
|
reduce_current = coll_alt1_allreduce_naive__reduce_done;
|
|
copy_next = coll_allreduce_naive__copy_in_done;
|
|
break;
|
|
default:
|
|
assert(!"Should not get here.");
|
|
}
|
|
state_idx = (state_idx + 1) % 2;
|
|
|
|
int data_size = chunk_size / chunk_el;
|
|
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
|
|
std::atomic_thread_fence(std::memory_order_release);
|
|
workspace[world_rank]->states[state_group] = copy_current;
|
|
|
|
for (int i = 0; i < world_size; i++) {
|
|
// wait until all the other ranks copy the buffer
|
|
if (i != world_rank) wait_buffer_state_until_2(i, copy_current, reduce_current, state_group);
|
|
}
|
|
|
|
// reduce scatter
|
|
reduce_all_buffers(
|
|
slice_el_start(chunk_el, world_rank),
|
|
slice_size(chunk_el, world_rank),
|
|
scalar_type,
|
|
world_rank,
|
|
distributed_buffer[current_buffer][world_rank],
|
|
distributed_buffer[current_buffer]);
|
|
std::atomic_thread_fence(std::memory_order_release);
|
|
workspace[world_rank]->states[state_group] = reduce_current;
|
|
|
|
for (int i = 0; i < world_size; i++) {
|
|
// wait until all the other ranks reduce the buffer
|
|
if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group);
|
|
}
|
|
|
|
for (int i = 0; i < world_size; i++) {
|
|
int rank = (i + world_rank) % world_size;
|
|
parallel_memcpy(
|
|
slice_data(data_ptr, chunk_el, data_size, rank),
|
|
slice_data(distributed_buffer[current_buffer][rank], chunk_el, chunk_size / chunk_el, rank),
|
|
slice_size(chunk_el, rank) * data_size);
|
|
}
|
|
|
|
current_buffer = 1 - current_buffer;
|
|
}
|
|
|
|
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size) {
|
|
for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) {
|
|
auto data_ptr = ((char*)(data.data_ptr()) + offset);
|
|
size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset;
|
|
size_t chunk_el = chunk_size / (data_size / numel);
|
|
if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) {
|
|
symmetric_naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
|
|
} else {
|
|
distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
|
|
}
|
|
}
|
|
}
|
|
|
|
void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_t chunk_size, size_t chunk_el) {
|
|
const int state_group = 1;
|
|
static int current_buffer = 0;
|
|
static int state_idx = 0;
|
|
|
|
// init states to case 0 to get rid of "maybe-uninitialized" warning.
|
|
enum coll_state copy_current = coll_allgather_naive__copy_in_done;
|
|
enum coll_state copy_next = coll_alt1_allgather_naive__copy_in_done;
|
|
|
|
switch (state_idx) {
|
|
case 0:
|
|
copy_current = coll_allgather_naive__copy_in_done;
|
|
copy_next = coll_alt1_allgather_naive__copy_in_done;
|
|
break;
|
|
case 1:
|
|
copy_current = coll_alt1_allgather_naive__copy_in_done;
|
|
copy_next = coll_alt2_allgather_naive__copy_in_done;
|
|
break;
|
|
case 2:
|
|
copy_current = coll_alt2_allgather_naive__copy_in_done;
|
|
copy_next = coll_allgather_naive__copy_in_done;
|
|
break;
|
|
default:
|
|
assert(!"Should not get here.");
|
|
}
|
|
state_idx = (state_idx + 1) % 3;
|
|
|
|
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
|
|
std::atomic_thread_fence(std::memory_order_release);
|
|
workspace[world_rank]->states[state_group] = copy_current;
|
|
|
|
for (int i = 0; i < world_size; i++) {
|
|
// wait until all the other ranks copy the buffer
|
|
if (i != world_rank) wait_buffer_state_until_2(i, copy_current, copy_next, state_group);
|
|
}
|
|
for (int i = 0; i < world_size; i++) {
|
|
parallel_memcpy(result_ptr + i * res_stride, distributed_buffer[current_buffer][i], chunk_size);
|
|
}
|
|
current_buffer = 1 - current_buffer;
|
|
}
|
|
|
|
torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size) {
|
|
size_t dim_el = data.stride(dim) * data.size(dim);
|
|
int dtype_size = data_size / numel;
|
|
size_t dim_size = dim_el * dtype_size;
|
|
int dim_count = data_size / dim_size;
|
|
auto data_ptr = (char*)(data.data_ptr());
|
|
auto result_ptr = (char*)(result.data_ptr());
|
|
for (int i = 0; i < dim_count; i++) {
|
|
for (size_t offset = 0; offset < dim_size; offset += MAX_BUF_SIZE) {
|
|
size_t chunk_size = dim_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : dim_size - offset;
|
|
size_t chunk_el = chunk_size / dtype_size;
|
|
naive_all_gather(
|
|
result_ptr + i * dim_size * world_size + offset,
|
|
data_ptr + i * dim_size + offset,
|
|
dim_size,
|
|
chunk_size,
|
|
chunk_el);
|
|
}
|
|
}
|
|
return result;
|
|
}
|