import ctypes import functools import os import shlex from typing import Sequence import torch from ..jit import JitSpec from ..jit import env as jit_env from ..jit import gen_jit_spec def gen_nvshmem_module() -> JitSpec: lib_dirs = jit_env.get_nvshmem_lib_dirs() ldflags = ( [f"-L{lib_dir}" for lib_dir in lib_dirs] + ["-lnvshmem_device"] + shlex.split(os.environ.get("NVSHMEM_LDFLAGS", "")) ) return gen_jit_spec( "nvshmem", [jit_env.FLASHINFER_CSRC_DIR / "nvshmem_binding.cu"], extra_include_paths=[str(p) for p in jit_env.get_nvshmem_include_dirs()], extra_ldflags=ldflags, needs_device_linking=True, ) @functools.cache def get_nvshmem_module(): # Try to find libnvshmem_host.so first, fallback to libnvshmem_host.so.3 lib_dirs = jit_env.get_nvshmem_lib_dirs() lib_path = None lib_names = ["libnvshmem_host.so", "libnvshmem_host.so.3"] for lib_dir in lib_dirs: for lib_name in lib_names: candidate_path = lib_dir / lib_name if candidate_path.exists(): lib_path = candidate_path break if lib_path is not None: break if lib_path is None: raise FileNotFoundError( f"Could not find libnvshmem_host.so or libnvshmem_host.so.3 in {lib_dirs}" ) ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL) module = gen_nvshmem_module().build_and_load() return module def get_unique_id() -> torch.Tensor: return get_nvshmem_module().nvshmem_get_unique_id() def unique_id_size() -> int: return get_nvshmem_module().nvshmem_unique_id_size() def alloc_empty_unique_id() -> torch.Tensor: return torch.zeros(unique_id_size(), dtype=torch.uint8, device="cpu") def init(uid: torch.Tensor, rank: int, world_size: int) -> int: status = get_nvshmem_module().nvshmem_init(uid, rank, world_size) torch.cuda.synchronize() return status def alltoall(dest: torch.Tensor, source: torch.Tensor) -> None: return get_nvshmem_module().nvshmem_alltoall(dest, source) def finalize() -> None: torch.cuda.synchronize() get_nvshmem_module().nvshmem_finalize() def my_pe() -> int: return get_nvshmem_module().nvshmem_my_pe() def n_pes() -> int: return get_nvshmem_module().nvshmem_n_pes() def malloc( shape: Sequence[int], dtype: torch.dtype, device: torch.device, ) -> torch.Tensor: """Allocates memory using NVSHMEM collective malloc operation. This is a collective operation that requires participation by all PEs (Processing Elements). All participants must call this function with the same parameters. Note: This tensor should be explicitly deleted (del tensor) to ensure proper ordering of nvshmem_free operations rather than relying on garbage collection. Args: shape: The shape of the tensor to allocate. dtype: The data type of the tensor. device: The device to allocate the tensor on. Returns: A tensor allocated using NVSHMEM collective malloc. Reference: https://docs.nvidia.com/nvshmem/api/gen/api/memory.html#nvshmem-malloc-nvshmem-free-nvshmem-align """ return get_nvshmem_module().nvshmem_malloc(shape, dtype, device) def barrier_all() -> None: get_nvshmem_module().nvshmem_barrier_all() def barrier_all_on_current_stream() -> None: get_nvshmem_module().nvshmem_barrier_all_on_current_stream()