151 lines
6.5 KiB
C++
151 lines
6.5 KiB
C++
/*************************************************************************
|
|
* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* See LICENSE.txt for license information
|
|
************************************************************************/
|
|
|
|
#include <cuda.h> // for CUresult
|
|
#if CUDART_VERSION >= 11030
|
|
#include <cudaTypedefs.h>
|
|
#else
|
|
// IWYU pragma: no_include <cudaTypedefs.h>
|
|
#endif
|
|
|
|
#include <cuda_runtime.h> // for cudaDriverGetVersion
|
|
#include <dlfcn.h> // for dlsym, dlopen, RTLD...
|
|
#include <driver_types.h> // for cudaError_t
|
|
#include <stdio.h> // for NULL, snprintf
|
|
#include "internal/host/debug.h" // for WARN, INFO, NVSHMEM...
|
|
#include "internal/host/error_codes_internal.h" // for NVSHMEMI_SYSTEM_ERROR
|
|
#include "internal/host/util.h" // for nvshmemi_options
|
|
#include "internal/host_transport/cudawrap.h" // for nvshmemi_cuda_fn_table
|
|
#include "bootstrap_host_transport/env_defs_internal.h" // for nvshmemi_options_s
|
|
|
|
static enum {
|
|
cudaUninitialized,
|
|
cudaInitializing,
|
|
cudaInitialized,
|
|
cudaError
|
|
} cudaState = cudaUninitialized;
|
|
|
|
static void *cudaLib;
|
|
static int cudaDriverVersion;
|
|
|
|
static int cudaPfnFuncLoader(struct nvshmemi_cuda_fn_table *table) {
|
|
CUresult res;
|
|
|
|
#define LOAD_SYM(table, symbol, version, sym_suffix, ignore) \
|
|
do { \
|
|
bool not_found = false; \
|
|
if (table->pfn_cuGetProcAddress) { \
|
|
res = \
|
|
table->pfn_cuGetProcAddress(#symbol, (void **)(&table->pfn_##symbol), version, 0); \
|
|
if (res != 0) not_found = true; \
|
|
} else { \
|
|
table->pfn_##symbol = (PFN_##symbol##_v##version)dlsym(cudaLib, #symbol #sym_suffix); \
|
|
if (table->pfn_##symbol == NULL) not_found = true; \
|
|
} \
|
|
if (not_found) { \
|
|
if (!ignore) { \
|
|
WARN("Retrieve %s version %d failed", #symbol #sym_suffix, cudaDriverVersion); \
|
|
return NVSHMEMI_SYSTEM_ERROR; \
|
|
} \
|
|
} \
|
|
} while (0)
|
|
|
|
LOAD_SYM(table, cuCtxGetDevice, 2000, , 0);
|
|
LOAD_SYM(table, cuCtxSynchronize, 2000, , 0);
|
|
LOAD_SYM(table, cuDeviceGet, 2000, , 0);
|
|
LOAD_SYM(table, cuDeviceGetAttribute, 2000, , 0);
|
|
LOAD_SYM(table, cuPointerSetAttribute, 6000, , 0);
|
|
LOAD_SYM(table, cuModuleGetGlobal, 3020, _v2, 0);
|
|
LOAD_SYM(table, cuGetErrorString, 6000, , 0);
|
|
LOAD_SYM(table, cuGetErrorName, 6000, , 0);
|
|
LOAD_SYM(table, cuCtxSetCurrent, 4000, , 0);
|
|
LOAD_SYM(table, cuDevicePrimaryCtxRetain, 7000, , 0);
|
|
LOAD_SYM(table, cuCtxGetCurrent, 4000, , 0);
|
|
LOAD_SYM(table, cuCtxGetFlags, 7000, , 0);
|
|
LOAD_SYM(table, cuCtxSetFlags, 12010, , 1);
|
|
LOAD_SYM(table, cuFlushGPUDirectRDMAWrites, 11030, , 1);
|
|
LOAD_SYM(table, cuMemGetHandleForAddressRange, 11070, , 1); // DMA-BUF support
|
|
LOAD_SYM(table, cuMemCreate, 10020, , 1);
|
|
LOAD_SYM(table, cuMemMap, 10020, , 1);
|
|
LOAD_SYM(table, cuMemAddressReserve, 10020, , 1);
|
|
LOAD_SYM(table, cuMemAddressFree, 10020, , 1);
|
|
LOAD_SYM(table, cuMemGetAllocationGranularity, 10020, , 1);
|
|
LOAD_SYM(table, cuMemImportFromShareableHandle, 10020, , 1);
|
|
LOAD_SYM(table, cuMemExportToShareableHandle, 10020, , 1);
|
|
LOAD_SYM(table, cuMemRelease, 10020, , 1);
|
|
LOAD_SYM(table, cuMemSetAccess, 10020, , 1);
|
|
LOAD_SYM(table, cuMemUnmap, 10020, , 1);
|
|
LOAD_SYM(table, cuMulticastCreate, 12010, , 1);
|
|
LOAD_SYM(table, cuMulticastAddDevice, 12010, , 1);
|
|
LOAD_SYM(table, cuMulticastBindMem, 12010, , 1);
|
|
LOAD_SYM(table, cuMulticastUnbind, 12010, , 1);
|
|
LOAD_SYM(table, cuMulticastGetGranularity, 12010, , 1);
|
|
LOAD_SYM(table, cuStreamWriteValue64, 11070, _v2, 1);
|
|
LOAD_SYM(table, cuStreamWaitValue64, 11070, _v2, 1);
|
|
return NVSHMEMI_SUCCESS;
|
|
}
|
|
|
|
int nvshmemi_cuda_library_init(struct nvshmemi_cuda_fn_table *table) {
|
|
cudaError_t cuda_err;
|
|
|
|
if (cudaState == cudaInitialized) return NVSHMEMI_SUCCESS;
|
|
if (cudaState == cudaError) return NVSHMEMI_SYSTEM_ERROR;
|
|
|
|
/*
|
|
* Load CUDA driver library
|
|
*/
|
|
char path[1024];
|
|
if (!nvshmemi_options.CUDA_PATH_provided)
|
|
snprintf(path, 1024, "%s", "libcuda.so.1");
|
|
else
|
|
snprintf(path, 1024, "%s/%s", nvshmemi_options.CUDA_PATH, "libcuda.so.1");
|
|
|
|
cudaLib = dlopen(path, RTLD_LAZY);
|
|
if (cudaLib == NULL) {
|
|
WARN("Failed to find CUDA library in %s (NVSHMEM_CUDA_PATH=%s)", path,
|
|
nvshmemi_options.CUDA_PATH);
|
|
goto error;
|
|
}
|
|
|
|
/*
|
|
* Load initial CUDA functions
|
|
*/
|
|
|
|
table->pfn_cuInit = (PFN_cuInit_v2000)dlsym(cudaLib, "cuInit");
|
|
if (table->pfn_cuInit == NULL) {
|
|
WARN("Failed to load CUDA missing symbol cuInit");
|
|
goto error;
|
|
}
|
|
|
|
cuda_err = cudaDriverGetVersion(&cudaDriverVersion);
|
|
if (cuda_err != 0) {
|
|
WARN("cudaDriverGetVersion failed with %d", cuda_err);
|
|
goto error;
|
|
}
|
|
INFO(NVSHMEM_INIT, "cudaDriverVersion %d", cudaDriverVersion);
|
|
|
|
table->pfn_cuGetProcAddress = (PFN_cuGetProcAddress_v11030)dlsym(cudaLib, "cuGetProcAddress");
|
|
|
|
/*
|
|
* Required to initialize the CUDA Driver.
|
|
* Multiple calls of cuInit() will return immediately
|
|
* without making any relevant change
|
|
*/
|
|
table->pfn_cuInit(0);
|
|
|
|
if (cudaPfnFuncLoader(table)) {
|
|
WARN("CUDA some PFN functions not found in the library");
|
|
goto error;
|
|
}
|
|
|
|
cudaState = cudaInitialized;
|
|
return NVSHMEMI_SUCCESS;
|
|
|
|
error:
|
|
cudaState = cudaError;
|
|
return NVSHMEMI_SYSTEM_ERROR;
|
|
}
|