sglang_v0.5.2/pytorch_2.8.0/test/distributed/test_nvshmem.py

984 lines
35 KiB
Python

# Owner(s): ["oncall: distributed"]
# To run:
# TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py
import torch
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
from torch._inductor.runtime.triton_compat import tl, triton
from torch.testing._internal.common_distributed import MultiProcContinousTest
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skip_but_pass_in_sandcastle_if,
skipIfRocm,
)
from torch.testing._internal.inductor_utils import requires_triton
# Decorator
def requires_nvshmem():
return skip_but_pass_in_sandcastle_if(
not symm_mem.is_nvshmem_available(),
"test_nvshmem requires NVSHMEM, skipping tests",
)
# So that tests are written in device-agnostic way
device_type = "cuda"
device_module = torch.get_device_module(device_type)
@instantiate_parametrized_tests
@requires_nvshmem()
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
def _init_device(self) -> None:
# TODO: relieve this (seems to hang if without)
device_module.set_device(self.device)
# NOTE: required for nvshmem allocation
torch.empty(1, device=self.device)
@property
def device(self) -> torch.device:
return torch.device(device_type, self.rank)
@skipIfRocm
def test_alloc(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel = 1024
def foo():
inp = symm_mem.empty(numel, dtype=dtype, device=self.device)
symm_mem.rendezvous(inp, group=group_name)
foo()
out = symm_mem.empty(numel, dtype=dtype, device=self.device)
symm_mem.rendezvous(out, group=group_name)
@skipIfRocm
def test_nvshmem_put(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel = 1024
tensor = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
symm_mem.rendezvous(tensor, group=group_name)
if self.rank == 0:
torch.ops.symm_mem.nvshmem_put(tensor, 1)
# TODO: remove after we have wait_signal
dist.barrier()
elif self.rank == 1:
# handle.wait_signal(src_rank=0)
# TODO: remove after we have wait_signal
dist.barrier()
torch.testing.assert_close(
tensor, torch.zeros(numel, dtype=dtype, device=self.device)
)
else:
dist.barrier()
@skipIfRocm
def test_nvshmem_all_to_all(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
numel_per_peer = 10
numel = self.world_size * numel_per_peer
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
symm_mem.rendezvous(inp, group=group_name)
symm_mem.rendezvous(out, group=group_name)
torch.ops.symm_mem.nvshmem_all_to_all(inp, out, group_name)
expected = torch.cat(
[
torch.empty(numel_per_peer, dtype=dtype, device=self.device).fill_(i)
for i in range(self.world_size)
]
)
torch.testing.assert_close(out, expected)
@skipIfRocm
def test_all_to_all_vdev(self) -> None:
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
# Number of elements for a peer is random between [0, k)
k = 10
inp_splits = torch.randint(k, (self.world_size,), device=self.device)
inp_numel = inp_splits.sum().item()
# Exchange input splits to get output splits
out_splits = torch.zeros_like(inp_splits)
dist.all_to_all_single(out_splits, inp_splits)
out_numel = out_splits.sum().item()
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
max_inp_numel = k * self.world_size
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
overflow_factor = self.world_size # worst case: one rank receives all data
max_out_numel = max_inp_numel * overflow_factor
inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).fill_(
self.rank
)
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
in_out_splits = symm_mem.empty(
(3, self.world_size), dtype=torch.int64, device=self.device
)
# Row 0 is input splits
in_out_splits[0].copy_(inp_splits)
torch.ops.symm_mem.all_to_all_vdev(inp, out, in_out_splits, group_name)
# Check input splits (row 0) -- should not change
torch.testing.assert_close(in_out_splits[0], inp_splits)
# Check output splits (row 1)
torch.testing.assert_close(in_out_splits[1], out_splits)
# Check output offsets (row 2)
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
# output offsets from `all_to_all_vdev` is exclusive scan
self.assertEqual(in_out_splits[2][0], 0)
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
# Check data
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
dist.all_to_all_single(
expected, inp[:inp_numel], out_splits.tolist(), inp_splits.tolist()
)
torch.testing.assert_close(out[:out_numel], expected)
@skipIfRocm
@parametrize("align", [1, 8, 16]) # `major_align` of output
def test_all_to_all_vdev_2d(self, align: int) -> None:
torch.manual_seed(42 + self.rank)
self._init_device()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
dtype = torch.float
# Number of experts per rank
ne = 8
nsplits = ne * self.world_size
# Number of elements for an expert is random between [0, k)
k = 10
inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=self.device)
# Exchange input splits to get output splits
out_splits = torch.zeros_like(inp_splits)
dist.all_to_all_single(out_splits, inp_splits)
# We do a .t() here because there is a rank-major to expert-major shuffle
out_splits_t = out_splits.reshape(self.world_size, ne).t()
# Actual number of input elements
inp_numel = inp_splits.sum().item()
# Actual number of output elements
out_numel = out_splits.sum().item()
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
max_inp_numel = k * nsplits
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
overflow_factor = self.world_size # worst case: one rank receives all data
max_out_numel = max_inp_numel * overflow_factor
inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).fill_(
self.rank
)
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
# 3 rows: input splits, output splits, output offsets
# Initiallizing all values to -1 to check if they are updated
in_out_splits = symm_mem.empty(
(3, nsplits), dtype=torch.int64, device=self.device
).fill_(-1)
# Row 0 is input splits
in_out_splits[0].copy_(inp_splits)
torch.ops.symm_mem.all_to_all_vdev_2d(
inp, out, in_out_splits, group_name, major_align=align
)
received_out_splits = in_out_splits[1]
received_out_offsets = in_out_splits[2]
# Check input splits (row 0) -- should not change
torch.testing.assert_close(in_out_splits[0], inp_splits)
# Check output splits (row 1)
torch.testing.assert_close(received_out_splits, out_splits_t.reshape(-1))
# Check output offsets (row 2)
out_split_list = out_splits_t.tolist()
for i in range(ne):
expert_sum = 0
for j in range(self.world_size):
expert_sum += out_split_list[i][j]
# Align up expert_sum
expert_sum_aligned = (expert_sum + align - 1) // align * align
# If 0, make it at least `align` (bc cutlass currently does not support empty bins)
expert_sum_aligned = max(expert_sum_aligned, align)
# last element absorbs the padding
out_split_list[i][-1] += expert_sum_aligned - expert_sum
out_splits_padded = torch.tensor(out_split_list, device=self.device).reshape(-1)
out_offsets = torch.cumsum(out_splits_padded, dim=0) # inclusive scan
# Make it exclusive scan because that's what `all_to_all_vdev_2d` returns
out_offsets = torch.cat(
[torch.zeros(1, device=self.device), out_offsets[:-1]]
).to(torch.int64)
torch.testing.assert_close(received_out_offsets, out_offsets)
# Check data
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
inp_splits_rank = inp_splits.reshape(self.world_size, ne).sum(1)
out_splits_rank = out_splits.reshape(self.world_size, ne).sum(1)
dist.all_to_all_single(
expected,
inp[:inp_numel],
out_splits_rank.tolist(),
inp_splits_rank.tolist(),
)
# We still need to shuffle `expected`
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
result_list = []
for j in range(ne):
for i in range(self.world_size):
chunk_id = i * ne + j
offset = out_offsets[chunk_id]
chunk = expected[offset - out_splits[chunk_id] : offset]
result_list.append(chunk)
# Do a chunk-wise comparison
for c, chunk in enumerate(result_list):
start = received_out_offsets[c].item()
split = received_out_splits[c].item()
received_chunk = out[start : start + split]
torch.testing.assert_close(received_chunk, chunk)
@skipIfRocm
@requires_triton()
def test_triton_put(self) -> None:
# A Triton kernel that calls nvshmem device side API
@triton.jit
def put_kernel(
dst_ptr,
src_ptr,
numel: tl.constexpr,
peer: tl.constexpr,
):
nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)
torch.manual_seed(42 + self.rank)
self._init_device()
# Enable NVSHMEM for Triton
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
val = 5
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
peer = 1 - rank
if rank == 0:
dst_ptr = out_hdl.buffer_ptrs[rank]
src_ptr = inp_hdl.buffer_ptrs[rank]
put_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel=numel,
peer=peer,
extern_libs=nvshmem_lib,
)
dist.barrier()
if rank == 1:
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
)
@skipIfRocm
@requires_triton()
def test_triton_get(self) -> None:
# A Triton kernel that calls nvshmem device side API for GET
@triton.jit
def get_kernel(
dst_ptr,
src_ptr,
numel: tl.constexpr,
peer: tl.constexpr,
):
nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
val = 7
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(
val if rank == 0 else -1
)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
dist.barrier()
peer = 1 - rank
if rank == 1:
# Rank 1 gets data from rank 0
dst_ptr = out_hdl.buffer_ptrs[rank]
src_ptr = inp_hdl.buffer_ptrs[rank]
get_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel=numel,
peer=peer,
extern_libs=nvshmem_lib,
)
if rank == 1:
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
)
@skipIfRocm
@requires_triton()
def test_triton_get_ring(self) -> None:
# A Triton kernel that calls nvshmem device side API for GET
# with ring topology
@triton.jit
def get_kernel(
dst_ptr,
src_ptr,
numel: tl.constexpr,
peer: tl.constexpr,
):
nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
world_size = dist.get_world_size()
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
# Each rank fills its input buffer with its own rank value
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(rank)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
dist.barrier()
# Ring topology: each rank gets data from the rank to its left
# rank 0 gets from rank (world_size-1), rank 1 gets from rank 0, etc.
peer = (rank - 1) % world_size
# All ranks execute the get operation
dst_ptr = out_hdl.buffer_ptrs[rank]
src_ptr = inp_hdl.buffer_ptrs[rank]
get_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel=numel,
peer=peer,
extern_libs=nvshmem_lib,
)
expected_value = peer
torch.testing.assert_close(
out, expected_value * torch.ones(numel, dtype=dtype, device=self.device)
)
@skipIfRocm
@requires_triton()
def test_triton_put_signal_set(self) -> None:
# A Triton kernel that calls nvshmem device side API for PUT with SIGNAL
@triton.jit
def put_signal_kernel(
dst_ptr,
src_ptr,
numel: tl.constexpr,
sig_ptr,
signal_val: tl.constexpr,
sig_op: tl.constexpr,
peer: tl.constexpr,
):
nvshmem.putmem_signal_block(
dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer
)
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
# Data buffers
val = 11
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
# Use the signal pad attached to the output symmetric memory handle
# as the flag buffer for signaling completion.
flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0)
peer = 1 - rank
NVSHMEM_SIGNAL_SET = 0 # value defined by NVSHMEM for atomic set
SIGNAL_VAL = 1 # Signal completion value
NVSHMEM_CMP_EQ = 0 # compare equal for signal wait until
# Kernel for waiting on the signal locally (Rank 1).
@triton.jit
def signal_wait_until_kernel(
sig_ptr, cmp_op: tl.constexpr, cmp_val: tl.constexpr
):
nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val)
if rank == 0:
# Rank 0 puts into Rank 1
dst_ptr = out_hdl.buffer_ptrs[peer]
src_ptr = inp_hdl.buffer_ptrs[rank]
sig_ptr = out_hdl.signal_pad_ptrs[peer]
put_signal_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel=numel,
sig_ptr=sig_ptr,
signal_val=SIGNAL_VAL,
sig_op=NVSHMEM_SIGNAL_SET,
peer=peer,
extern_libs=nvshmem_lib,
)
if rank == 1:
# Wait until signal flag is set by Rank 0
sig_ptr_local = out_hdl.signal_pad_ptrs[rank]
signal_wait_until_kernel[(1,)](
sig_ptr_local,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=SIGNAL_VAL,
extern_libs=nvshmem_lib,
)
# After wait completes, verify data and flag contents
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
)
torch.testing.assert_close(
flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device)
)
@skipIfRocm
@requires_triton()
def test_triton_put_signal_add(self) -> None:
# A Triton kernel that calls nvshmem device side API for PUT with SIGNAL
@triton.jit
def put_signal_kernel(
dst_ptr,
src_ptr,
numel: tl.constexpr,
sig_ptr,
signal_val: tl.constexpr,
sig_op: tl.constexpr,
peer: tl.constexpr,
):
nvshmem.putmem_signal_block(
dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer
)
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
# Data buffers
val = 11
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
# Use the signal pad attached to the output symmetric memory handle
# as the flag buffer for signaling completion.
flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0)
peer = 1 - rank
NVSHMEM_SIGNAL_ADD = 5 # atomic add operation
SIGNAL_VAL = 16 # val + NVSHMEM_SIGNAL_ADD
NVSHMEM_CMP_EQ = 0
@triton.jit
def signal_wait_until_kernel(
sig_ptr, cmp_op: tl.constexpr, cmp_val: tl.constexpr
):
nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val)
if rank == 0:
# Rank 0 puts into Rank 1
dst_ptr = out_hdl.buffer_ptrs[peer]
src_ptr = inp_hdl.buffer_ptrs[rank]
sig_ptr = out_hdl.signal_pad_ptrs[peer]
put_signal_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel=numel,
sig_ptr=sig_ptr,
signal_val=SIGNAL_VAL,
sig_op=NVSHMEM_SIGNAL_ADD,
peer=peer,
extern_libs=nvshmem_lib,
)
if rank == 1:
sig_ptr_local = out_hdl.signal_pad_ptrs[rank]
signal_wait_until_kernel[(1, 1, 1)](
sig_ptr_local,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=SIGNAL_VAL,
extern_libs=nvshmem_lib,
)
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
)
torch.testing.assert_close(
flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device)
)
@skipIfRocm
@requires_triton()
def test_triton_wait_until(self) -> None:
# A Triton kernel that calls nvshmem device side API for PUT
@triton.jit
def put_kernel(
dst_ptr,
src_ptr,
numel: tl.constexpr,
peer: tl.constexpr,
):
nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)
# A Triton kernel that calls nvshmem device side API for WAIT_UNTIL
@triton.jit
def wait_until_kernel(
ivar_ptr,
cmp_op: tl.constexpr,
cmp_val: tl.constexpr,
):
nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val)
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
# Data buffers
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
val = 13
flag_val = 21
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
dist.barrier()
peer = 1 - rank
NVSHMEM_CMP_EQ = 0 # from nvshmem.h
if rank == 0:
# Rank 0 waits for the flag to be set by Rank 1, then checks the data
ivar_ptr = out_hdl.signal_pad_ptrs[rank]
wait_until_kernel[(1, 1, 1)](
ivar_ptr,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=flag_val,
extern_libs=nvshmem_lib,
)
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
)
if rank == 1:
# Rank 1 puts data into Rank 0's output buffer
dst_ptr = out_hdl.buffer_ptrs[rank]
src_ptr = inp_hdl.buffer_ptrs[rank]
put_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel=numel,
peer=peer,
extern_libs=nvshmem_lib,
)
# Rank 1 sets the flag on Rank 0
# We use a temporary tensor for the value to put.
flag_update_val = torch.tensor(
[flag_val], dtype=torch.int64, device=self.device
)
dst_ptr = out_hdl.signal_pad_ptrs[rank]
src_ptr = flag_update_val.data_ptr()
put_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel=1,
peer=peer,
extern_libs=nvshmem_lib,
)
@skipIfRocm
@requires_triton()
def test_triton_signal_wait_until(self) -> None:
# A Triton kernel that waits on a signal variable until it meets the compare condition.
@triton.jit
def signal_wait_until_kernel(
sig_ptr,
cmp_op: tl.constexpr,
cmp_val: tl.constexpr,
):
nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val)
# A Triton kernel for the producer that puts data and then signals completion.
@triton.jit
def put_and_signal_kernel(
dst_ptr,
src_ptr,
numel: tl.constexpr,
sig_ptr,
signal_val: tl.constexpr,
sig_op: tl.constexpr,
peer: tl.constexpr,
):
nvshmem.putmem_signal_block(
dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer
)
self._init_device()
# Enable NVSHMEM for Triton
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
peer = 1 - rank
# NVSHMEM constants from documentation
NVSHMEM_CMP_EQ = 0 # equal comparison
NVSHMEM_SIGNAL_SET = 0 # atomic set operation
# Message configuration
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
val_to_put = 123 # arbitrary test value
COMPLETION_FLAG_VAL = 1
# Producer (rank 0) prepares the data to send
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val_to_put)
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
# Consumer (rank 1) prepares the destination buffer
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
out_hdl = symm_mem.rendezvous(out, group=group_name)
# Use the signal pad for synchronization, as in previous tests
flag_dtype = torch.int64
flag = out_hdl.get_signal_pad(rank, (1,), dtype=flag_dtype).fill_(0)
# Ensure setup is complete on all ranks before proceeding
dist.barrier()
if rank == 0:
# Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag
dst_ptr = out_hdl.buffer_ptrs[peer]
src_ptr = inp_hdl.buffer_ptrs[rank]
sig_ptr = out_hdl.signal_pad_ptrs[peer]
put_and_signal_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
numel,
sig_ptr,
signal_val=COMPLETION_FLAG_VAL,
sig_op=NVSHMEM_SIGNAL_SET,
peer=peer,
extern_libs=nvshmem_lib,
)
elif rank == 1:
# Consumer (rank 1): Waits on the signal variable using `signal_wait_until`.
sig_ptr = out_hdl.signal_pad_ptrs[rank]
signal_wait_until_kernel[(1, 1, 1)](
sig_ptr,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=COMPLETION_FLAG_VAL,
extern_libs=nvshmem_lib,
)
# After the wait returns, verify data and flag
torch.testing.assert_close(
out, val_to_put * torch.ones(numel, dtype=dtype, device=self.device)
)
torch.testing.assert_close(
flag,
torch.tensor(
[COMPLETION_FLAG_VAL], dtype=flag_dtype, device=self.device
),
)
# Final barrier to ensure the test does not exit before assertions complete
dist.barrier()
@skipIfRocm
@requires_triton()
def test_triton_fence(self) -> None:
"""
Rank 0 performs two put operations into Rank 1's buffers with a fence
between them, followed by another fence and a flag update. Rank 1 waits
for the flag, then verifies that both destination buffers contain the
expected values. The flag is transferred after the final fence, so
its arrival implies that both preceding puts have been delivered in
order.
"""
# Triton kernel that issues two ordered puts separated by fences and
# finally writes the completion flag.
@triton.jit
def put_with_fence_kernel(
dst_ptr1,
dst_ptr2,
src_ptr1,
src_ptr2,
flag_ptr,
flag_src_ptr,
numel: tl.constexpr,
peer: tl.constexpr,
):
# First put
nvshmem.putmem_block(dst_ptr1, src_ptr1, numel, peer)
# Ensure the first put is ordered before the next.
nvshmem.fence()
# Second put
nvshmem.putmem_block(dst_ptr2, src_ptr2, numel, peer)
# Order the second put before flag update.
nvshmem.fence()
# Write the flag (single int64) to signal completion.
nvshmem.putmem_block(flag_ptr, flag_src_ptr, 1, peer)
# Kernel for Rank 1 to wait until the flag becomes the expected value.
@triton.jit
def wait_until_kernel(
ivar_ptr,
cmp_op: tl.constexpr,
cmp_val: tl.constexpr,
):
nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val)
torch.manual_seed(42 + self.rank)
self._init_device()
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
peer = 1 - rank
# Message configuration
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
val1 = 10
val2 = 20
flag_val = 1
# Symmetric buffers
inp1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val1)
inp2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val2)
out1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
out2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
inp1_hdl = symm_mem.rendezvous(inp1, group=group_name)
inp2_hdl = symm_mem.rendezvous(inp2, group=group_name)
out1_hdl = symm_mem.rendezvous(out1, group=group_name)
out2_hdl = symm_mem.rendezvous(out2, group=group_name)
# Flag buffer resides in the signal pad of out2.
flag = out2_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0)
flag_update_val = torch.tensor(
[flag_val], dtype=torch.int64, device=self.device
)
NVSHMEM_CMP_EQ = 0 # compare equal
dist.barrier()
if rank == 0:
dst_ptr1 = out1_hdl.buffer_ptrs[rank]
dst_ptr2 = out2_hdl.buffer_ptrs[rank]
src_ptr1 = inp1_hdl.buffer_ptrs[rank]
src_ptr2 = inp2_hdl.buffer_ptrs[rank]
flag_ptr = out2_hdl.signal_pad_ptrs[rank]
flag_src_ptr = flag_update_val.data_ptr()
put_with_fence_kernel[(1, 1, 1)](
dst_ptr1,
dst_ptr2,
src_ptr1,
src_ptr2,
flag_ptr,
flag_src_ptr,
numel,
peer=peer,
extern_libs=nvshmem_lib,
)
elif rank == 1:
# Wait until flag is set by Rank 0.
ivar_ptr = out2_hdl.signal_pad_ptrs[rank]
wait_until_kernel[(1, 1, 1)](
ivar_ptr,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=flag_val,
extern_libs=nvshmem_lib,
)
# Verify ordered data arrival.
torch.testing.assert_close(
out1, val1 * torch.ones(numel, dtype=dtype, device=self.device)
)
torch.testing.assert_close(
out2, val2 * torch.ones(numel, dtype=dtype, device=self.device)
)
torch.testing.assert_close(
flag, torch.tensor([flag_val], dtype=torch.int64, device=self.device)
)
dist.barrier()
@skipIfRocm
@requires_triton()
def test_triton_quiet(self) -> None:
# A Triton kernel that uses nvshmem_quiet to ensure completion
@triton.jit
def put_with_quiet_kernel(
dst_ptr,
src_ptr,
flag_dst_ptr,
flag_src_ptr,
numel: tl.constexpr,
peer: tl.constexpr,
):
# Put data
nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)
# Call quiet to ensure put is complete
nvshmem.quiet()
# Only after quiet, set the completion flag
# This ensures the data put is complete before flag is set
nvshmem.putmem_block(flag_dst_ptr, flag_src_ptr, 1, peer)
torch.manual_seed(42 + self.rank)
self._init_device()
# Enable NVSHMEM for Triton
nvshmem_lib = nvshmem.enable_triton()
group_name = dist.group.WORLD.group_name
symm_mem.enable_symm_mem_for_group(group_name)
rank = self.rank
msg_size_bytes = 8
dtype = torch.int8
numel = msg_size_bytes // dtype.itemsize
# Data buffers
val = 15
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
out_hdl = symm_mem.rendezvous(out, group=group_name)
# Use signal pad as completion flag
flag_val = 42
peer = 1 - rank
NVSHMEM_CMP_EQ = 0
@triton.jit
def wait_until_kernel(
ivar_ptr,
cmp_op: tl.constexpr,
cmp_val: tl.constexpr,
):
nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val)
dist.barrier()
if rank == 0:
# Rank 0 waits for flag from Rank 1
ivar_ptr = out_hdl.signal_pad_ptrs[rank]
wait_until_kernel[(1, 1, 1)](
ivar_ptr,
cmp_op=NVSHMEM_CMP_EQ,
cmp_val=flag_val,
extern_libs=nvshmem_lib,
)
# After flag is set, data should be complete due to quiet
torch.testing.assert_close(
out, val * torch.ones(numel, dtype=dtype, device=self.device)
)
if rank == 1:
# Rank 1 puts data and flag with quiet in between
dst_ptr = out_hdl.buffer_ptrs[rank]
src_ptr = inp_hdl.buffer_ptrs[rank]
flag_dst_ptr = out_hdl.signal_pad_ptrs[rank]
# Create a tensor for the flag value
flag_update_val = torch.tensor(
[flag_val], dtype=torch.int64, device=self.device
)
flag_src_ptr = flag_update_val.data_ptr()
put_with_quiet_kernel[(1, 1, 1)](
dst_ptr,
src_ptr,
flag_dst_ptr,
flag_src_ptr,
numel=numel,
peer=peer,
extern_libs=nvshmem_lib,
)
if __name__ == "__main__":
run_tests()