sglang_v0.5.2/flashinfer_0.3.1/flashinfer/comm/nvshmem.py

127 lines
3.4 KiB
Python

import ctypes
import functools
import os
import shlex
from typing import Sequence
import torch
from ..jit import JitSpec
from ..jit import env as jit_env
from ..jit import gen_jit_spec
def gen_nvshmem_module() -> JitSpec:
lib_dirs = jit_env.get_nvshmem_lib_dirs()
ldflags = (
[f"-L{lib_dir}" for lib_dir in lib_dirs]
+ ["-lnvshmem_device"]
+ shlex.split(os.environ.get("NVSHMEM_LDFLAGS", ""))
)
return gen_jit_spec(
"nvshmem",
[jit_env.FLASHINFER_CSRC_DIR / "nvshmem_binding.cu"],
extra_include_paths=[str(p) for p in jit_env.get_nvshmem_include_dirs()],
extra_ldflags=ldflags,
needs_device_linking=True,
)
@functools.cache
def get_nvshmem_module():
# Try to find libnvshmem_host.so first, fallback to libnvshmem_host.so.3
lib_dirs = jit_env.get_nvshmem_lib_dirs()
lib_path = None
lib_names = ["libnvshmem_host.so", "libnvshmem_host.so.3"]
for lib_dir in lib_dirs:
for lib_name in lib_names:
candidate_path = lib_dir / lib_name
if candidate_path.exists():
lib_path = candidate_path
break
if lib_path is not None:
break
if lib_path is None:
raise FileNotFoundError(
f"Could not find libnvshmem_host.so or libnvshmem_host.so.3 in {lib_dirs}"
)
ctypes.CDLL(lib_path, mode=ctypes.RTLD_GLOBAL)
module = gen_nvshmem_module().build_and_load()
return module
def get_unique_id() -> torch.Tensor:
return get_nvshmem_module().nvshmem_get_unique_id()
def unique_id_size() -> int:
return get_nvshmem_module().nvshmem_unique_id_size()
def alloc_empty_unique_id() -> torch.Tensor:
return torch.zeros(unique_id_size(), dtype=torch.uint8, device="cpu")
def init(uid: torch.Tensor, rank: int, world_size: int) -> int:
status = get_nvshmem_module().nvshmem_init(uid, rank, world_size)
torch.cuda.synchronize()
return status
def alltoall(dest: torch.Tensor, source: torch.Tensor) -> None:
return get_nvshmem_module().nvshmem_alltoall(dest, source)
def finalize() -> None:
torch.cuda.synchronize()
get_nvshmem_module().nvshmem_finalize()
def my_pe() -> int:
return get_nvshmem_module().nvshmem_my_pe()
def n_pes() -> int:
return get_nvshmem_module().nvshmem_n_pes()
def malloc(
shape: Sequence[int],
dtype: torch.dtype,
device: torch.device,
) -> torch.Tensor:
"""Allocates memory using NVSHMEM collective malloc operation.
This is a collective operation that requires participation by all PEs (Processing Elements).
All participants must call this function with the same parameters.
Note: This tensor should be explicitly deleted (del tensor) to ensure proper ordering
of nvshmem_free operations rather than relying on garbage collection.
Args:
shape: The shape of the tensor to allocate.
dtype: The data type of the tensor.
device: The device to allocate the tensor on.
Returns:
A tensor allocated using NVSHMEM collective malloc.
Reference:
https://docs.nvidia.com/nvshmem/api/gen/api/memory.html#nvshmem-malloc-nvshmem-free-nvshmem-align
"""
return get_nvshmem_module().nvshmem_malloc(shape, dtype, device)
def barrier_all() -> None:
get_nvshmem_module().nvshmem_barrier_all()
def barrier_all_on_current_stream() -> None:
get_nvshmem_module().nvshmem_barrier_all_on_current_stream()