sglang_v0.5.2/flashinfer_0.3.1/flashinfer/comm/trtllm_mnnvl_ar.py

318 lines
10 KiB
Python

"""
MNNVL (Multi-Node NVLink) communication operations for FlashInfer.
"""
import functools
import math
import os
from types import SimpleNamespace
from typing import Optional, Tuple
import torch
from flashinfer.comm.mapping import Mapping
from ..jit import JitSpec
from ..jit import env as jit_env
from ..jit import gen_jit_spec
from ..utils import register_custom_op
from .mnnvl import McastGPUBuffer
def mpi_barrier():
from mpi4py import MPI
"""MPI barrier - could potentially be replaced with dist.barrier()"""
MPI.COMM_WORLD.Barrier()
def gen_trtllm_mnnvl_comm_module() -> JitSpec:
return gen_jit_spec(
"trtllm_mnnvl_comm",
[
jit_env.FLASHINFER_CSRC_DIR / "trtllm_mnnvl_allreduce.cu",
],
)
@functools.cache
def get_trtllm_mnnvl_comm_module():
module = gen_trtllm_mnnvl_comm_module().build_and_load()
@register_custom_op(
"flashinfer::trtllm_mnnvl_all_reduce",
mutates_args=[
"inp",
"multicast_buffer_ptr",
"buffer_ptrs_dev",
"buffer_mnnvl",
"buffer_flags_mnnvl",
"nranks",
"rank",
"wait_for_results",
"launch_with_pdl",
"out",
],
)
def trtllm_mnnvl_all_reduce(
inp: torch.Tensor,
multicast_buffer_ptr: int, # Pointer address as integer
buffer_ptrs_dev: int, # Pointer address as integer
buffer_mnnvl: torch.Tensor,
buffer_flags_mnnvl: torch.Tensor,
nranks: int,
rank: int,
wait_for_results: bool,
launch_with_pdl: bool,
out: Optional[torch.Tensor],
) -> None:
module.trtllm_mnnvl_all_reduce(
inp,
multicast_buffer_ptr,
buffer_ptrs_dev,
buffer_mnnvl,
buffer_flags_mnnvl,
nranks,
rank,
wait_for_results,
launch_with_pdl,
out,
)
@register_custom_op(
"flashinfer::trtllm_mnnvl_rmsnorm",
mutates_args=[
"mcast_buffer_input",
"prenorm_output",
"normed_output",
"gamma",
"epsilon",
"residual",
"buffer_flags",
"launch_with_pdl",
],
)
def trtllm_mnnvl_rmsnorm(
mcast_buffer_input: int,
prenorm_output: torch.Tensor,
normed_output: torch.Tensor,
gamma: torch.Tensor,
epsilon: float,
residual: torch.Tensor,
buffer_flags: torch.Tensor,
launch_with_pdl: bool,
) -> None:
"""Performs MNNVL TwoShot RMSNorm on the communication buffer.
Args:
prenorm_output: Output tensor for prenorm results
normed_output: Output tensor for normalized results
mcast_buffer_input: Input tensor
gamma: The gamma parameter for RMSNorm
epsilon: The epsilon parameter for RMSNorm
residual: The residual tensor to add
buffer_flags: Buffer flags for synchronization
launch_with_pdl: Whether to launch with PDL
"""
return module.trtllm_mnnvl_rmsnorm(
mcast_buffer_input,
prenorm_output,
normed_output,
gamma,
epsilon,
residual,
buffer_flags,
launch_with_pdl,
)
return SimpleNamespace(
trtllm_mnnvl_all_reduce=trtllm_mnnvl_all_reduce,
trtllm_mnnvl_rmsnorm=trtllm_mnnvl_rmsnorm,
)
def get_allreduce_mnnvl_workspace(
mapping: Mapping, dtype: torch.dtype
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.
This function allocates and initializes the workspace buffers required for performing
multi-node NVLink all-reduce operations. It creates:
1. A multicast GPU buffer for communication between nodes
2. A flags tensor to track buffer state
3. Maximum number of elements that can fit in the buffer
The buffer size is calculated to efficiently handle common hidden dimensions
(2048, 4096, 5120, 7168, 8192) by using their LCM of 286720.
Args:
mapping: Tensor parallel mapping configuration containing rank info
dtype: Data type of the tensors being reduced
Returns:
Tuple containing:
- McastGPUBuffer: Multicast buffer for inter-node communication
- torch.Tensor: Buffer flags tensor tracking state
- int: Maximum number of elements that can fit in buffer
"""
force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1"
# buffer shape: [3, 2, buffer_tokens, hidden_dim]
stride = 3 * 2 * dtype.itemsize
# LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720
# max_num_elements must be a multiple of 286720
lcm_hidden_dim = 286720
TARGET_WORKSPACE_SIZE_BYTES = 12_000_000
buffer_size_in_bytes = math.ceil(
TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride)
) * (lcm_hidden_dim * stride)
max_num_elements = buffer_size_in_bytes // stride
mcast_buffer = McastGPUBuffer(
buffer_size_in_bytes,
mapping.tp_size,
mapping.tp_rank,
torch.device("cuda", mapping.local_rank),
mapping.is_multi_node() or force_mn,
)
# Initialize the unicast buffer with -0.0
mcast_buffer.lamport_initialize(mapping.tp_rank, dtype)
# CPU barrier since we assume this should not be called in cuda graph
torch.cuda.synchronize()
mpi_barrier()
# This is a buffer to maintain the state of this allreduce Op
# [Buffer_ptr, Clear_ptr, Buffer_size, num_tokens_prev, atomic access counter]
buffer_flags = torch.tensor(
[0, 2, max_num_elements, 0, 0],
dtype=torch.uint32,
device=torch.device("cuda", mapping.local_rank),
)
return (
mcast_buffer,
buffer_flags,
max_num_elements,
)
def trtllm_mnnvl_all_reduce(
inp: torch.Tensor,
multicast_buffer_ptr: int, # Pointer address as integer
buffer_ptrs_dev: int, # Pointer address as integer
buffer_M: int,
buffer_flags_mnnvl: torch.Tensor,
nranks: int,
rank: int,
wait_for_results: bool,
launch_with_pdl: bool,
out: Optional[torch.Tensor] = None,
) -> None:
"""Perform a multi-node NVLink all-reduce operation across multiple GPUs.
This function performs an all-reduce (sum) operation using NVIDIA's multi-node NVLink (MNNVL)
technology to efficiently combine tensors across multiple GPUs and nodes.
There are 3 steps:
1. scatter each GPU's input shard to the right unicast buffer
2. perform all-reduce on each GPU
3. broadcast the result to all GPUs
Args:
inp: Local Input Shard
multicast_buffer_ptr: Pointer to the multicast buffer as an integer
buffer_ptrs_dev: Pointer to device buffer pointers as an integer
buffer_M: Maximum number of elements // hidden_dim
buffer_flags_mnnvl: Tensor containing buffer state flags
nranks: Total number of ranks participating in the all-reduce
rank: Current process rank
wait_for_results: If True, store the result to out
launch_with_pdl: If True, launch using Programmatic Dependent Launch
[Optional] out: Output tensor to store the result (required if wait_for_results is True)
"""
module = get_trtllm_mnnvl_comm_module()
module.trtllm_mnnvl_all_reduce(
inp,
multicast_buffer_ptr,
buffer_ptrs_dev,
buffer_M,
buffer_flags_mnnvl,
nranks,
rank,
wait_for_results,
launch_with_pdl,
out,
)
def trtllm_mnnvl_fused_allreduce_rmsnorm(
prenorm_output: torch.Tensor,
normed_output: torch.Tensor,
shard_input: torch.Tensor,
multicast_buffer_ptr: int, # Pointer address as integer
buffer_ptrs_dev: int, # Pointer address as integer
unicast_ptr: int, # Local unicast buffer pointer
buffer_M: int,
buffer_flags_mnnvl: torch.Tensor,
nranks: int,
rank: int,
gamma: torch.Tensor,
epsilon: float,
residual: torch.Tensor,
launch_with_pdl: bool,
) -> None:
"""Performs MNNVL TwoShot Allreduce + RMSNorm.
This function performs a multi-node all-reduce (sum) operation by first calling trtllm_mnnvl_all_reduce on the shard_input.
After this, it performs RMSNorm on the all-reduced result, reading it directly from the multicast buffer.
Note: multicast buffer is the same as the unicast buffer for the current rank.
Args:
prenorm_output: Output tensor for prenorm results
normed_output: Output tensor for normalized results
shard_input: Input tensor shard
multicast_buffer_ptr: Pointer address as integer for multicast buffer
buffer_ptrs_dev: Pointer address as integer for device buffer pointers
unicast_ptr: Pointer address as integer for unicast buffer
buffer_M: Maximum number of elements // hidden_dim
buffer_flags_mnnvl: Buffer flags for synchronization
nranks: Number of ranks in the tensor parallel group
rank: Current rank in the tensor parallel group
gamma: The gamma (norm weight) parameter for RMSNorm
epsilon: The epsilon parameter for RMSNorm
residual: The residual tensor to add
launch_with_pdl: Whether to launch with PDL
"""
# allreduce_result = Σ(shard_input across all ranks)
trtllm_mnnvl_all_reduce(
shard_input,
multicast_buffer_ptr,
buffer_ptrs_dev,
buffer_M,
buffer_flags_mnnvl,
nranks,
rank,
False, # No need to wait to write AR results here as we are not writing them
launch_with_pdl,
None, # out parameter - None since wait_for_results=False
)
# prenorm_output = AllReduce(shard_input) + residual
# rms = sqrt(mean(prenorm_output²) + epsilon)
# normed_output = (prenorm_output / rms) * gamma
get_trtllm_mnnvl_comm_module().trtllm_mnnvl_rmsnorm(
unicast_ptr,
prenorm_output,
normed_output,
gamma,
epsilon,
residual,
buffer_flags_mnnvl,
launch_with_pdl,
)