sglang_v0.5.2/nvshmem_src/examples/bootstrap_helper.h

106 lines
3.4 KiB
C

#ifndef _NVSHMEMI_EXAMPLES_BOOTSTRAP_HELPER_
#define _NVSHMEMI_EXAMPLES_BOOTSTRAP_HELPER_
#ifdef NVSHMEMTEST_MPI_SUPPORT
#include <dlfcn.h>
#include <errno.h>
#include <stdio.h>
#include <string.h>
#include "nvshmem.h"
#include "nvshmemx.h"
#include "mpi.h"
typedef int (*fnptr_MPI_Init)(int *argc, char ***argv);
typedef int (*fnptr_MPI_Bcast)(void *buffer, int count, MPI_Datatype datatype, int root,
MPI_Comm comm);
typedef int (*fnptr_MPI_Comm_rank)(MPI_Comm comm, int *rank);
typedef int (*fnptr_MPI_Comm_size)(MPI_Comm comm, int *size);
typedef int (*fnptr_MPI_Finalize)(void);
struct nvshmemi_mpi_fn_table {
fnptr_MPI_Init fn_MPI_Init;
fnptr_MPI_Bcast fn_MPI_Bcast;
fnptr_MPI_Comm_rank fn_MPI_Comm_rank;
fnptr_MPI_Comm_size fn_MPI_Comm_size;
fnptr_MPI_Finalize fn_MPI_Finalize;
};
void *nvshmemi_mpi_handle = NULL;
struct nvshmemi_mpi_fn_table mpi_fn_table = {0};
MPI_Comm MPI_COMM_WORLD_PLACEHOLDER;
MPI_Datatype MPI_UINT8_T_PLACEHOLDER;
MPI_Datatype *mpi_uint8_ptr;
#define MPI_LOAD_SYM(fn_name) \
mpi_fn_table.fn_##fn_name = (fnptr_##fn_name)dlsym(nvshmemi_mpi_handle, #fn_name); \
if (mpi_fn_table.fn_##fn_name == NULL) { \
fprintf(stderr, "Unable to load MPI symbol" #fn_name "\n"); \
return -1; \
}
int nvshmemi_load_mpi() {
nvshmemi_mpi_handle = dlopen("libmpi.so.40", RTLD_NOW | RTLD_GLOBAL | RTLD_DEEPBIND);
if (nvshmemi_mpi_handle == NULL) {
// Print the error number and description from errno.
fprintf(stderr, "dlopen failed: errno = %d, description = %s\n", errno, strerror(errno));
// Additionally, print the error message from dlerror for more specific information.
const char *dlerror_msg = dlerror();
if (dlerror_msg) {
fprintf(stderr, "dlerror: %s\n", dlerror_msg);
}
fprintf(stderr,
"Unable to dlopen libmpi.so.40."
"Please add it to your LD_LIBRARY_PATH or run without"
" NVSHMEMTEST_USE_MPI_LAUNCHER.\n");
return -1;
}
MPI_LOAD_SYM(MPI_Init);
MPI_LOAD_SYM(MPI_Bcast);
MPI_LOAD_SYM(MPI_Comm_rank);
MPI_LOAD_SYM(MPI_Comm_size);
MPI_LOAD_SYM(MPI_Finalize);
return 0;
}
void nvshmemi_init_mpi(int *c, char ***v) {
int status;
int rank, nranks;
status = nvshmemi_load_mpi();
if (status) exit(-1);
mpi_fn_table.fn_MPI_Init(c, v);
MPI_COMM_WORLD_PLACEHOLDER = (MPI_Comm)dlsym(nvshmemi_mpi_handle, "ompi_mpi_comm_world");
MPI_UINT8_T_PLACEHOLDER = (MPI_Datatype)dlsym(nvshmemi_mpi_handle, "ompi_mpi_uint8_t");
mpi_fn_table.fn_MPI_Comm_rank(MPI_COMM_WORLD_PLACEHOLDER, &rank);
mpi_fn_table.fn_MPI_Comm_size(MPI_COMM_WORLD_PLACEHOLDER, &nranks);
MPI_Comm mpi_comm = MPI_COMM_WORLD_PLACEHOLDER;
nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER;
attr.mpi_comm = &mpi_comm;
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr);
}
int nvshmemi_dlclose_mpi() {
int status;
status = dlclose(nvshmemi_mpi_handle);
if (status) {
fprintf(stderr, "unable to dlclose MPI.\n");
return -1;
}
return 0;
}
void nvshmemi_finalize_mpi() {
mpi_fn_table.fn_MPI_Finalize();
nvshmemi_dlclose_mpi();
}
#endif // NVSHMEMTEST_MPI_SUPPORT
#endif // _NVSHMEMI_EXAMPLES_BOOTSTRAP_HELPER_