/* * Copyright (c) 2017-2020, NVIDIA CORPORATION. All rights reserved. * * See COPYRIGHT for license information */ #include #include #include #ifdef __clang_llvm_bitcode_lib__ #define assert(...) #include "nvshmem.h" #endif #include "non_abi/nvshmem_build_options.h" #include "non_abi/nvshmem_version.h" #include "non_abi/nvshmemx_error.h" #include "internal/device/nvshmemi_device.h" #include "non_abi/device/pt-to-pt/proxy_device.cuh" #include "device_host/nvshmem_common.cuh" #include "device_host/nvshmem_types.h" #ifdef NVSHMEM_IBGDA_SUPPORT #include "device_host_transport/nvshmem_common_ibgda.h" __constant__ __attribute__((used)) nvshmemi_ibgda_device_state_t nvshmemi_ibgda_device_state_d; #endif nvshmemi_device_state_t nvshmemi_device_only_state; #ifdef __clang__ __constant__ __attribute__((address_space(4), used)) nvshmemi_device_host_state_t nvshmemi_device_state_d = {}; const nvshmemi_version_t nvshmemi_device_lib_version = { NVSHMEM_VENDOR_MAJOR_VERSION, NVSHMEM_VENDOR_MINOR_VERSION, NVSHMEM_VENDOR_PATCH_VERSION}; __constant__ __attribute__((address_space(4), used)) nvshmemi_version_t nvshmemi_device_lib_version_d = { NVSHMEM_VENDOR_MAJOR_VERSION, NVSHMEM_VENDOR_MINOR_VERSION, NVSHMEM_VENDOR_PATCH_VERSION}; #else __constant__ nvshmemi_device_host_state_t nvshmemi_device_state_d; const nvshmemi_version_t nvshmemi_device_lib_version = { NVSHMEM_VENDOR_MAJOR_VERSION, NVSHMEM_VENDOR_MINOR_VERSION, NVSHMEM_VENDOR_PATCH_VERSION}; __constant__ nvshmemi_version_t nvshmemi_device_lib_version_d = { NVSHMEM_VENDOR_MAJOR_VERSION, NVSHMEM_VENDOR_MINOR_VERSION, NVSHMEM_VENDOR_PATCH_VERSION}; #endif #ifdef __CUDA_ARCH__ #ifdef __cplusplus extern "C" { #endif NVSHMEMI_DEVICE_PREFIX void nvshmem_global_exit(int status); #ifdef __cplusplus } #endif NVSHMEMI_DEVICE_PREFIX void nvshmem_global_exit(int status) { if (nvshmemi_device_state_d.proxy > NVSHMEMI_PROXY_NONE) { nvshmemi_proxy_global_exit(status); } else { /* TODO: Add device side printing macros */ printf( "Device side proxy was called, but is not supported under your configuration. " "Please unset NVSHMEM_DISABLE_LOCAL_ONLY_PROXY, or set it to false.\n"); assert(0); } } #endif #ifdef __cplusplus extern "C" { void nvshmemi_get_mem_handle(void **dev_state_ptr, void **transport_dev_state_ptr); } #endif static int _nvshmemi_init_device_only_state() { int status = 0; status = nvshmemi_setup_collective_launch(); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "_nvshmemi_init_device_only_state failed\n"); nvshmemi_device_only_state.is_initialized = true; out: return status; } void nvshmemi_check_state_and_init_d() { int status; int ret; if (nvshmemid_init_status() == NVSHMEM_STATUS_NOT_INITIALIZED) NVSHMEMI_ERROR_EXIT("nvshmem API called before nvshmem_init \n"); if (nvshmemid_init_status() == NVSHMEM_STATUS_IS_BOOTSTRAPPED) { /* The fact that we can pass NVSHMEM_THREAD_SERIALIZED * here is an implementation detail. It should be fixed * if/when NVSHMEM_THREAD_* becomes significant. */ status = nvshmemid_hostlib_init_attr(NVSHMEM_THREAD_SERIALIZED, &ret, 0, NULL, nvshmemi_device_lib_version, NULL); if (status) { NVSHMEMI_ERROR_EXIT("nvshmem initialization failed, exiting \n"); } status = cudaGetDevice(&nvshmemi_device_only_state.cuda_device_id); if (status) { NVSHMEMI_ERROR_EXIT("nvshmem cuda device query failed, exiting \n"); } nvshmemid_hostlib_finalize(NULL, NULL); } if (!nvshmemi_device_only_state.is_initialized) { status = _nvshmemi_init_device_only_state(); if (status) { NVSHMEMI_ERROR_EXIT("nvshmem device initialization failed, exiting \n"); } } } void nvshmemi_get_mem_handle(void **dev_state_ptr, void **transport_dev_state_ptr) { int status = 0; status = cudaGetSymbolAddress(dev_state_ptr, nvshmemi_device_state_d); if (status) { NVSHMEMI_ERROR_PRINT("Unable to access device state. %d\n", status); *dev_state_ptr = NULL; } #ifdef NVSHMEM_IBGDA_SUPPORT status = cudaGetSymbolAddress(transport_dev_state_ptr, nvshmemi_ibgda_device_state_d); if (status) { NVSHMEMI_ERROR_PRINT("Unable to access ibgda device state. %d\n", status); *transport_dev_state_ptr = NULL; } #endif } int nvshmemi_init_thread(int requested_thread_support, int *provided_thread_support, unsigned int bootstrap_flags, nvshmemx_init_attr_t *bootstrap_attr, nvshmemi_version_t nvshmem_app_version) { int status = 0; #ifdef _NVSHMEM_DEBUG printf(" %-28s %d\n", "DEVICE CUDA API", CUDART_VERSION); #endif status = nvshmemid_hostlib_init_attr(requested_thread_support, provided_thread_support, bootstrap_flags, bootstrap_attr, nvshmemi_device_lib_version, &nvshmemi_get_mem_handle); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "nvshmem_internal_init_thread failed \n"); if (nvshmemid_init_status() > NVSHMEM_STATUS_IS_BOOTSTRAPPED) { status = _nvshmemi_init_device_only_state(); NVSHMEMI_NZ_ERROR_JMP(status, NVSHMEMX_ERROR_INTERNAL, out, "nvshmem_internal_init_thread failed at init_device_only_state.\n"); status = cudaGetDevice(&nvshmemi_device_only_state.cuda_device_id); if (status) { NVSHMEMI_ERROR_EXIT("nvshmem cuda device query failed, exiting \n"); } } out: return status; } #ifdef __cplusplus extern "C" { #endif void nvshmemi_finalize() { int status; void *dev_state_ptr, *transport_dev_state_ptr = NULL; status = cudaGetSymbolAddress(&dev_state_ptr, nvshmemi_device_state_d); if (status) { NVSHMEMI_ERROR_PRINT("Unable to properly unregister device state.\n"); nvshmemid_hostlib_finalize(NULL, NULL); return; } #ifdef NVSHMEM_IBGDA_SUPPORT status = cudaGetSymbolAddress(&transport_dev_state_ptr, nvshmemi_ibgda_device_state_d); if (status) { NVSHMEMI_ERROR_PRINT("Unable to properly unregister device state.\n"); nvshmemid_hostlib_finalize(NULL, NULL); return; } #endif nvshmemid_hostlib_finalize(dev_state_ptr, transport_dev_state_ptr); } #ifdef __cplusplus } #endif