# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Code imported from TensorRT-LLM/tensorrt_llm/_mnnvl_utils.py import ctypes import logging import os from abc import ABC, abstractmethod from dataclasses import dataclass import platform import sys from typing import Any, Dict, List, Optional, TYPE_CHECKING import torch try: from cuda import cuda except ImportError as e: raise ImportError( "Could not import the 'cuda' module. " "Please install cuda-python that matches your CUDA version." ) from e from ..cuda_utils import checkCudaErrors from .dlpack_utils import create_dlpack_capsule, pack_strided_memory from .mapping import Mapping IS_BUILDING_DOCS = os.environ.get("FLASHINFER_BUILDING_DOCS") == "1" # mpi4py only exports MPI_COMM_TYPE_SHARED, so we define OMPI_COMM_TYPE_HOST here OMPI_COMM_TYPE_HOST = 9 # Constants from C++ header SIGNAL_PAD_SIZE = 2048 # kSIGNAL_PAD_SIZE from header MNNVL_DEBUG = False def round_up(val: int, gran: int) -> int: """Efficient implementation assuming gran is a power of 2""" return (val + gran - 1) & ~(gran - 1) def create_tensor_from_cuda_memory( ptr: int, shape: tuple, dtype: torch.dtype, device_id: int ) -> torch.Tensor: """ Create a PyTorch tensor from a CUDA memory pointer using DLPack. Args: ptr: CUDA memory pointer address as integer shape: Desired tensor shape dtype: PyTorch data type device_id: CUDA device ID Returns: PyTorch tensor that wraps the CUDA memory """ # Calculate total size in elements numel = 1 for dim in shape: numel *= dim # Get element size in bytes element_size = torch.tensor([], dtype=dtype).element_size() # Create DLPack capsule for contiguous memory (stride = element_size, num_segments = numel) capsule_wrapper = create_dlpack_capsule( ptr, element_size, element_size, numel, dtype, device_id ) # Convert to tensor and reshape tensor = torch.utils.dlpack.from_dlpack(capsule_wrapper.capsule) tensor._capsule_wrapper = capsule_wrapper # Keep reference to prevent GC # Reshape to desired shape return tensor.view(shape) def test_cuda_memory_access(ptr: int, size: int, device_id: int) -> bool: """ Test if CUDA memory at ptr is accessible by trying to read/write a small amount. Args: ptr: CUDA memory pointer size: Size of memory region device_id: CUDA device ID Returns: True if memory is accessible, False otherwise """ try: # Test with a small 4-byte read/write test_size = min(4, size) host_data = bytearray(test_size) # Try to copy from device to host checkCudaErrors(cuda.cuMemcpyDtoH(host_data, ptr, test_size)) # Try to copy back from host to device checkCudaErrors(cuda.cuMemcpyHtoD(ptr, host_data, test_size)) print(f"DEBUG: Memory access test PASSED for ptr=0x{ptr:x}") return True except Exception as e: print(f"DEBUG: Memory access test FAILED for ptr=0x{ptr:x}: {e}") return False def alloc_and_copy_to_cuda(host_ptr_array: List[int]) -> Optional[int]: """ A helper function that allocates memory on cuda and copies the data from the host to the device. """ if not host_ptr_array: return None ArrayType = ctypes.c_uint64 * len(host_ptr_array) c_array = ArrayType(*host_ptr_array) size_in_bytes = ctypes.sizeof(c_array) device_ptr: cuda.CUdeviceptr = checkCudaErrors(cuda.cuMemAlloc(size_in_bytes)) checkCudaErrors( cuda.cuMemcpyHtoD(device_ptr, ctypes.addressof(c_array), size_in_bytes) ) # c_array should be freed by GC return device_ptr class CommBackend(ABC): """Abstract communication backend interface""" @abstractmethod def Get_rank(self) -> int: ... @abstractmethod def Get_size(self) -> int: ... @abstractmethod def allgather(self, data: int) -> List[int]: ... @abstractmethod def Split(self, color: int, key: int) -> "CommBackend": ... if IS_BUILDING_DOCS: # Mock classes for building docs class MpiComm: # type: ignore[no-redef] @classmethod def set_mpi_comm(cls, new_comm): pass def __getattr__(self, name): return None class MnnvlMemory: # type: ignore[no-redef] initialized: bool = False current_mem_offset: int = 0 current_rank_stride: int = 0 # stride for ranks and also address space size. current_start_address: int = 0 # allocation granularity allocation_granularity: int = 0 # fabric address page size (512 MB) fabric_page_size: int = 1 << 29 # MPI communicator comm = None dev_id: int = None allocated_map: Dict[int, Any] = {} address_refcnt: Dict[int, Any] = {} def __init__(self, mapping: Mapping, size: int): pass def __del__(self): pass def as_torch_strided_tensor(self, dtype): return None @staticmethod def initialize(): pass @staticmethod def get_comm(mapping: Mapping): return None @staticmethod def get_allocation_prop(dev_id: int): return None @staticmethod def get_allocation_granularity(dev_id: int): return None @staticmethod def new_mnnvl_memory_address(mapping: Mapping, size: int): pass @staticmethod def open_mnnvl_memory(mapping: Mapping, size: int): return None @staticmethod def close_mnnvl_memory(ptr: int): pass @staticmethod def support_nvlink(need_all_up: bool = True): return None @staticmethod def supports_mnnvl() -> bool: return False else: import pynvml if TYPE_CHECKING: from mpi4py import MPI # noqa: F401 def lazy_import_mpi(): """Lazy import for mpi4py""" try: from mpi4py import MPI return MPI except ImportError as err: raise ImportError("mpi4py is not installed") from err # type: ignore[no-redef] class MpiComm: # type: ignore[no-redef] _comm: Any = None _MPI: Any = None @classmethod def _get_mpi(cls): if cls._MPI is None: cls._MPI = lazy_import_mpi() cls._comm = cls._MPI.COMM_WORLD return cls._MPI @classmethod def set_mpi_comm(cls, new_comm: Any): cls._get_mpi() # Optional: add type checking here cls._comm = new_comm def __getattr__(self, name): if self._comm is None: self._get_mpi() return getattr(self._comm, name) class MPIBackend(CommBackend): def __init__(self): self._mpicomm = MpiComm() def Get_rank(self) -> int: return self._mpicomm.Get_rank() def Get_size(self) -> int: return self._mpicomm.Get_size() def allgather(self, data: int) -> List[int]: return self._mpicomm.allgather(data) def Split(self, color: int, key: int) -> CommBackend: self._mpicomm = self._mpicomm.Split(color, key) return MPIBackend() # Returns new adapter @dataclass class MnnvlConfig: """Configuration for MNNVL memory management""" comm_backend: Optional[CommBackend] = None allocation_granularity: int = 0 fabric_page_size: int = 1 << 29 # 512MB class MnnvlMemory: # type: ignore[no-redef] initialized: bool = False current_mem_offset: int = 0 current_rank_stride: int = 0 # stride for ranks and also address space size. current_start_address: int = 0 # allocation granularity allocation_granularity: int = 0 # fabric address page size (512 MB) fabric_page_size: int = 1 << 29 # MPI communicator comm: Optional[CommBackend] = None dev_id: int = None allocated_map: Dict[int, Any] = {} address_refcnt: Dict[int, Any] = {} config: Optional[MnnvlConfig] = None def __init__(self, mapping: Mapping, size: int): self.mapping = mapping self.segment_size = size self.ptr, self.rank_stride = MnnvlMemory.open_mnnvl_memory( self.mapping, size ) def __del__(self): if not sys.is_finalizing(): MnnvlMemory.close_mnnvl_memory(self.ptr) def as_torch_strided_tensor(self, dtype): num_segments = MnnvlMemory.comm.Get_size() return pack_strided_memory( self.ptr, self.segment_size, self.rank_stride, num_segments, dtype, MnnvlMemory.dev_id, ) @staticmethod def initialize(): if not MnnvlMemory.initialized: # use a dummy torch CUDA tensor to trigger CUDA context initialization _ = torch.empty(1, device="cuda") # ensure nvml is initialized. try: pynvml.nvmlDeviceGetCount() except pynvml.NVMLError_Uninitialized: pynvml.nvmlInit() MnnvlMemory.initialized = True @staticmethod def set_comm_from_config(mapping: Mapping, config: MnnvlConfig = None): MnnvlMemory.config = config or MnnvlConfig(comm_backend=MPIBackend()) # type: ignore[attr-defined] comm = config.comm_backend.Split( mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank ) MnnvlMemory.comm = comm # type: ignore[assignment] @staticmethod def get_comm(mapping: Mapping): if MnnvlMemory.comm is not None: return MnnvlMemory.comm comm = MpiComm().Split( mapping.pp_rank * mapping.cp_size + mapping.cp_rank, mapping.tp_rank ) MnnvlMemory.comm = comm return comm @staticmethod def get_allocation_prop(dev_id: int): location = cuda.CUmemLocation() location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE location.id = dev_id allocation_prop = cuda.CUmemAllocationProp() allocation_prop.type = ( cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED ) # TODO: We differentiate FABRIC for GB200 (aarch64) and POSIX_FILE_DESCRIPTOR for B200 (x86_64). # May need to find a better way to handle this. arch = platform.machine().lower() is_on_aarch64 = "aarch64" in arch if is_on_aarch64: allocation_prop.requestedHandleTypes = ( cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC ) else: allocation_prop.requestedHandleTypes = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR allocation_prop.location = location return allocation_prop @staticmethod def get_allocation_granularity(dev_id: int): if MnnvlMemory.allocation_granularity != 0: return MnnvlMemory.allocation_granularity allocation_prop = MnnvlMemory.get_allocation_prop(dev_id) option = cuda.CUmemAllocationGranularity_flags( cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_RECOMMENDED ) granularity = checkCudaErrors( cuda.cuMemGetAllocationGranularity(prop=allocation_prop, option=option) ) MnnvlMemory.allocation_granularity = granularity return MnnvlMemory.allocation_granularity @staticmethod def new_mnnvl_memory_address(mapping: Mapping, size: int): page_count = ( size + MnnvlMemory.fabric_page_size - 1 ) // MnnvlMemory.fabric_page_size current_rank_stride = page_count * MnnvlMemory.fabric_page_size logging.info( f"[MnnvlMemory] creating address with stride={current_rank_stride}" ) comm = MnnvlMemory.get_comm(mapping) comm_size = comm.Get_size() address_size = current_rank_stride * comm_size ptr = checkCudaErrors( cuda.cuMemAddressReserve( address_size, MnnvlMemory.fabric_page_size, 0, 0 ) ) MnnvlMemory.current_start_address = int(ptr) MnnvlMemory.current_rank_stride = current_rank_stride MnnvlMemory.current_mem_offset = 0 @staticmethod def open_mnnvl_memory(mapping: Mapping, size: int): dev = checkCudaErrors(cuda.cuCtxGetDevice()) dev_id = int(dev) if MnnvlMemory.dev_id is None: MnnvlMemory.dev_id = dev_id assert dev_id == MnnvlMemory.dev_id, ( f"Different dev_id found dev_id={dev_id} but MnnvlMemory.dev_id={MnnvlMemory.dev_id}" ) comm = MnnvlMemory.get_comm(mapping) comm_rank = comm.Get_rank() comm_size = comm.Get_size() all_rank_allocate_sizes = comm.allgather(size) assert len(all_rank_allocate_sizes) == comm_size assert all(x == size for x in all_rank_allocate_sizes), ( "Not all rank allocating same size." ) granularity = MnnvlMemory.get_allocation_granularity(dev_id) aligned_size = (size + granularity - 1) // granularity * granularity if ( MnnvlMemory.current_mem_offset + aligned_size > MnnvlMemory.current_rank_stride ): MnnvlMemory.new_mnnvl_memory_address(mapping, aligned_size) assert ( MnnvlMemory.current_mem_offset + aligned_size <= MnnvlMemory.current_rank_stride ) allocation_prop = MnnvlMemory.get_allocation_prop(dev_id) allocated_mem_handle = checkCudaErrors( cuda.cuMemCreate(aligned_size, allocation_prop, flags=0) ) exported_fabric_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( allocated_mem_handle, allocation_prop.requestedHandleTypes, 0 ) ) if ( allocation_prop.requestedHandleTypes == cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC ): all_handles_data = comm.allgather(exported_fabric_handle.data) else: all_handles_data = comm.allgather(exported_fabric_handle) all_pids = comm.allgather(os.getpid()) libc = ctypes.CDLL(None, use_errno=True) syscall = libc.syscall SYS_pidfd_open = 434 SYS_pidfd_getfd = 438 pidfds = [] for pid in all_pids: pidfd = syscall(SYS_pidfd_open, pid, 0) if pidfd < 0: err = ctypes.get_errno() raise RuntimeError( f"pidfd_open({pid}) failed with errno {err}: {os.strerror(err)}" ) pidfds.append(pidfd) remote_fds = [] for pidfd, fd in zip(pidfds, all_handles_data): remote_fd = syscall(SYS_pidfd_getfd, pidfd, fd, 0) if remote_fd < 0: err = ctypes.get_errno() error_msg = f"pidfd_getfd(pidfd={pidfd}, fd={fd}) failed with errno {err}: {os.strerror(err)}." if err == 1: # EPERM error_msg += ( " Permission denied. If running in a container, try adding --cap-add=SYS_PTRACE " "to your docker run command." ) else: error_msg += " This may be due to kernel version (requires Linux 5.6+)." raise RuntimeError(error_msg) remote_fds.append(remote_fd) all_handles_data = remote_fds # all_handles_data like b'\x00\x00\x00 \x00\x00\x00\x00\x8f\xec\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\t\x00\x00\x00\x00\x00\x1d\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00' # noqa: E501 # can use buf = memoryview(data) to import if using plain buffer for data. madesc = cuda.CUmemAccessDesc() madesc.location = allocation_prop.location madesc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE mem_handles = [None] * comm_size for i, remote_handle_data in enumerate(all_handles_data): rank_ptr = ( MnnvlMemory.current_start_address + MnnvlMemory.current_rank_stride * i + MnnvlMemory.current_mem_offset ) if i == comm_rank: # Local memory mapping mem_handles[i] = allocated_mem_handle checkCudaErrors( cuda.cuMemMap( rank_ptr, aligned_size, 0, allocated_mem_handle, 0 ) ) else: # Fabric memory mapping imported_mem_handle = checkCudaErrors( cuda.cuMemImportFromShareableHandle( remote_handle_data, allocation_prop.requestedHandleTypes ) ) mem_handles[i] = imported_mem_handle checkCudaErrors( cuda.cuMemMap(rank_ptr, aligned_size, 0, imported_mem_handle, 0) ) checkCudaErrors( cuda.cuMemSetAccess(rank_ptr, aligned_size, [madesc], 1) ) ptr = MnnvlMemory.current_start_address + MnnvlMemory.current_mem_offset stride = MnnvlMemory.current_rank_stride MnnvlMemory.allocated_map[ptr] = ( mapping, aligned_size, mem_handles, MnnvlMemory.current_start_address, MnnvlMemory.current_rank_stride, MnnvlMemory.current_mem_offset, ) MnnvlMemory.address_refcnt[MnnvlMemory.current_start_address] = ( MnnvlMemory.address_refcnt.get(MnnvlMemory.current_start_address, 0) + 1 ) MnnvlMemory.current_mem_offset += aligned_size return ptr, stride @staticmethod def close_mnnvl_memory(ptr: int): ( mapping, aligned_size, mem_handles, start_address, rank_stride, address_offset, ) = MnnvlMemory.allocated_map.pop(ptr) comm = MnnvlMemory.get_comm(mapping) comm_size = comm.Get_size() for i in range(comm_size): rank_ptr = start_address + i * rank_stride + address_offset checkCudaErrors(cuda.cuMemUnmap(rank_ptr, aligned_size)) checkCudaErrors(cuda.cuMemRelease(mem_handles[i])) MnnvlMemory.address_refcnt[start_address] -= 1 if MnnvlMemory.address_refcnt[start_address] == 0: MnnvlMemory.address_refcnt.pop(start_address) device_ptr = cuda.CUdeviceptr(start_address) checkCudaErrors( cuda.cuMemAddressFree(device_ptr, comm_size * rank_stride) ) if start_address == MnnvlMemory.current_start_address: MnnvlMemory.current_start_address = 0 MnnvlMemory.current_rank_stride = 0 MnnvlMemory.current_mem_offset = 0 @staticmethod def support_nvlink(need_all_up: bool = True): dev_id = torch.cuda.current_device() handle = pynvml.nvmlDeviceGetHandleByIndex(dev_id) link_count = pynvml.NVML_NVLINK_MAX_LINKS active_links = 0 available_links = 0 for link_idx in range(link_count): try: if pynvml.nvmlDeviceGetNvLinkCapability( handle, link_idx, pynvml.NVML_NVLINK_CAP_P2P_SUPPORTED ): available_links += 1 is_active = pynvml.nvmlDeviceGetNvLinkState(handle, link_idx) if is_active: active_links += 1 except pynvml.NVMLError_NotSupported: continue return ( active_links == available_links and available_links > 0 if need_all_up else available_links > 0 ) @staticmethod def supports_mnnvl() -> bool: # TODO: # We check if it has all NVLink up now. # But it is not equivalent to MNNVL support. # May need better support check. support_nvlink_and_all_up = MnnvlMemory.support_nvlink(True) return support_nvlink_and_all_up class McastDeviceMemory: """Python port of McastDeviceMemory from TensorRT-LLM""" def __init__( self, buf_size: int, group_size: int, group_rank: int, device_idx: int, is_multi_node: bool = True, ): cu_device = checkCudaErrors(cuda.cuDeviceGet(device_idx)) primary_ctx = checkCudaErrors(cuda.cuDevicePrimaryCtxRetain(cu_device)) checkCudaErrors(cuda.cuCtxSetCurrent(primary_ctx)) # Set CUDA device # Check if cuda.cudart is available and import accordingly from flashinfer.utils import has_cuda_cudart if has_cuda_cudart(): # cuda-python <= 12.9 import cuda.cudart as cudart else: # cuda-python >= 13.0 import cuda.bindings.runtime as cudart checkCudaErrors(cudart.cudaSetDevice(device_idx)) self.is_multi_node = is_multi_node self.device_idx = device_idx self.group_size = group_size self.group_rank = group_rank self.buf_size = buf_size self.signal_pad_offset = 0 self.allocation_size = 0 # CUDA memory handles and pointers self.mc_ptr = 0 # CUdeviceptr mMcPtr self.uc_ptrs: List[int] = [] # std::vector mUcPtrs self.signal_pads: List[int] = [] # mSignalPads self.signal_pads_dev = 0 # std::vector mSignalPadsDev self.uc_ptrs_dev = 0 self.mc_handle = 0 # CUmemGenericAllocationHandle mMcHandle self.uc_handles: List[ int ] = [] # std::vector mUcHandles # Signal pad constants self.SIGNAL_PAD_ALIGNMENT = 16 self.SIGNAL_PAD_SIZE = SIGNAL_PAD_SIZE # Check if device supports multicasting multicast_supported = checkCudaErrors( cuda.cuDeviceGetAttribute( cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, device_idx, ) ) if multicast_supported == 0: raise RuntimeError( "[McastDeviceMemory] Device does not support multicasting." ) # Calculate signal pad offset with alignment (matching C++ exactly) self.signal_pad_offset = round_up(buf_size, self.SIGNAL_PAD_ALIGNMENT) logging.info( f"[McastDeviceMemory] Rank: {group_rank}, Group size: {group_size}, " f"mnNvlink: {is_multi_node}, device_idx: {device_idx}, " f"Signal pad offset: {self.signal_pad_offset}" ) if self.is_multi_node: # Check if fabric handle is supported fabric_handle_supported = checkCudaErrors( cuda.cuDeviceGetAttribute( cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, device_idx, ) ) if fabric_handle_supported == 0: raise RuntimeError( "[McastDeviceMemory] Device does not support fabric handle." ) self._alloc_mn_mcast_mem(buf_size) else: # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem raise NotImplementedError("Single-node NVLS allocation not implemented yet") # Initialize signal pads self.signal_pads = [0] * self.group_size for i in range(self.group_size): self.signal_pads[i] = self.uc_ptrs[i] + self.signal_pad_offset if i == self.group_rank: checkCudaErrors( cuda.cuMemsetD8(self.signal_pads[i], 0, self.SIGNAL_PAD_SIZE) ) # Create device pointers self.signal_pads_dev = alloc_and_copy_to_cuda(self.signal_pads) self.uc_ptrs_dev = alloc_and_copy_to_cuda(self.uc_ptrs) def __del__(self): """Destructor - cleanup allocated memory""" # Check if we're in a valid state for cleanup if not hasattr(self, "is_multi_node"): return if not self.is_multi_node: return # Skip cleanup during Python finalization to avoid segfaults # Especially cause the CUDA context could be destroyed at this point. if sys.is_finalizing(): return # Verify CUDA context is still valid try: cuda.cuCtxGetCurrent() except Exception as e: print(f"Destructor: CUDA context invalid, skipping cleanup: {e}") return # Free device pointers if self.signal_pads_dev: checkCudaErrors(cuda.cuMemFree(self.signal_pads_dev)) if self.uc_ptrs_dev: checkCudaErrors(cuda.cuMemFree(self.uc_ptrs_dev)) # Unmap UC regions and release their handles if hasattr(self, "uc_handles") and self.uc_handles: for rank in range(self.group_size): if self.uc_handles[rank] != 0: try: # Release the handle checkCudaErrors(cuda.cuMemRelease(self.uc_handles[rank])) # Unmap the vmem if rank < len(self.uc_ptrs) and self.uc_ptrs[rank]: checkCudaErrors( cuda.cuMemUnmap( self.uc_ptrs[rank], self.allocation_size ) ) except Exception as e: print( f"Destructor: Failed to release UC handle for rank {rank}: {e}" ) # Free the UC address space if hasattr(self, "uc_base_ptr") and self.uc_base_ptr: checkCudaErrors( cuda.cuMemAddressFree(self.uc_base_ptr, self.total_uc_size) ) # Release MC handle if hasattr(self, "mc_handle") and self.mc_handle and self.mc_handle != 0: try: checkCudaErrors(cuda.cuMemUnmap(self.mc_ptr, self.allocation_size)) checkCudaErrors( cuda.cuMemAddressFree(self.mc_ptr, self.allocation_size) ) checkCudaErrors(cuda.cuMemRelease(self.mc_handle)) except Exception as e: print(f"Destructor: Failed to release MC handle: {e}") def get_signal_pad_ptrs_host(self) -> List[int]: """Get the raw array of signal pad pointers to all ranks (including self)""" return self.signal_pads def get_buffer_ptrs_host(self) -> List[int]: """Get the raw array of unicast pointers to all ranks (including self)""" return self.uc_ptrs def get_signal_pad_ptrs_dev(self) -> int: """Get the raw array of signal pad pointers to all ranks (including self)""" return self.signal_pads_dev def get_buffer_ptrs_dev(self) -> int: """Get the raw array of unicast pointers to all ranks (including self)""" return self.uc_ptrs_dev def get_unicast_ptr(self, rank: int) -> int: """Get the raw unicast pointer to a given rank""" if rank >= len(self.uc_ptrs): raise ValueError(f"Rank {rank} out of range (0-{len(self.uc_ptrs) - 1})") data_ptr = self.uc_ptrs[rank] # Note: In C++, this would call tensorrt_llm::common::registerMcastDevMemBuffer # For Python port, we skip this registration for now return data_ptr def get_multicast_ptr(self) -> int: """Get the raw multicast pointer""" # Note: In C++, this would call tensorrt_llm::common::registerMcastDevMemBuffer # For Python port, we skip this registration for now return int(self.mc_ptr) def get_rank(self) -> int: """Get the rank of this device in the group""" return self.group_rank def get_world_size(self) -> int: """Get the total number of devices in the group""" return self.group_size def _alloc_mn_mcast_mem(self, buf_size: int): """Allocate multi-node multicast memory using MNNVL""" # Verify CUDA context try: current_device = checkCudaErrors(cuda.cuCtxGetDevice()) if int(current_device) != self.device_idx: print( f"CUDA context device mismatch! Current: {current_device}, Expected: {self.device_idx}" ) except Exception as e: print(f"Error checking CUDA context: {e}") # Get MPI communicator comm = MpiComm() # Set up allocation properties handle_type = cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC allocation_prop = cuda.CUmemAllocationProp() allocation_prop.requestedHandleTypes = handle_type allocation_prop.type = cuda.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_PINNED allocation_prop.location = cuda.CUmemLocation() allocation_prop.location.type = ( cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE ) allocation_prop.location.id = self.device_idx allocation_prop.allocFlags.gpuDirectRDMACapable = 1 # Get allocation granularity alloc_granularity = checkCudaErrors( cuda.cuMemGetAllocationGranularity( allocation_prop, cuda.CUmemAllocationGranularity_flags.CU_MEM_ALLOC_GRANULARITY_MINIMUM, ) ) # mAllocationSize = roundUp(bufSize + kSIGNAL_PAD_SIZE, alloc_granularity); self.allocation_size = round_up( buf_size + self.SIGNAL_PAD_SIZE, alloc_granularity ) # Set up multicast properties mc_prop = cuda.CUmulticastObjectProp() mc_prop.numDevices = self.group_size mc_prop.size = self.allocation_size mc_prop.handleTypes = handle_type # Get multicast granularity mc_granularity = checkCudaErrors( cuda.cuMulticastGetGranularity( mc_prop, cuda.CUmulticastGranularity_flags.CU_MULTICAST_GRANULARITY_RECOMMENDED, ) ) self.allocation_size = round_up(self.allocation_size, mc_granularity) # Initialize UC handles list self.uc_handles = [0] * self.group_size # Allocate local GPU memory self.uc_handles[self.group_rank] = checkCudaErrors( cuda.cuMemCreate(self.allocation_size, allocation_prop, 0) ) # Export local handle to fabric handle my_fabric_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.uc_handles[self.group_rank], cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, 0, ) ) # All-gather fabric handles all_fabric_handles = comm.allgather(my_fabric_handle.data) cuda.cuCtxSynchronize() # Import remote handles for p in range(self.group_size): if p != self.group_rank: self.uc_handles[p] = checkCudaErrors( cuda.cuMemImportFromShareableHandle( all_fabric_handles[p], cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, ) ) # Initialize multicasting if self.group_rank == 0: # Create multicast object self.mc_handle = checkCudaErrors(cuda.cuMulticastCreate(mc_prop)) # Export multicast handle mc_fabric_handle = checkCudaErrors( cuda.cuMemExportToShareableHandle( self.mc_handle, cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, 0, ) ) else: mc_fabric_handle = None # Broadcast multicast handle mc_fabric_handle_data = comm.bcast( mc_fabric_handle.data if mc_fabric_handle else None, root=0 ) # Sync device to ensure broadcast is complete cuda.cuCtxSynchronize() # Import multicast handle for non-root ranks if self.group_rank != 0: self.mc_handle = checkCudaErrors( cuda.cuMemImportFromShareableHandle( mc_fabric_handle_data, cuda.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_FABRIC, ) ) # Add device to multicast checkCudaErrors(cuda.cuMulticastAddDevice(self.mc_handle, self.device_idx)) # Bind memory addresses self.uc_ptrs = [0] * self.group_size # Reserve address space for UC pointers total_uc_size = self.allocation_size * self.group_size self.total_uc_size = total_uc_size uc_base_ptr = checkCudaErrors( cuda.cuMemAddressReserve(total_uc_size, mc_granularity, 0, 0) ) self.uc_base_ptr = uc_base_ptr # Store for cleanup # Set up memory access descriptor access_desc = cuda.CUmemAccessDesc() access_desc.location = cuda.CUmemLocation() access_desc.location.type = cuda.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE access_desc.location.id = self.device_idx access_desc.flags = cuda.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE # Map UC memory for i in range(self.group_size): offset = self.allocation_size * i self.uc_ptrs[i] = int(uc_base_ptr) + offset checkCudaErrors( cuda.cuMemMap( self.uc_ptrs[i], self.allocation_size, 0, self.uc_handles[i], 0 ) ) # Set memory access permissions checkCudaErrors( cuda.cuMemSetAccess(uc_base_ptr, total_uc_size, [access_desc], 1) ) # Bind MC pointer self.mc_ptr = checkCudaErrors( cuda.cuMemAddressReserve(self.allocation_size, mc_granularity, 0, 0) ) checkCudaErrors( cuda.cuMemMap(self.mc_ptr, self.allocation_size, 0, self.mc_handle, 0) ) checkCudaErrors( cuda.cuMemSetAccess(self.mc_ptr, self.allocation_size, [access_desc], 1) ) # Bind memory to multicast checkCudaErrors( cuda.cuMulticastBindMem( self.mc_handle, 0, # mcOffset self.uc_handles[self.group_rank], 0, # memOffset self.allocation_size, 0, # flags ) ) def lamport_initialize(self, rank: int, dtype: torch.dtype): if dtype == torch.bfloat16 or dtype == torch.float16: neg_zero = 0x8000 dsize = 2 memset_func = cuda.cuMemsetD16 elif dtype == torch.float32: neg_zero = 0x80000000 dsize = 4 memset_func = cuda.cuMemsetD32 else: raise ValueError(f"Unsupported dtype: {dtype}") # Calculate number of elements that fit in allocation_size num_elements = self.allocation_size // dsize checkCudaErrors( memset_func(int(self.uc_ptrs[self.group_rank]), neg_zero, num_elements) ) class McastGPUBuffer: """ Wrapper class for McastDeviceMemory to facilitate PyTorch tensor creation. It manages a buffer accessible via unicast or multicast for multi-node communication. Python port of McastGPUBuffer from TensorRT-LLM """ def __init__( self, buf_size: int, group_size: int, group_rank: int, device: torch.device, mn_nvlink: bool = True, ): """ Constructor for McastGpuBuffer. Args: buf_size: The total size of the buffer in bytes group_size: The number of ranks in the communication group group_rank: The rank of the local process within the group device: The CUDA device for buffer allocation mn_nvlink: Flag indicating if multi-node NVLink is used """ self.mcast_device_memory = McastDeviceMemory( buf_size, group_size, group_rank, device.index, mn_nvlink ) self.buf_size = buf_size self.local_device = device def lamport_initialize(self, rank: int, dtype: torch.dtype): self.mcast_device_memory.lamport_initialize(rank, dtype) def get_mc_buffer( self, sizes: tuple, dtype: torch.dtype, storage_offset: int = 0 ) -> torch.Tensor: """ Returns a PyTorch tensor view of the multicast buffer portion. Args: sizes: The desired shape (dimensions) of the tensor dtype: The data type of the tensor elements storage_offset: The offset in elements from the start of the buffer Returns: A PyTorch tensor wrapping the multicast buffer section """ raise NotImplementedError("Not implemented yet") def get_multicast_ptr(self) -> int: """Get the raw multicast pointer""" return self.mcast_device_memory.get_multicast_ptr() def get_buffer_ptrs_dev(self) -> int: """Get the buffer pointers device array""" return self.mcast_device_memory.get_buffer_ptrs_dev()