/* * Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved. * * NVIDIA CORPORATION and its licensors retain all intellectual property * and proprietary rights in and to this software, related documentation * and any modifications thereto. Any use, reproduction, disclosure or * distribution of this software and related documentation without an express * license agreement from NVIDIA CORPORATION is strictly prohibited. * * See COPYRIGHT.txt for license information */ #ifndef UTILS #define UTILS #include #include #include #include #include #include #include #include #include #ifdef NVSHMEMTEST_MPI_SUPPORT #include "mpi.h" #endif #ifdef NVSHMEMTEST_SHMEM_SUPPORT #include "shmem.h" #include "shmemx.h" #endif #include "nvshmem.h" #include "nvshmemx.h" #define NVSHMEMI_TEST_STRINGIFY(x) #x #define CUMODULE_LOAD(CUMODULE, CUMODULE_PATH, ERROR) \ CU_CHECK(cuModuleLoad(&CUMODULE, CUMODULE_PATH)); \ ERROR = nvshmemx_cumodule_init(CUMODULE); void init_test_case_kernel(CUfunction *kernel, const char *kernel_name); #ifdef NVSHMEMTEST_MPI_SUPPORT #define MPI_LOAD_SYM(fn_name) \ mpi_fn_table.fn_##fn_name = (fnptr_##fn_name)dlsym(nvshmemi_mpi_handle, #fn_name); \ if (mpi_fn_table.fn_##fn_name == NULL) { \ fprintf(stderr, "Unable to load MPI symbol" #fn_name "\n"); \ return -1; \ } typedef int (*fnptr_MPI_Init)(int *argc, char ***argv); typedef int (*fnptr_MPI_Bcast)(void *buffer, int count, MPI_Datatype datatype, int root, MPI_Comm comm); typedef int (*fnptr_MPI_Comm_rank)(MPI_Comm comm, int *rank); typedef int (*fnptr_MPI_Comm_size)(MPI_Comm comm, int *size); typedef int (*fnptr_MPI_Finalize)(void); struct nvshmemi_mpi_fn_table { fnptr_MPI_Init fn_MPI_Init; fnptr_MPI_Bcast fn_MPI_Bcast; fnptr_MPI_Comm_rank fn_MPI_Comm_rank; fnptr_MPI_Comm_size fn_MPI_Comm_size; fnptr_MPI_Finalize fn_MPI_Finalize; }; extern void *nvshmemi_mpi_handle; extern struct nvshmemi_mpi_fn_table mpi_fn_table; #endif #ifdef NVSHMEMTEST_SHMEM_SUPPORT #define SHMEM_LOAD_SYM(fn_name) \ shmem_fn_table.fn_##fn_name = (fnptr_##fn_name)dlsym(nvshmemi_shmem_handle, #fn_name); \ if (shmem_fn_table.fn_##fn_name == NULL) { \ fprintf(stderr, "Unable to load SHMEM symbol" #fn_name "\n"); \ return -1; \ } typedef void (*fnptr_shmem_init)(void); typedef int (*fnptr_shmem_my_pe)(void); typedef int (*fnptr_shmem_n_pes)(void); typedef void (*fnptr_shmem_fcollect64)(void *target, const void *source, size_t nlong, int PE_start, int logPE_stride, int PE_size, long *pSync); typedef void *(*fnptr_shmem_malloc)(size_t size); typedef void (*fnptr_shmem_free)(void *ptr); typedef void (*fnptr_shmem_finalize)(void); struct nvshmemi_shmem_fn_table { fnptr_shmem_init fn_shmem_init; fnptr_shmem_my_pe fn_shmem_my_pe; fnptr_shmem_n_pes fn_shmem_n_pes; fnptr_shmem_fcollect64 fn_shmem_fcollect64; fnptr_shmem_malloc fn_shmem_malloc; fnptr_shmem_free fn_shmem_free; fnptr_shmem_finalize fn_shmem_finalize; }; extern void *nvshmemi_shmem_handle; extern struct nvshmemi_shmem_fn_table shmem_fn_table; #endif using namespace std; #undef CUDA_CHECK #define CUDA_CHECK(stmt) \ do { \ cudaError_t result = (stmt); \ if (cudaSuccess != result) { \ fprintf(stderr, "[%s:%d] cuda failed with %s \n", __FILE__, __LINE__, \ cudaGetErrorString(result)); \ exit(-1); \ } \ assert(cudaSuccess == result); \ } while (0) #define CU_CHECK(stmt) \ do { \ CUresult result = (stmt); \ char str[1024]; \ if (CUDA_SUCCESS != result) { \ CUresult ret = cuGetErrorString(result, (const char **)&str); \ fprintf(stderr, "[%s:%d] cuda failed with (%d) %s \n", __FILE__, __LINE__, \ (int)result, (ret != CUDA_SUCCESS) ? "cuGetErrorString failed" : str); \ exit(-1); \ } \ assert(CUDA_SUCCESS == result); \ } while (0) #define ERROR_EXIT(...) \ do { \ fprintf(stderr, "%s:%s:%d: ", __FILE__, __FUNCTION__, __LINE__); \ fprintf(stderr, __VA_ARGS__); \ exit(-1); \ } while (0) #define ERROR_PRINT(...) \ do { \ fprintf(stderr, "%s:%s:%d: ", __FILE__, __FUNCTION__, __LINE__); \ fprintf(stderr, __VA_ARGS__); \ } while (0) #undef WARN_PRINT #define WARN_PRINT(...) \ do { \ fprintf(stdout, "%s:%s:%d: ", __FILE__, __FUNCTION__, __LINE__); \ fprintf(stdout, __VA_ARGS__); \ } while (0) #ifdef _NVSHMEM_DEBUG #define DEBUG_PRINT(...) fprintf(stderr, __VA_ARGS__); #else #define DEBUG_PRINT(...) \ do { \ } while (0) #endif #define MS_TO_S 1000 #define B_TO_GB (1000 * 1000 * 1000) enum NVSHMEM_DATATYPE_T { NVSHMEM_INT = 0, NVSHMEM_LONG, NVSHMEM_LONGLONG, NVSHMEM_ULONGLONG, NVSHMEM_SIZE, NVSHMEM_PTRDIFF, NVSHMEM_FLOAT, NVSHMEM_DOUBLE, NVSHMEM_UINT, NVSHMEM_INT32, NVSHMEM_UINT32, NVSHMEM_INT64, NVSHMEM_UINT64, NVSHMEM_FP16, NVSHMEM_BF16 }; enum NVSHMEM_REDUCE_OP_T { NVSHMEM_MIN = 0, NVSHMEM_MAX, NVSHMEM_SUM, NVSHMEM_PROD, NVSHMEM_AND, NVSHMEM_OR, NVSHMEM_XOR }; enum NVSHMEM_THREADGROUP_SCOPE_T { NVSHMEM_THREAD = 0, NVSHMEM_WARP, NVSHMEM_BLOCK, NVSHMEM_ALL_SCOPES }; enum AMO_T { AMO_INC = 0, AMO_FETCH_INC, AMO_SET, AMO_ADD, AMO_FETCH_ADD, AMO_AND, AMO_FETCH_AND, AMO_OR, AMO_FETCH_OR, AMO_XOR, AMO_FETCH_XOR, AMO_SWAP, AMO_COMPARE_SWAP, AMO_ACK }; enum PUTGET_ISSUE_T { ON_STREAM = 0, HOST }; enum DIR_T { WRITE = 0, READ }; extern size_t min_size; extern size_t max_size; extern size_t num_blocks; extern size_t threads_per_block; extern size_t iters; extern size_t warmup_iters; extern size_t step_factor; extern size_t max_size_log; extern size_t stride; extern bool bidirectional; struct datatype_t { NVSHMEM_DATATYPE_T type; size_t size; string name; }; extern datatype_t datatype; struct reduce_op_t { NVSHMEM_REDUCE_OP_T type; string name; }; extern reduce_op_t reduce_op; struct threadgroup_scope_t { NVSHMEM_THREADGROUP_SCOPE_T type; string name; }; extern threadgroup_scope_t threadgroup_scope; struct amo_t { AMO_T type; string name; }; extern amo_t test_amo; struct putget_issue_t { PUTGET_ISSUE_T type; string name; }; extern putget_issue_t putget_issue; struct dir_t { DIR_T type; string name; }; extern dir_t dir; extern __device__ int clockrate; extern bool use_cubin; void init_cumodule(const char *str); void init_wrapper(int *c, char ***v); void finalize_wrapper(); void alloc_tables(void ***table_mem, int num_tables, int num_entries_per_table); void free_tables(void **tables, int num_tables); void print_table_basic(const char *job_name, const char *subjob_name, const char *var_name, const char *output_var, const char *units, const char plus_minus, uint64_t *size, double *value, int num_entries); void print_table_v1(const char *job_name, const char *subjob_name, const char *var_name, const char *output_var, const char *units, const char plus_minus, uint64_t *size, double *value, int num_entries); void print_table_v2(const char *job_name, const char *subjob_name, const char *var_name, const char *output_var, const char *units, const char plus_minus, uint64_t *size, double **value, int num_entries, size_t num_iters); void read_args(int argc, char **argv); #endif