sglang_v0.5.2/flashinfer_0.3.1/flashinfer/comm/mnnvl.py

1093 lines
39 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code imported from TensorRT-LLM/tensorrt_llm/_mnnvl_utils.py
import ctypes
import logging
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
import platform
import sys
from typing import Any, Dict, List, Optional, TYPE_CHECKING
import torch
try:
from cuda import cuda
except ImportError as e:
raise ImportError(
"Could not import the 'cuda' module. "
"Please install cuda-python that matches your CUDA version."
) from e
from ..cuda_utils import checkCudaErrors
from .dlpack_utils import create_dlpack_capsule, pack_strided_memory
from .mapping import Mapping
IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1"
# mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here
OMPI_COMM_TYPE_HOST = 9
# Constants from C++ header
SIGNAL_PAD_SIZE = 2048 # kSIGNAL_PAD_SIZE from header
MNNVL_DEBUG = False
def round_up(val: int, gran: int) -> int:
"""Efficient implementation assuming gran is a power of 2"""
return (val + gran - 1) & ~(gran - 1)
def create_tensor_from_cuda_memory(
ptr: int, shape: tuple, dtype: torch.dtype, device_id: int
) -> torch.Tensor:
"""
Create a PyTorch tensor from a CUDA memory pointer using DLPack.
Args:
ptr: CUDA memory pointer address as integer
shape: Desired tensor shape
dtype: PyTorch data type
device_id: CUDA device ID
Returns:
PyTorch tensor that wraps the CUDA memory
"""
# Calculate total size in elements
numel = 1
for dim in shape:
numel *= dim
# Get element size in bytes
element_size = torch.tensor([], dtype=dtype).element_size()
# Create DLPack capsule for contiguous memory (stride = element_size, num_segments = numel)
capsule_wrapper = create_dlpack_capsule(
ptr, element_size, element_size, numel, dtype, device_id
)
# Convert to tensor and reshape
tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule)
tensor._capsule_wrapper = capsule_wrapper # Keep reference to prevent GC
# Reshape to desired shape
return tensor.view(shape)
def test_cuda_memory_access(ptr: int, size: int, device_id: int) -> bool:
"""
Test if CUDA memory at ptr is accessible by trying to read/write a small amount.
Args:
ptr: CUDA memory pointer
size: Size of memory region
device_id: CUDA device ID
Returns:
True if memory is accessible, False otherwise
"""
try:
# Test with a small 4-byte read/write
test_size = min(4, size)
host_data = bytearray(test_size)
# Try to copy from device to host
checkCudaErrors(cuda.cuMemcpyDtoH(host_data, ptr, test_size))
# Try to copy back from host to device
checkCudaErrors(cuda.cuMemcpyHtoD(ptr, host_data, test_size))
print(f"DEBUG: Memory access test PASSED for ptr=0x{ptr:x}")
return True
except Exception as e:
print(f"DEBUG: Memory access test FAILED for ptr=0x{ptr:x}: {e}")
return False
def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]:
"""
A helper function that allocates memory on cuda and copies the data from the host to the device.
"""
if not host_ptr_array:
return None
ArrayType = ctypes.c_uint64 * len(host_ptr_array)
c_array = ArrayType(*host_ptr_array)
size_in_bytes = ctypes.sizeof(c_array)
device_ptr: cuda.CUdeviceptr = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes))
checkCudaErrors(
cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes)
)
# c_array should be freed by GC
return device_ptr
class CommBackend(ABC):
"""Abstract communication backend interface"""
@abstractmethod
def Get_rank(self) -> int: ...
@abstractmethod
def Get_size(self) -> int: ...
@abstractmethod
def allgather(self, data: int) -> List[int]: ...
@abstractmethod
def Split(self, color: int, key: int) -> "CommBackend": ...
if IS_BUILDING_DOCS:
# Mock classes for building docs
class MpiComm: # type: ignore[no-redef]
@classmethod
def set_mpi_comm(cls, new_comm):
pass
def __getattr__(self, name):
return None
class MnnvlMemory: # type: ignore[no-redef]
initialized: bool = False
current_mem_offset: int = 0
current_rank_stride: int = 0 # stride for ranks and also address space size.
current_start_address: int = 0
# allocation granularity
allocation_granularity: int = 0
# fabric address page size (512 MB)
fabric_page_size: int = 1 << 29
# MPI communicator
comm = None
dev_id: int = None
allocated_map: Dict[int, Any] = {}
address_refcnt: Dict[int, Any] = {}
def __init__(self, mapping: Mapping, size: int):
pass
def __del__(self):
pass
def as_torch_strided_tensor(self, dtype):
return None
@staticmethod
def initialize():
pass
@staticmethod
def get_comm(mapping: Mapping):
return None
@staticmethod
def get_allocation_prop(dev_id: int):
return None
@staticmethod
def get_allocation_granularity(dev_id: int):
return None
@staticmethod
def new_mnnvl_memory_address(mapping: Mapping, size: int):
pass
@staticmethod
def open_mnnvl_memory(mapping: Mapping, size: int):
return None
@staticmethod
def close_mnnvl_memory(ptr: int):
pass
@staticmethod
def support_nvlink(need_all_up: bool = True):
return None
@staticmethod
def supports_mnnvl() -> bool:
return False
else:
import pynvml
if TYPE_CHECKING:
from mpi4py import MPI # noqa: F401
def lazy_import_mpi():
"""Lazy import for mpi4py"""
try:
from mpi4py import MPI
return MPI
except ImportError as err:
raise ImportError("mpi4py is not installed") from err # type: ignore[no-redef]
class MpiComm: # type: ignore[no-redef]
_comm: Any = None
_MPI: Any = None
@classmethod
def _get_mpi(cls):
if cls._MPI is None:
cls._MPI = lazy_import_mpi()
cls._comm = cls._MPI.COMM_WORLD
return cls._MPI
@classmethod
def set_mpi_comm(cls, new_comm: Any):
cls._get_mpi()
# Optional: add type checking here
cls._comm = new_comm
def __getattr__(self, name):
if self._comm is None:
self._get_mpi()
return getattr(self._comm, name)
class MPIBackend(CommBackend):
def __init__(self):
self._mpicomm = MpiComm()
def Get_rank(self) -> int:
return self._mpicomm.Get_rank()
def Get_size(self) -> int:
return self._mpicomm.Get_size()
def allgather(self, data: int) -> List[int]:
return self._mpicomm.allgather(data)
def Split(self, color: int, key: int) -> CommBackend:
self._mpicomm = self._mpicomm.Split(color, key)
return MPIBackend() # Returns new adapter
@dataclass
class MnnvlConfig:
"""Configuration for MNNVL memory management"""
comm_backend: Optional[CommBackend] = None
allocation_granularity: int = 0
fabric_page_size: int = 1 << 29 # 512MB
class MnnvlMemory: # type: ignore[no-redef]
initialized: bool = False
current_mem_offset: int = 0
current_rank_stride: int = 0 # stride for ranks and also address space size.
current_start_address: int = 0
# allocation granularity
allocation_granularity: int = 0
# fabric address page size (512 MB)
fabric_page_size: int = 1 << 29
# MPI communicator
comm: Optional[CommBackend] = None
dev_id: int = None
allocated_map: Dict[int, Any] = {}
address_refcnt: Dict[int, Any] = {}
config: Optional[MnnvlConfig] = None
def __init__(self, mapping: Mapping, size: int):
self.mapping = mapping
self.segment_size = size
self.ptr, self.rank_stride = MnnvlMemory.open_mnnvl_memory(
self.mapping, size
)
def __del__(self):
if not sys.is_finalizing():
MnnvlMemory.close_mnnvl_memory(self.ptr)
def as_torch_strided_tensor(self, dtype):
num_segments = MnnvlMemory.comm.Get_size()
return pack_strided_memory(
self.ptr,
self.segment_size,
self.rank_stride,
num_segments,
dtype,
MnnvlMemory.dev_id,
)
@staticmethod
def initialize():
if not MnnvlMemory.initialized:
# use a dummy torch CUDA tensor to trigger CUDA context initialization
_ = torch.empty(1, device="cuda")
# ensure nvml is initialized.
try:
pynvml.nvmlDeviceGetCount()
except pynvml.NVMLError_Uninitialized:
pynvml.nvmlInit()
MnnvlMemory.initialized = True
@staticmethod
def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None):
MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined]
comm = config.comm_backend.Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
MnnvlMemory.comm = comm # type: ignore[assignment]
@staticmethod
def get_comm(mapping: Mapping):
if MnnvlMemory.comm is not None:
return MnnvlMemory.comm
comm = MpiComm().Split(
mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank
)
MnnvlMemory.comm = comm
return comm
@staticmethod
def get_allocation_prop(dev_id: int):
location = cuda.CUmemLocation()
location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
location.id = dev_id
allocation_prop = cuda.CUmemAllocationProp()
allocation_prop.type = (
cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
)
# TODO: We differentiate FABRIC for GB200 (aarch64) and POSIX_FILE_DESCRIPTOR for B200 (x86_64).
# May need to find a better way to handle this.
arch = platform.machine().lower()
is_on_aarch64 = "aarch64" in arch
if is_on_aarch64:
allocation_prop.requestedHandleTypes = (
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
)
else:
allocation_prop.requestedHandleTypes = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR
allocation_prop.location = location
return allocation_prop
@staticmethod
def get_allocation_granularity(dev_id: int):
if MnnvlMemory.allocation_granularity != 0:
return MnnvlMemory.allocation_granularity
allocation_prop = MnnvlMemory.get_allocation_prop(dev_id)
option = cuda.CUmemAllocationGranularity_flags(
cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED
)
granularity = checkCudaErrors(
cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option)
)
MnnvlMemory.allocation_granularity = granularity
return MnnvlMemory.allocation_granularity
@staticmethod
def new_mnnvl_memory_address(mapping: Mapping, size: int):
page_count = (
size + MnnvlMemory.fabric_page_size - 1
) // MnnvlMemory.fabric_page_size
current_rank_stride = page_count * MnnvlMemory.fabric_page_size
logging.info(
f"[MnnvlMemory] creating address with stride={current_rank_stride}"
)
comm = MnnvlMemory.get_comm(mapping)
comm_size = comm.Get_size()
address_size = current_rank_stride * comm_size
ptr = checkCudaErrors(
cuda.cuMemAddressReserve(
address_size, MnnvlMemory.fabric_page_size, 0, 0
)
)
MnnvlMemory.current_start_address = int(ptr)
MnnvlMemory.current_rank_stride = current_rank_stride
MnnvlMemory.current_mem_offset = 0
@staticmethod
def open_mnnvl_memory(mapping: Mapping, size: int):
dev = checkCudaErrors(cuda.cuCtxGetDevice())
dev_id = int(dev)
if MnnvlMemory.dev_id is None:
MnnvlMemory.dev_id = dev_id
assert dev_id == MnnvlMemory.dev_id, (
f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}"
)
comm = MnnvlMemory.get_comm(mapping)
comm_rank = comm.Get_rank()
comm_size = comm.Get_size()
all_rank_allocate_sizes = comm.allgather(size)
assert len(all_rank_allocate_sizes) == comm_size
assert all(x == size for x in all_rank_allocate_sizes), (
"Not all rank allocating same size."
)
granularity = MnnvlMemory.get_allocation_granularity(dev_id)
aligned_size = (size + granularity - 1) // granularity * granularity
if (
MnnvlMemory.current_mem_offset + aligned_size
> MnnvlMemory.current_rank_stride
):
MnnvlMemory.new_mnnvl_memory_address(mapping, aligned_size)
assert (
MnnvlMemory.current_mem_offset + aligned_size
<= MnnvlMemory.current_rank_stride
)
allocation_prop = MnnvlMemory.get_allocation_prop(dev_id)
allocated_mem_handle = checkCudaErrors(
cuda.cuMemCreate(aligned_size, allocation_prop, flags=0)
)
exported_fabric_handle = checkCudaErrors(
cuda.cuMemExportToShareableHandle(
allocated_mem_handle, allocation_prop.requestedHandleTypes, 0
)
)
if (
allocation_prop.requestedHandleTypes
== cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
):
all_handles_data = comm.allgather(exported_fabric_handle.data)
else:
all_handles_data = comm.allgather(exported_fabric_handle)
all_pids = comm.allgather(os.getpid())
libc = ctypes.CDLL(None, use_errno=True)
syscall = libc.syscall
SYS_pidfd_open = 434
SYS_pidfd_getfd = 438
pidfds = []
for pid in all_pids:
pidfd = syscall(SYS_pidfd_open, pid, 0)
if pidfd < 0:
err = ctypes.get_errno()
raise RuntimeError(
f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}"
)
pidfds.append(pidfd)
remote_fds = []
for pidfd, fd in zip(pidfds, all_handles_data):
remote_fd = syscall(SYS_pidfd_getfd, pidfd, fd, 0)
if remote_fd < 0:
err = ctypes.get_errno()
error_msg = f"pidfd_getfd(pidfd={pidfd}, fd={fd}) failed with errno {err}: {os.strerror(err)}."
if err == 1: # EPERM
error_msg += (
" Permission denied. If running in a container, try adding --cap-add=SYS_PTRACE "
"to your docker run command."
)
else:
error_msg += " This may be due to kernel version (requires Linux 5.6+)."
raise RuntimeError(error_msg)
remote_fds.append(remote_fd)
all_handles_data = remote_fds
# all_handles_data like b'\x00\x00\x00 \x00\x00\x00\x00\x8f\xec\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x1d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' # noqa: E501
# can use buf = memoryview(data) to import if using plain buffer for data.
madesc = cuda.CUmemAccessDesc()
madesc.location = allocation_prop.location
madesc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
mem_handles = [None] * comm_size
for i, remote_handle_data in enumerate(all_handles_data):
rank_ptr = (
MnnvlMemory.current_start_address
+ MnnvlMemory.current_rank_stride * i
+ MnnvlMemory.current_mem_offset
)
if i == comm_rank:
# Local memory mapping
mem_handles[i] = allocated_mem_handle
checkCudaErrors(
cuda.cuMemMap(
rank_ptr, aligned_size, 0, allocated_mem_handle, 0
)
)
else:
# Fabric memory mapping
imported_mem_handle = checkCudaErrors(
cuda.cuMemImportFromShareableHandle(
remote_handle_data, allocation_prop.requestedHandleTypes
)
)
mem_handles[i] = imported_mem_handle
checkCudaErrors(
cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0)
)
checkCudaErrors(
cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1)
)
ptr = MnnvlMemory.current_start_address + MnnvlMemory.current_mem_offset
stride = MnnvlMemory.current_rank_stride
MnnvlMemory.allocated_map[ptr] = (
mapping,
aligned_size,
mem_handles,
MnnvlMemory.current_start_address,
MnnvlMemory.current_rank_stride,
MnnvlMemory.current_mem_offset,
)
MnnvlMemory.address_refcnt[MnnvlMemory.current_start_address] = (
MnnvlMemory.address_refcnt.get(MnnvlMemory.current_start_address, 0) + 1
)
MnnvlMemory.current_mem_offset += aligned_size
return ptr, stride
@staticmethod
def close_mnnvl_memory(ptr: int):
(
mapping,
aligned_size,
mem_handles,
start_address,
rank_stride,
address_offset,
) = MnnvlMemory.allocated_map.pop(ptr)
comm = MnnvlMemory.get_comm(mapping)
comm_size = comm.Get_size()
for i in range(comm_size):
rank_ptr = start_address + i * rank_stride + address_offset
checkCudaErrors(cuda.cuMemUnmap(rank_ptr, aligned_size))
checkCudaErrors(cuda.cuMemRelease(mem_handles[i]))
MnnvlMemory.address_refcnt[start_address] -= 1
if MnnvlMemory.address_refcnt[start_address] == 0:
MnnvlMemory.address_refcnt.pop(start_address)
device_ptr = cuda.CUdeviceptr(start_address)
checkCudaErrors(
cuda.cuMemAddressFree(device_ptr, comm_size * rank_stride)
)
if start_address == MnnvlMemory.current_start_address:
MnnvlMemory.current_start_address = 0
MnnvlMemory.current_rank_stride = 0
MnnvlMemory.current_mem_offset = 0
@staticmethod
def support_nvlink(need_all_up: bool = True):
dev_id = torch.cuda.current_device()
handle = pynvml.nvmlDeviceGetHandleByIndex(dev_id)
link_count = pynvml.NVML_NVLINK_MAX_LINKS
active_links = 0
available_links = 0
for link_idx in range(link_count):
try:
if pynvml.nvmlDeviceGetNvLinkCapability(
handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED
):
available_links += 1
is_active = pynvml.nvmlDeviceGetNvLinkState(handle, link_idx)
if is_active:
active_links += 1
except pynvml.NVMLError_NotSupported:
continue
return (
active_links == available_links and available_links > 0
if need_all_up
else available_links > 0
)
@staticmethod
def supports_mnnvl() -> bool:
# TODO:
# We check if it has all NVLink up now.
# But it is not equivalent to MNNVL support.
# May need better support check.
support_nvlink_and_all_up = MnnvlMemory.support_nvlink(True)
return support_nvlink_and_all_up
class McastDeviceMemory:
"""Python port of McastDeviceMemory from TensorRT-LLM"""
def __init__(
self,
buf_size: int,
group_size: int,
group_rank: int,
device_idx: int,
is_multi_node: bool = True,
):
cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx))
primary_ctx = checkCudaErrors(cuda.cuDevicePrimaryCtxRetain(cu_device))
checkCudaErrors(cuda.cuCtxSetCurrent(primary_ctx))
# Set CUDA device
# Check if cuda.cudart is available and import accordingly
from flashinfer.utils import has_cuda_cudart
if has_cuda_cudart():
# cuda-python <= 12.9
import cuda.cudart as cudart
else:
# cuda-python >= 13.0
import cuda.bindings.runtime as cudart
checkCudaErrors(cudart.cudaSetDevice(device_idx))
self.is_multi_node = is_multi_node
self.device_idx = device_idx
self.group_size = group_size
self.group_rank = group_rank
self.buf_size = buf_size
self.signal_pad_offset = 0
self.allocation_size = 0
# CUDA memory handles and pointers
self.mc_ptr = 0 # CUdeviceptr mMcPtr
self.uc_ptrs: List[int] = [] # std::vector<CUdeviceptr> mUcPtrs
self.signal_pads: List[int] = [] # mSignalPads
self.signal_pads_dev = 0 # std::vector<CUdeviceptr> mSignalPadsDev
self.uc_ptrs_dev = 0
self.mc_handle = 0 # CUmemGenericAllocationHandle mMcHandle
self.uc_handles: List[
int
] = [] # std::vector<CUmemGenericAllocationHandle> mUcHandles
# Signal pad constants
self.SIGNAL_PAD_ALIGNMENT = 16
self.SIGNAL_PAD_SIZE = SIGNAL_PAD_SIZE
# Check if device supports multicasting
multicast_supported = checkCudaErrors(
cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED,
device_idx,
)
)
if multicast_supported == 0:
raise RuntimeError(
"[McastDeviceMemory] Device does not support multicasting."
)
# Calculate signal pad offset with alignment (matching C++ exactly)
self.signal_pad_offset = round_up(buf_size, self.SIGNAL_PAD_ALIGNMENT)
logging.info(
f"[McastDeviceMemory] Rank: {group_rank}, Group size: {group_size}, "
f"mnNvlink: {is_multi_node}, device_idx: {device_idx}, "
f"Signal pad offset: {self.signal_pad_offset}"
)
if self.is_multi_node:
# Check if fabric handle is supported
fabric_handle_supported = checkCudaErrors(
cuda.cuDeviceGetAttribute(
cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED,
device_idx,
)
)
if fabric_handle_supported == 0:
raise RuntimeError(
"[McastDeviceMemory] Device does not support fabric handle."
)
self._alloc_mn_mcast_mem(buf_size)
else:
# For single-node NVLS, would need to implement _alloc_nvls_mcast_mem
raise NotImplementedError("Single-node NVLS allocation not implemented yet")
# Initialize signal pads
self.signal_pads = [0] * self.group_size
for i in range(self.group_size):
self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset
if i == self.group_rank:
checkCudaErrors(
cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE)
)
# Create device pointers
self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads)
self.uc_ptrs_dev = alloc_and_copy_to_cuda(self.uc_ptrs)
def __del__(self):
"""Destructor - cleanup allocated memory"""
# Check if we're in a valid state for cleanup
if not hasattr(self, "is_multi_node"):
return
if not self.is_multi_node:
return
# Skip cleanup during Python finalization to avoid segfaults
# Especially cause the CUDA context could be destroyed at this point.
if sys.is_finalizing():
return
# Verify CUDA context is still valid
try:
cuda.cuCtxGetCurrent()
except Exception as e:
print(f"Destructor: CUDA context invalid, skipping cleanup: {e}")
return
# Free device pointers
if self.signal_pads_dev:
checkCudaErrors(cuda.cuMemFree(self.signal_pads_dev))
if self.uc_ptrs_dev:
checkCudaErrors(cuda.cuMemFree(self.uc_ptrs_dev))
# Unmap UC regions and release their handles
if hasattr(self, "uc_handles") and self.uc_handles:
for rank in range(self.group_size):
if self.uc_handles[rank] != 0:
try:
# Release the handle
checkCudaErrors(cuda.cuMemRelease(self.uc_handles[rank]))
# Unmap the vmem
if rank < len(self.uc_ptrs) and self.uc_ptrs[rank]:
checkCudaErrors(
cuda.cuMemUnmap(
self.uc_ptrs[rank], self.allocation_size
)
)
except Exception as e:
print(
f"Destructor: Failed to release UC handle for rank {rank}: {e}"
)
# Free the UC address space
if hasattr(self, "uc_base_ptr") and self.uc_base_ptr:
checkCudaErrors(
cuda.cuMemAddressFree(self.uc_base_ptr, self.total_uc_size)
)
# Release MC handle
if hasattr(self, "mc_handle") and self.mc_handle and self.mc_handle != 0:
try:
checkCudaErrors(cuda.cuMemUnmap(self.mc_ptr, self.allocation_size))
checkCudaErrors(
cuda.cuMemAddressFree(self.mc_ptr, self.allocation_size)
)
checkCudaErrors(cuda.cuMemRelease(self.mc_handle))
except Exception as e:
print(f"Destructor: Failed to release MC handle: {e}")
def get_signal_pad_ptrs_host(self) -> List[int]:
"""Get the raw array of signal pad pointers to all ranks (including self)"""
return self.signal_pads
def get_buffer_ptrs_host(self) -> List[int]:
"""Get the raw array of unicast pointers to all ranks (including self)"""
return self.uc_ptrs
def get_signal_pad_ptrs_dev(self) -> int:
"""Get the raw array of signal pad pointers to all ranks (including self)"""
return self.signal_pads_dev
def get_buffer_ptrs_dev(self) -> int:
"""Get the raw array of unicast pointers to all ranks (including self)"""
return self.uc_ptrs_dev
def get_unicast_ptr(self, rank: int) -> int:
"""Get the raw unicast pointer to a given rank"""
if rank >= len(self.uc_ptrs):
raise ValueError(f"Rank {rank} out of range (0-{len(self.uc_ptrs) - 1})")
data_ptr = self.uc_ptrs[rank]
# Note: In C++, this would call tensorrt_llm::common::registerMcastDevMemBuffer
# For Python port, we skip this registration for now
return data_ptr
def get_multicast_ptr(self) -> int:
"""Get the raw multicast pointer"""
# Note: In C++, this would call tensorrt_llm::common::registerMcastDevMemBuffer
# For Python port, we skip this registration for now
return int(self.mc_ptr)
def get_rank(self) -> int:
"""Get the rank of this device in the group"""
return self.group_rank
def get_world_size(self) -> int:
"""Get the total number of devices in the group"""
return self.group_size
def _alloc_mn_mcast_mem(self, buf_size: int):
"""Allocate multi-node multicast memory using MNNVL"""
# Verify CUDA context
try:
current_device = checkCudaErrors(cuda.cuCtxGetDevice())
if int(current_device) != self.device_idx:
print(
f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}"
)
except Exception as e:
print(f"Error checking CUDA context: {e}")
# Get MPI communicator
comm = MpiComm()
# Set up allocation properties
handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC
allocation_prop = cuda.CUmemAllocationProp()
allocation_prop.requestedHandleTypes = handle_type
allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED
allocation_prop.location = cuda.CUmemLocation()
allocation_prop.location.type = (
cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
)
allocation_prop.location.id = self.device_idx
allocation_prop.allocFlags.gpuDirectRDMACapable = 1
# Get allocation granularity
alloc_granularity = checkCudaErrors(
cuda.cuMemGetAllocationGranularity(
allocation_prop,
cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM,
)
)
# mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity);
self.allocation_size = round_up(
buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity
)
# Set up multicast properties
mc_prop = cuda.CUmulticastObjectProp()
mc_prop.numDevices = self.group_size
mc_prop.size = self.allocation_size
mc_prop.handleTypes = handle_type
# Get multicast granularity
mc_granularity = checkCudaErrors(
cuda.cuMulticastGetGranularity(
mc_prop,
cuda.CUmulticastGranularity_flags.CU_MULTICAST_GRANULARITY_RECOMMENDED,
)
)
self.allocation_size = round_up(self.allocation_size, mc_granularity)
# Initialize UC handles list
self.uc_handles = [0] * self.group_size
# Allocate local GPU memory
self.uc_handles[self.group_rank] = checkCudaErrors(
cuda.cuMemCreate(self.allocation_size, allocation_prop, 0)
)
# Export local handle to fabric handle
my_fabric_handle = checkCudaErrors(
cuda.cuMemExportToShareableHandle(
self.uc_handles[self.group_rank],
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC,
0,
)
)
# All-gather fabric handles
all_fabric_handles = comm.allgather(my_fabric_handle.data)
cuda.cuCtxSynchronize()
# Import remote handles
for p in range(self.group_size):
if p != self.group_rank:
self.uc_handles[p] = checkCudaErrors(
cuda.cuMemImportFromShareableHandle(
all_fabric_handles[p],
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC,
)
)
# Initialize multicasting
if self.group_rank == 0:
# Create multicast object
self.mc_handle = checkCudaErrors(cuda.cuMulticastCreate(mc_prop))
# Export multicast handle
mc_fabric_handle = checkCudaErrors(
cuda.cuMemExportToShareableHandle(
self.mc_handle,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC,
0,
)
)
else:
mc_fabric_handle = None
# Broadcast multicast handle
mc_fabric_handle_data = comm.bcast(
mc_fabric_handle.data if mc_fabric_handle else None, root=0
)
# Sync device to ensure broadcast is complete
cuda.cuCtxSynchronize()
# Import multicast handle for non-root ranks
if self.group_rank != 0:
self.mc_handle = checkCudaErrors(
cuda.cuMemImportFromShareableHandle(
mc_fabric_handle_data,
cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC,
)
)
# Add device to multicast
checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx))
# Bind memory addresses
self.uc_ptrs = [0] * self.group_size
# Reserve address space for UC pointers
total_uc_size = self.allocation_size * self.group_size
self.total_uc_size = total_uc_size
uc_base_ptr = checkCudaErrors(
cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0)
)
self.uc_base_ptr = uc_base_ptr # Store for cleanup
# Set up memory access descriptor
access_desc = cuda.CUmemAccessDesc()
access_desc.location = cuda.CUmemLocation()
access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
access_desc.location.id = self.device_idx
access_desc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
# Map UC memory
for i in range(self.group_size):
offset = self.allocation_size * i
self.uc_ptrs[i] = int(uc_base_ptr) + offset
checkCudaErrors(
cuda.cuMemMap(
self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0
)
)
# Set memory access permissions
checkCudaErrors(
cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1)
)
# Bind MC pointer
self.mc_ptr = checkCudaErrors(
cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0)
)
checkCudaErrors(
cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0)
)
checkCudaErrors(
cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1)
)
# Bind memory to multicast
checkCudaErrors(
cuda.cuMulticastBindMem(
self.mc_handle,
0, # mcOffset
self.uc_handles[self.group_rank],
0, # memOffset
self.allocation_size,
0, # flags
)
)
def lamport_initialize(self, rank: int, dtype: torch.dtype):
if dtype == torch.bfloat16 or dtype == torch.float16:
neg_zero = 0x8000
dsize = 2
memset_func = cuda.cuMemsetD16
elif dtype == torch.float32:
neg_zero = 0x80000000
dsize = 4
memset_func = cuda.cuMemsetD32
else:
raise ValueError(f"Unsupported dtype: {dtype}")
# Calculate number of elements that fit in allocation_size
num_elements = self.allocation_size // dsize
checkCudaErrors(
memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements)
)
class McastGPUBuffer:
"""
Wrapper class for McastDeviceMemory to facilitate PyTorch tensor creation.
It manages a buffer accessible via unicast or multicast for multi-node communication.
Python port of McastGPUBuffer from TensorRT-LLM
"""
def __init__(
self,
buf_size: int,
group_size: int,
group_rank: int,
device: torch.device,
mn_nvlink: bool = True,
):
"""
Constructor for McastGpuBuffer.
Args:
buf_size: The total size of the buffer in bytes
group_size: The number of ranks in the communication group
group_rank: The rank of the local process within the group
device: The CUDA device for buffer allocation
mn_nvlink: Flag indicating if multi-node NVLink is used
"""
self.mcast_device_memory = McastDeviceMemory(
buf_size, group_size, group_rank, device.index, mn_nvlink
)
self.buf_size = buf_size
self.local_device = device
def lamport_initialize(self, rank: int, dtype: torch.dtype):
self.mcast_device_memory.lamport_initialize(rank, dtype)
def get_mc_buffer(
self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0
) -> torch.Tensor:
"""
Returns a PyTorch tensor view of the multicast buffer portion.
Args:
sizes: The desired shape (dimensions) of the tensor
dtype: The data type of the tensor elements
storage_offset: The offset in elements from the start of the buffer
Returns:
A PyTorch tensor wrapping the multicast buffer section
"""
raise NotImplementedError("Not implemented yet")
def get_multicast_ptr(self) -> int:
"""Get the raw multicast pointer"""
return self.mcast_device_memory.get_multicast_ptr()
def get_buffer_ptrs_dev(self) -> int:
"""Get the buffer pointers device array"""
return self.mcast_device_memory.get_buffer_ptrs_dev()