sglang_v0.5.2/flashinfer_0.3.1/tests/test_trtllm_mnnvl_allreduce.py

314 lines
9.6 KiB
Python

# Check torch version:
from typing import Tuple
import pytest
import torch
from mpi4py import MPI # Added MPI import
import flashinfer.comm.trtllm_mnnvl_ar as trtllm_mnnvl_ar
from flashinfer.comm.mapping import Mapping
# Use flashinfer.norm.rmsnorm as reference implementation.
from flashinfer.norm import rmsnorm
@torch.inference_mode()
def row_linear_residual_norm_fusion_forward(
x: torch.Tensor,
residual: torch.Tensor,
norm_weight: torch.Tensor,
eps: float,
hidden_size: int,
dtype: torch.dtype,
mapping: Mapping,
fusion: bool,
reference_output: tuple[torch.Tensor, ...],
multicast_ptr: int,
buffer_ptrs_dev: int,
unicast_ptr: int,
max_num_elements_mnnvl: int,
buffer_flags_mnnvl: torch.Tensor,
):
x = x.cuda()
residual = residual.cuda()
norm_weight = norm_weight.cuda()
reference_output = tuple(t.cuda() for t in reference_output)
tensor_parallel_size = mapping.tp_size
tensor_parallel_rank = mapping.tp_rank
MPI.COMM_WORLD.barrier()
def func(
input,
residual,
norm_weight,
eps,
enable_fusion,
multicast_ptr,
buffer_ptrs_dev,
unicast_ptr,
max_num_elements_mnnvl,
):
# For both fused and unfused cases:
shape = input.shape
assert max_num_elements_mnnvl % hidden_size == 0
input = input.view(-1, shape[-1])
buffer_M = max_num_elements_mnnvl // hidden_size
if enable_fusion:
use_pdl = True
prenorm_output = torch.empty_like(residual)
normed_output = torch.empty_like(residual)
trtllm_mnnvl_ar.mpi_barrier()
trtllm_mnnvl_ar.trtllm_mnnvl_fused_allreduce_rmsnorm(
prenorm_output,
normed_output,
input,
multicast_ptr,
buffer_ptrs_dev,
unicast_ptr,
buffer_M,
buffer_flags_mnnvl,
tensor_parallel_size,
tensor_parallel_rank,
norm_weight,
eps,
residual,
use_pdl,
)
return normed_output.view(shape), prenorm_output.view(shape)
else:
output = torch.empty_like(input)
trtllm_mnnvl_ar.trtllm_mnnvl_all_reduce(
input,
multicast_ptr,
buffer_ptrs_dev,
buffer_M,
buffer_flags_mnnvl,
tensor_parallel_size,
tensor_parallel_rank,
True, # wait_for_results
False, # launch_with_pdl
output, # Need to provide output tensor since we are writing them out.
)
return (output.view(shape),)
output = func(
x.clone(),
residual.clone(),
norm_weight,
eps,
fusion,
multicast_ptr,
buffer_ptrs_dev,
unicast_ptr,
max_num_elements_mnnvl,
)
assert output[0].shape == reference_output[0].shape
if tensor_parallel_rank == 0:
print("output[0] (first 10 values):", output[0].flatten()[:10])
print(
"reference_output[0] (first 10 values):",
reference_output[0].flatten()[:10],
)
if fusion:
print("output[1] (first 10 values):", output[1].flatten()[:10])
print(
"reference_output[1] (first 10 values):",
reference_output[1].flatten()[:10],
)
torch.testing.assert_close(
output[0],
reference_output[0],
rtol=0.05,
atol=0.15,
)
if fusion:
torch.testing.assert_close(
output[1],
reference_output[1],
rtol=0.05,
atol=0.15,
)
"""Main test function that runs on each MPI rank"""
@pytest.mark.parametrize(
"seq_lens",
[
[1],
[4],
[15],
[27, 11, 24],
[127],
],
) # Test with different sequence length lists
@pytest.mark.parametrize("fusion", [False, True])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [2048, 4096, 5120, 7168, 8192])
def test_mnnvl_allreduce_full(
monkeypatch, seq_lens: list[int], fusion: bool, dtype: torch.dtype, hidden_size: int
):
monkeypatch.setenv("TRTLLM_FORCE_MNNVL_AR", "1") # force multi-node allreduce.
# Get MPI info
rank = MPI.COMM_WORLD.Get_rank()
world_size = MPI.COMM_WORLD.Get_size()
gpus_per_node = torch.cuda.device_count()
if gpus_per_node == 0:
pytest.skip("MNNVL allreduce test requires at least one CUDA device per node")
# Ensure we have exactly 2 ranks for this test
if world_size < 2:
pytest.skip(f"This test requires at least 2 MPI ranks, got {world_size}")
mapping = Mapping(
world_size=world_size,
rank=rank,
gpus_per_node=gpus_per_node,
tp_size=world_size,
)
# Set CUDA device based on rank
torch.cuda.set_device(mapping.local_rank)
if mapping.local_rank == 0:
print(
f"[Node {mapping.node_rank}] Running MNNVL AllReduce test with {world_size} ranks"
)
print(
f"[Node {mapping.node_rank}] Rank {rank} using GPU {torch.cuda.current_device()}"
)
tensor_parallel_size = world_size
eps = 1e-5
torch.manual_seed(42)
# Track if this rank failed
rank_failed = False
failure_message = ""
try:
# Get workspace buffers using MPI rank - allocate once per seq_lens list and reuse within the list
# This workspace is sized for the maximum expected sequence length and can be reused within each list
# Each parameterized list gets its own fresh workspace allocation
mcast_buffer_mnnvl, buffer_flags_mnnvl, max_num_elements_mnnvl = (
trtllm_mnnvl_ar.get_allreduce_mnnvl_workspace(mapping, dtype)
)
multicast_ptr = mcast_buffer_mnnvl.get_multicast_ptr()
buffer_ptrs_dev = mcast_buffer_mnnvl.get_buffer_ptrs_dev()
unicast_ptr = mcast_buffer_mnnvl.mcast_device_memory.get_unicast_ptr(
mapping.tp_rank
)
# Test each sequence length with the same workspace (reusing allocated buffers within this list)
for seq_len in seq_lens:
if rank == 0:
print(
f"Testing seq_len={seq_len}, hidden_size={hidden_size}, fusion={fusion}, dtype={dtype}"
)
# Generate test data (same on all ranks due to same seed)
x_full = torch.randn(
(tensor_parallel_size, seq_len, hidden_size),
dtype=dtype,
device=torch.device("cuda"),
)
residual = torch.randn(
(seq_len, hidden_size), dtype=dtype, device=torch.device("cuda")
)
norm_weight = torch.randn(
(hidden_size,), dtype=dtype, device=torch.device("cuda")
)
# Each rank gets its slice of the input
x = x_full[rank, :, :]
# Compute reference output based on fusion mode
reference_output: Tuple[torch.Tensor, ...] = None
if fusion:
# Fused case: AllReduce + Residual Add + RMS Norm
allreduce_result = torch.sum(x_full, dim=0) # AllReduce result
residual_out = allreduce_result + residual # Add residual
print(
"Device of residual_out:{}, norm_weight:{}".format(
residual_out.device, norm_weight.device
)
)
norm_out = rmsnorm(residual_out, norm_weight, eps, enable_pdl=False)
reference_output = (norm_out, residual_out)
else:
# Non-fused case: Only AllReduce
allreduce_result = torch.sum(x_full, dim=0) # AllReduce result
reference_output = (allreduce_result,)
# Run the test with the same workspace
row_linear_residual_norm_fusion_forward(
x,
residual,
norm_weight,
eps,
hidden_size,
dtype,
mapping,
fusion,
reference_output,
multicast_ptr,
buffer_ptrs_dev,
unicast_ptr,
max_num_elements_mnnvl,
buffer_flags_mnnvl,
)
# Synchronize before next test
trtllm_mnnvl_ar.mpi_barrier()
print(
f"PASSED[rank={rank}]: seq_len={seq_len}, fusion={fusion}, dtype={dtype}"
)
except Exception as e:
rank_failed = True
failure_message = f"FAILED[rank={rank}]: seq_lens={seq_lens}, fusion={fusion}, dtype={dtype} failed: {e}"
print(failure_message)
# Gather failure status from all ranks
all_failures = MPI.COMM_WORLD.allgather(rank_failed)
# If any rank failed, fail the test
if any(all_failures):
failed_ranks = [i for i, failed in enumerate(all_failures) if failed]
if rank == 0:
print(f"Test failed on ranks: {failed_ranks}")
# Fail the test on all ranks
pytest.fail(f"Test failed on ranks {failed_ranks}")
trtllm_mnnvl_ar.mpi_barrier()
finally:
# Ensure cleanup happens for this list's workspace
if "mcast_buffer_mnnvl" in locals():
del mcast_buffer_mnnvl
# Final synchronization and check for failures across all ranks
trtllm_mnnvl_ar.mpi_barrier()