""" MNNVL (Multi-Node NVLink) communication operations for FlashInfer. """ import functools import math import os from types import SimpleNamespace from typing import Optional, Tuple import torch from flashinfer.comm.mapping import Mapping from ..jit import JitSpec from ..jit import env as jit_env from ..jit import gen_jit_spec from ..utils import register_custom_op from .mnnvl import McastGPUBuffer def mpi_barrier(): from mpi4py import MPI """MPI barrier - could potentially be replaced with dist.barrier()""" MPI.COMM_WORLD.Barrier() def gen_trtllm_mnnvl_comm_module() -> JitSpec: return gen_jit_spec( "trtllm_mnnvl_comm", [ jit_env.FLASHINFER_CSRC_DIR / "trtllm_mnnvl_allreduce.cu", ], ) @functools.cache def get_trtllm_mnnvl_comm_module(): module = gen_trtllm_mnnvl_comm_module().build_and_load() @register_custom_op( "flashinfer::trtllm_mnnvl_all_reduce", mutates_args=[ "inp", "multicast_buffer_ptr", "buffer_ptrs_dev", "buffer_mnnvl", "buffer_flags_mnnvl", "nranks", "rank", "wait_for_results", "launch_with_pdl", "out", ], ) def trtllm_mnnvl_all_reduce( inp: torch.Tensor, multicast_buffer_ptr: int, # Pointer address as integer buffer_ptrs_dev: int, # Pointer address as integer buffer_mnnvl: torch.Tensor, buffer_flags_mnnvl: torch.Tensor, nranks: int, rank: int, wait_for_results: bool, launch_with_pdl: bool, out: Optional[torch.Tensor], ) -> None: module.trtllm_mnnvl_all_reduce( inp, multicast_buffer_ptr, buffer_ptrs_dev, buffer_mnnvl, buffer_flags_mnnvl, nranks, rank, wait_for_results, launch_with_pdl, out, ) @register_custom_op( "flashinfer::trtllm_mnnvl_rmsnorm", mutates_args=[ "mcast_buffer_input", "prenorm_output", "normed_output", "gamma", "epsilon", "residual", "buffer_flags", "launch_with_pdl", ], ) def trtllm_mnnvl_rmsnorm( mcast_buffer_input: int, prenorm_output: torch.Tensor, normed_output: torch.Tensor, gamma: torch.Tensor, epsilon: float, residual: torch.Tensor, buffer_flags: torch.Tensor, launch_with_pdl: bool, ) -> None: """Performs MNNVL TwoShot RMSNorm on the communication buffer. Args: prenorm_output: Output tensor for prenorm results normed_output: Output tensor for normalized results mcast_buffer_input: Input tensor gamma: The gamma parameter for RMSNorm epsilon: The epsilon parameter for RMSNorm residual: The residual tensor to add buffer_flags: Buffer flags for synchronization launch_with_pdl: Whether to launch with PDL """ return module.trtllm_mnnvl_rmsnorm( mcast_buffer_input, prenorm_output, normed_output, gamma, epsilon, residual, buffer_flags, launch_with_pdl, ) return SimpleNamespace( trtllm_mnnvl_all_reduce=trtllm_mnnvl_all_reduce, trtllm_mnnvl_rmsnorm=trtllm_mnnvl_rmsnorm, ) def get_allreduce_mnnvl_workspace( mapping: Mapping, dtype: torch.dtype ) -> Tuple[McastGPUBuffer, torch.Tensor, int]: """Get workspace buffers needed for multi-node NVLink all-reduce operation. This function allocates and initializes the workspace buffers required for performing multi-node NVLink all-reduce operations. It creates: 1. A multicast GPU buffer for communication between nodes 2. A flags tensor to track buffer state 3. Maximum number of elements that can fit in the buffer The buffer size is calculated to efficiently handle common hidden dimensions (2048, 4096, 5120, 7168, 8192) by using their LCM of 286720. Args: mapping: Tensor parallel mapping configuration containing rank info dtype: Data type of the tensors being reduced Returns: Tuple containing: - McastGPUBuffer: Multicast buffer for inter-node communication - torch.Tensor: Buffer flags tensor tracking state - int: Maximum number of elements that can fit in buffer """ force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" # buffer shape: [3, 2, buffer_tokens, hidden_dim] stride = 3 * 2 * dtype.itemsize # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 # max_num_elements must be a multiple of 286720 lcm_hidden_dim = 286720 TARGET_WORKSPACE_SIZE_BYTES = 12_000_000 buffer_size_in_bytes = math.ceil( TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) ) * (lcm_hidden_dim * stride) max_num_elements = buffer_size_in_bytes // stride mcast_buffer = McastGPUBuffer( buffer_size_in_bytes, mapping.tp_size, mapping.tp_rank, torch.device("cuda", mapping.local_rank), mapping.is_multi_node() or force_mn, ) # Initialize the unicast buffer with -0.0 mcast_buffer.lamport_initialize(mapping.tp_rank, dtype) # CPU barrier since we assume this should not be called in cuda graph torch.cuda.synchronize() mpi_barrier() # This is a buffer to maintain the state of this allreduce Op # [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter] buffer_flags = torch.tensor( [0, 2, max_num_elements, 0, 0], dtype=torch.uint32, device=torch.device("cuda", mapping.local_rank), ) return ( mcast_buffer, buffer_flags, max_num_elements, ) def trtllm_mnnvl_all_reduce( inp: torch.Tensor, multicast_buffer_ptr: int, # Pointer address as integer buffer_ptrs_dev: int, # Pointer address as integer buffer_M: int, buffer_flags_mnnvl: torch.Tensor, nranks: int, rank: int, wait_for_results: bool, launch_with_pdl: bool, out: Optional[torch.Tensor] = None, ) -> None: """Perform a multi-node NVLink all-reduce operation across multiple GPUs. This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL) technology to efficiently combine tensors across multiple GPUs and nodes. There are 3 steps: 1. scatter each GPU's input shard to the right unicast buffer 2. perform all-reduce on each GPU 3. broadcast the result to all GPUs Args: inp: Local Input Shard multicast_buffer_ptr: Pointer to the multicast buffer as an integer buffer_ptrs_dev: Pointer to device buffer pointers as an integer buffer_M: Maximum number of elements // hidden_dim buffer_flags_mnnvl: Tensor containing buffer state flags nranks: Total number of ranks participating in the all-reduce rank: Current process rank wait_for_results: If True, store the result to out launch_with_pdl: If True, launch using Programmatic Dependent Launch [Optional] out: Output tensor to store the result (required if wait_for_results is True) """ module = get_trtllm_mnnvl_comm_module() module.trtllm_mnnvl_all_reduce( inp, multicast_buffer_ptr, buffer_ptrs_dev, buffer_M, buffer_flags_mnnvl, nranks, rank, wait_for_results, launch_with_pdl, out, ) def trtllm_mnnvl_fused_allreduce_rmsnorm( prenorm_output: torch.Tensor, normed_output: torch.Tensor, shard_input: torch.Tensor, multicast_buffer_ptr: int, # Pointer address as integer buffer_ptrs_dev: int, # Pointer address as integer unicast_ptr: int, # Local unicast buffer pointer buffer_M: int, buffer_flags_mnnvl: torch.Tensor, nranks: int, rank: int, gamma: torch.Tensor, epsilon: float, residual: torch.Tensor, launch_with_pdl: bool, ) -> None: """Performs MNNVL TwoShot Allreduce + RMSNorm. This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_all_reduce on the shard_input. After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer. Note: multicast buffer is the same as the unicast buffer for the current rank. Args: prenorm_output: Output tensor for prenorm results normed_output: Output tensor for normalized results shard_input: Input tensor shard multicast_buffer_ptr: Pointer address as integer for multicast buffer buffer_ptrs_dev: Pointer address as integer for device buffer pointers unicast_ptr: Pointer address as integer for unicast buffer buffer_M: Maximum number of elements // hidden_dim buffer_flags_mnnvl: Buffer flags for synchronization nranks: Number of ranks in the tensor parallel group rank: Current rank in the tensor parallel group gamma: The gamma (norm weight) parameter for RMSNorm epsilon: The epsilon parameter for RMSNorm residual: The residual tensor to add launch_with_pdl: Whether to launch with PDL """ # allreduce_result = Σ(shard_input across all ranks) trtllm_mnnvl_all_reduce( shard_input, multicast_buffer_ptr, buffer_ptrs_dev, buffer_M, buffer_flags_mnnvl, nranks, rank, False, # No need to wait to write AR results here as we are not writing them launch_with_pdl, None, # out parameter - None since wait_for_results=False ) # prenorm_output = AllReduce(shard_input) + residual # rms = sqrt(mean(prenorm_output²) + epsilon) # normed_output = (prenorm_output / rms) * gamma get_trtllm_mnnvl_comm_module().trtllm_mnnvl_rmsnorm( unicast_ptr, prenorm_output, normed_output, gamma, epsilon, residual, buffer_flags_mnnvl, launch_with_pdl, )