984 lines
35 KiB
Python
984 lines
35 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
# To run:
|
|
# TORCH_SYMMMEM=NVSHMEM python test/distributed/test_nvshmem.py
|
|
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed._symmetric_memory as symm_mem
|
|
import torch.distributed._symmetric_memory._nvshmem_triton as nvshmem
|
|
from torch._inductor.runtime.triton_compat import tl, triton
|
|
from torch.testing._internal.common_distributed import MultiProcContinousTest
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
skip_but_pass_in_sandcastle_if,
|
|
skipIfRocm,
|
|
)
|
|
from torch.testing._internal.inductor_utils import requires_triton
|
|
|
|
|
|
# Decorator
|
|
def requires_nvshmem():
|
|
return skip_but_pass_in_sandcastle_if(
|
|
not symm_mem.is_nvshmem_available(),
|
|
"test_nvshmem requires NVSHMEM, skipping tests",
|
|
)
|
|
|
|
|
|
# So that tests are written in device-agnostic way
|
|
device_type = "cuda"
|
|
device_module = torch.get_device_module(device_type)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
@requires_nvshmem()
|
|
class NVSHMEMSymmetricMemoryTest(MultiProcContinousTest):
|
|
def _init_device(self) -> None:
|
|
# TODO: relieve this (seems to hang if without)
|
|
device_module.set_device(self.device)
|
|
# NOTE: required for nvshmem allocation
|
|
torch.empty(1, device=self.device)
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return torch.device(device_type, self.rank)
|
|
|
|
@skipIfRocm
|
|
def test_alloc(self) -> None:
|
|
self._init_device()
|
|
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
|
|
dtype = torch.float
|
|
numel = 1024
|
|
|
|
def foo():
|
|
inp = symm_mem.empty(numel, dtype=dtype, device=self.device)
|
|
symm_mem.rendezvous(inp, group=group_name)
|
|
|
|
foo()
|
|
|
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device)
|
|
symm_mem.rendezvous(out, group=group_name)
|
|
|
|
@skipIfRocm
|
|
def test_nvshmem_put(self) -> None:
|
|
self._init_device()
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
|
|
dtype = torch.float
|
|
numel = 1024
|
|
tensor = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
|
|
symm_mem.rendezvous(tensor, group=group_name)
|
|
|
|
if self.rank == 0:
|
|
torch.ops.symm_mem.nvshmem_put(tensor, 1)
|
|
# TODO: remove after we have wait_signal
|
|
dist.barrier()
|
|
elif self.rank == 1:
|
|
# handle.wait_signal(src_rank=0)
|
|
# TODO: remove after we have wait_signal
|
|
dist.barrier()
|
|
torch.testing.assert_close(
|
|
tensor, torch.zeros(numel, dtype=dtype, device=self.device)
|
|
)
|
|
else:
|
|
dist.barrier()
|
|
|
|
@skipIfRocm
|
|
def test_nvshmem_all_to_all(self) -> None:
|
|
self._init_device()
|
|
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
|
|
dtype = torch.float
|
|
numel_per_peer = 10
|
|
numel = self.world_size * numel_per_peer
|
|
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(self.rank)
|
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
|
|
|
symm_mem.rendezvous(inp, group=group_name)
|
|
symm_mem.rendezvous(out, group=group_name)
|
|
torch.ops.symm_mem.nvshmem_all_to_all(inp, out, group_name)
|
|
|
|
expected = torch.cat(
|
|
[
|
|
torch.empty(numel_per_peer, dtype=dtype, device=self.device).fill_(i)
|
|
for i in range(self.world_size)
|
|
]
|
|
)
|
|
torch.testing.assert_close(out, expected)
|
|
|
|
@skipIfRocm
|
|
def test_all_to_all_vdev(self) -> None:
|
|
self._init_device()
|
|
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
|
|
dtype = torch.float
|
|
# Number of elements for a peer is random between [0, k)
|
|
k = 10
|
|
inp_splits = torch.randint(k, (self.world_size,), device=self.device)
|
|
inp_numel = inp_splits.sum().item()
|
|
# Exchange input splits to get output splits
|
|
out_splits = torch.zeros_like(inp_splits)
|
|
dist.all_to_all_single(out_splits, inp_splits)
|
|
out_numel = out_splits.sum().item()
|
|
|
|
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
|
|
max_inp_numel = k * self.world_size
|
|
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
|
|
overflow_factor = self.world_size # worst case: one rank receives all data
|
|
max_out_numel = max_inp_numel * overflow_factor
|
|
|
|
inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).fill_(
|
|
self.rank
|
|
)
|
|
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
|
|
in_out_splits = symm_mem.empty(
|
|
(3, self.world_size), dtype=torch.int64, device=self.device
|
|
)
|
|
# Row 0 is input splits
|
|
in_out_splits[0].copy_(inp_splits)
|
|
|
|
torch.ops.symm_mem.all_to_all_vdev(inp, out, in_out_splits, group_name)
|
|
|
|
# Check input splits (row 0) -- should not change
|
|
torch.testing.assert_close(in_out_splits[0], inp_splits)
|
|
|
|
# Check output splits (row 1)
|
|
torch.testing.assert_close(in_out_splits[1], out_splits)
|
|
|
|
# Check output offsets (row 2)
|
|
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
|
|
# output offsets from `all_to_all_vdev` is exclusive scan
|
|
self.assertEqual(in_out_splits[2][0], 0)
|
|
torch.testing.assert_close(in_out_splits[2][1:], out_offsets[:-1])
|
|
|
|
# Check data
|
|
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
|
|
dist.all_to_all_single(
|
|
expected, inp[:inp_numel], out_splits.tolist(), inp_splits.tolist()
|
|
)
|
|
torch.testing.assert_close(out[:out_numel], expected)
|
|
|
|
@skipIfRocm
|
|
@parametrize("align", [1, 8, 16]) # `major_align` of output
|
|
def test_all_to_all_vdev_2d(self, align: int) -> None:
|
|
torch.manual_seed(42 + self.rank)
|
|
self._init_device()
|
|
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
|
|
dtype = torch.float
|
|
# Number of experts per rank
|
|
ne = 8
|
|
nsplits = ne * self.world_size
|
|
|
|
# Number of elements for an expert is random between [0, k)
|
|
k = 10
|
|
inp_splits = torch.randint(k, (nsplits,), dtype=torch.int64, device=self.device)
|
|
|
|
# Exchange input splits to get output splits
|
|
out_splits = torch.zeros_like(inp_splits)
|
|
dist.all_to_all_single(out_splits, inp_splits)
|
|
# We do a .t() here because there is a rank-major to expert-major shuffle
|
|
out_splits_t = out_splits.reshape(self.world_size, ne).t()
|
|
|
|
# Actual number of input elements
|
|
inp_numel = inp_splits.sum().item()
|
|
# Actual number of output elements
|
|
out_numel = out_splits.sum().item()
|
|
# Max number of input elements (must be a constant across ranks for symmetric memory allocation)
|
|
max_inp_numel = k * nsplits
|
|
# Max number of output elements (must be a constant across ranks for symmetric memory allocation)
|
|
overflow_factor = self.world_size # worst case: one rank receives all data
|
|
max_out_numel = max_inp_numel * overflow_factor
|
|
|
|
inp = symm_mem.empty(max_inp_numel, dtype=dtype, device=self.device).fill_(
|
|
self.rank
|
|
)
|
|
out = symm_mem.empty(max_out_numel, dtype=dtype, device=self.device).fill_(-1)
|
|
# 3 rows: input splits, output splits, output offsets
|
|
# Initiallizing all values to -1 to check if they are updated
|
|
in_out_splits = symm_mem.empty(
|
|
(3, nsplits), dtype=torch.int64, device=self.device
|
|
).fill_(-1)
|
|
# Row 0 is input splits
|
|
in_out_splits[0].copy_(inp_splits)
|
|
|
|
torch.ops.symm_mem.all_to_all_vdev_2d(
|
|
inp, out, in_out_splits, group_name, major_align=align
|
|
)
|
|
received_out_splits = in_out_splits[1]
|
|
received_out_offsets = in_out_splits[2]
|
|
|
|
# Check input splits (row 0) -- should not change
|
|
torch.testing.assert_close(in_out_splits[0], inp_splits)
|
|
|
|
# Check output splits (row 1)
|
|
torch.testing.assert_close(received_out_splits, out_splits_t.reshape(-1))
|
|
|
|
# Check output offsets (row 2)
|
|
out_split_list = out_splits_t.tolist()
|
|
for i in range(ne):
|
|
expert_sum = 0
|
|
for j in range(self.world_size):
|
|
expert_sum += out_split_list[i][j]
|
|
# Align up expert_sum
|
|
expert_sum_aligned = (expert_sum + align - 1) // align * align
|
|
# If 0, make it at least `align` (bc cutlass currently does not support empty bins)
|
|
expert_sum_aligned = max(expert_sum_aligned, align)
|
|
# last element absorbs the padding
|
|
out_split_list[i][-1] += expert_sum_aligned - expert_sum
|
|
|
|
out_splits_padded = torch.tensor(out_split_list, device=self.device).reshape(-1)
|
|
out_offsets = torch.cumsum(out_splits_padded, dim=0) # inclusive scan
|
|
# Make it exclusive scan because that's what `all_to_all_vdev_2d` returns
|
|
out_offsets = torch.cat(
|
|
[torch.zeros(1, device=self.device), out_offsets[:-1]]
|
|
).to(torch.int64)
|
|
torch.testing.assert_close(received_out_offsets, out_offsets)
|
|
|
|
# Check data
|
|
expected = torch.empty(out_numel, dtype=dtype, device=self.device)
|
|
inp_splits_rank = inp_splits.reshape(self.world_size, ne).sum(1)
|
|
out_splits_rank = out_splits.reshape(self.world_size, ne).sum(1)
|
|
dist.all_to_all_single(
|
|
expected,
|
|
inp[:inp_numel],
|
|
out_splits_rank.tolist(),
|
|
inp_splits_rank.tolist(),
|
|
)
|
|
# We still need to shuffle `expected`
|
|
out_offsets = torch.cumsum(out_splits, dim=0) # inclusive scan
|
|
result_list = []
|
|
for j in range(ne):
|
|
for i in range(self.world_size):
|
|
chunk_id = i * ne + j
|
|
offset = out_offsets[chunk_id]
|
|
chunk = expected[offset - out_splits[chunk_id] : offset]
|
|
result_list.append(chunk)
|
|
|
|
# Do a chunk-wise comparison
|
|
for c, chunk in enumerate(result_list):
|
|
start = received_out_offsets[c].item()
|
|
split = received_out_splits[c].item()
|
|
received_chunk = out[start : start + split]
|
|
torch.testing.assert_close(received_chunk, chunk)
|
|
|
|
@skipIfRocm
|
|
@requires_triton()
|
|
def test_triton_put(self) -> None:
|
|
# A Triton kernel that calls nvshmem device side API
|
|
@triton.jit
|
|
def put_kernel(
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel: tl.constexpr,
|
|
peer: tl.constexpr,
|
|
):
|
|
nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
self._init_device()
|
|
|
|
# Enable NVSHMEM for Triton
|
|
nvshmem_lib = nvshmem.enable_triton()
|
|
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
rank = self.rank
|
|
|
|
msg_size_bytes = 8
|
|
dtype = torch.int8
|
|
numel = msg_size_bytes // dtype.itemsize
|
|
|
|
val = 5
|
|
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
|
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
|
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
|
|
out_hdl = symm_mem.rendezvous(out, group=group_name)
|
|
|
|
peer = 1 - rank
|
|
if rank == 0:
|
|
dst_ptr = out_hdl.buffer_ptrs[rank]
|
|
src_ptr = inp_hdl.buffer_ptrs[rank]
|
|
put_kernel[(1, 1, 1)](
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel=numel,
|
|
peer=peer,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
|
|
dist.barrier()
|
|
if rank == 1:
|
|
torch.testing.assert_close(
|
|
out, val * torch.ones(numel, dtype=dtype, device=self.device)
|
|
)
|
|
|
|
@skipIfRocm
|
|
@requires_triton()
|
|
def test_triton_get(self) -> None:
|
|
# A Triton kernel that calls nvshmem device side API for GET
|
|
@triton.jit
|
|
def get_kernel(
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel: tl.constexpr,
|
|
peer: tl.constexpr,
|
|
):
|
|
nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
self._init_device()
|
|
|
|
nvshmem_lib = nvshmem.enable_triton()
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
rank = self.rank
|
|
msg_size_bytes = 8
|
|
dtype = torch.int8
|
|
numel = msg_size_bytes // dtype.itemsize
|
|
val = 7
|
|
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(
|
|
val if rank == 0 else -1
|
|
)
|
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
|
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
|
|
out_hdl = symm_mem.rendezvous(out, group=group_name)
|
|
dist.barrier()
|
|
peer = 1 - rank
|
|
if rank == 1:
|
|
# Rank 1 gets data from rank 0
|
|
dst_ptr = out_hdl.buffer_ptrs[rank]
|
|
src_ptr = inp_hdl.buffer_ptrs[rank]
|
|
get_kernel[(1, 1, 1)](
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel=numel,
|
|
peer=peer,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
if rank == 1:
|
|
torch.testing.assert_close(
|
|
out, val * torch.ones(numel, dtype=dtype, device=self.device)
|
|
)
|
|
|
|
@skipIfRocm
|
|
@requires_triton()
|
|
def test_triton_get_ring(self) -> None:
|
|
# A Triton kernel that calls nvshmem device side API for GET
|
|
# with ring topology
|
|
@triton.jit
|
|
def get_kernel(
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel: tl.constexpr,
|
|
peer: tl.constexpr,
|
|
):
|
|
nvshmem.getmem_block(dst_ptr, src_ptr, numel, peer)
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
self._init_device()
|
|
|
|
nvshmem_lib = nvshmem.enable_triton()
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
rank = self.rank
|
|
world_size = dist.get_world_size()
|
|
msg_size_bytes = 8
|
|
dtype = torch.int8
|
|
numel = msg_size_bytes // dtype.itemsize
|
|
|
|
# Each rank fills its input buffer with its own rank value
|
|
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(rank)
|
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
|
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
|
|
out_hdl = symm_mem.rendezvous(out, group=group_name)
|
|
dist.barrier()
|
|
|
|
# Ring topology: each rank gets data from the rank to its left
|
|
# rank 0 gets from rank (world_size-1), rank 1 gets from rank 0, etc.
|
|
peer = (rank - 1) % world_size
|
|
|
|
# All ranks execute the get operation
|
|
dst_ptr = out_hdl.buffer_ptrs[rank]
|
|
src_ptr = inp_hdl.buffer_ptrs[rank]
|
|
get_kernel[(1, 1, 1)](
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel=numel,
|
|
peer=peer,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
|
|
expected_value = peer
|
|
torch.testing.assert_close(
|
|
out, expected_value * torch.ones(numel, dtype=dtype, device=self.device)
|
|
)
|
|
|
|
@skipIfRocm
|
|
@requires_triton()
|
|
def test_triton_put_signal_set(self) -> None:
|
|
# A Triton kernel that calls nvshmem device side API for PUT with SIGNAL
|
|
@triton.jit
|
|
def put_signal_kernel(
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel: tl.constexpr,
|
|
sig_ptr,
|
|
signal_val: tl.constexpr,
|
|
sig_op: tl.constexpr,
|
|
peer: tl.constexpr,
|
|
):
|
|
nvshmem.putmem_signal_block(
|
|
dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer
|
|
)
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
self._init_device()
|
|
|
|
nvshmem_lib = nvshmem.enable_triton()
|
|
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
rank = self.rank
|
|
|
|
msg_size_bytes = 8
|
|
dtype = torch.int8
|
|
numel = msg_size_bytes // dtype.itemsize
|
|
|
|
# Data buffers
|
|
val = 11
|
|
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
|
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
|
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
|
|
out_hdl = symm_mem.rendezvous(out, group=group_name)
|
|
|
|
# Use the signal pad attached to the output symmetric memory handle
|
|
# as the flag buffer for signaling completion.
|
|
flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0)
|
|
|
|
peer = 1 - rank
|
|
NVSHMEM_SIGNAL_SET = 0 # value defined by NVSHMEM for atomic set
|
|
SIGNAL_VAL = 1 # Signal completion value
|
|
NVSHMEM_CMP_EQ = 0 # compare equal for signal wait until
|
|
|
|
# Kernel for waiting on the signal locally (Rank 1).
|
|
@triton.jit
|
|
def signal_wait_until_kernel(
|
|
sig_ptr, cmp_op: tl.constexpr, cmp_val: tl.constexpr
|
|
):
|
|
nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val)
|
|
|
|
if rank == 0:
|
|
# Rank 0 puts into Rank 1
|
|
dst_ptr = out_hdl.buffer_ptrs[peer]
|
|
src_ptr = inp_hdl.buffer_ptrs[rank]
|
|
sig_ptr = out_hdl.signal_pad_ptrs[peer]
|
|
put_signal_kernel[(1, 1, 1)](
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel=numel,
|
|
sig_ptr=sig_ptr,
|
|
signal_val=SIGNAL_VAL,
|
|
sig_op=NVSHMEM_SIGNAL_SET,
|
|
peer=peer,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
|
|
if rank == 1:
|
|
# Wait until signal flag is set by Rank 0
|
|
sig_ptr_local = out_hdl.signal_pad_ptrs[rank]
|
|
signal_wait_until_kernel[(1,)](
|
|
sig_ptr_local,
|
|
cmp_op=NVSHMEM_CMP_EQ,
|
|
cmp_val=SIGNAL_VAL,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
# After wait completes, verify data and flag contents
|
|
torch.testing.assert_close(
|
|
out, val * torch.ones(numel, dtype=dtype, device=self.device)
|
|
)
|
|
torch.testing.assert_close(
|
|
flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device)
|
|
)
|
|
|
|
@skipIfRocm
|
|
@requires_triton()
|
|
def test_triton_put_signal_add(self) -> None:
|
|
# A Triton kernel that calls nvshmem device side API for PUT with SIGNAL
|
|
@triton.jit
|
|
def put_signal_kernel(
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel: tl.constexpr,
|
|
sig_ptr,
|
|
signal_val: tl.constexpr,
|
|
sig_op: tl.constexpr,
|
|
peer: tl.constexpr,
|
|
):
|
|
nvshmem.putmem_signal_block(
|
|
dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer
|
|
)
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
self._init_device()
|
|
|
|
nvshmem_lib = nvshmem.enable_triton()
|
|
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
rank = self.rank
|
|
|
|
msg_size_bytes = 8
|
|
dtype = torch.int8
|
|
numel = msg_size_bytes // dtype.itemsize
|
|
|
|
# Data buffers
|
|
val = 11
|
|
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
|
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
|
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
|
|
out_hdl = symm_mem.rendezvous(out, group=group_name)
|
|
|
|
# Use the signal pad attached to the output symmetric memory handle
|
|
# as the flag buffer for signaling completion.
|
|
flag = out_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0)
|
|
|
|
peer = 1 - rank
|
|
NVSHMEM_SIGNAL_ADD = 5 # atomic add operation
|
|
SIGNAL_VAL = 16 # val + NVSHMEM_SIGNAL_ADD
|
|
NVSHMEM_CMP_EQ = 0
|
|
|
|
@triton.jit
|
|
def signal_wait_until_kernel(
|
|
sig_ptr, cmp_op: tl.constexpr, cmp_val: tl.constexpr
|
|
):
|
|
nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val)
|
|
|
|
if rank == 0:
|
|
# Rank 0 puts into Rank 1
|
|
dst_ptr = out_hdl.buffer_ptrs[peer]
|
|
src_ptr = inp_hdl.buffer_ptrs[rank]
|
|
sig_ptr = out_hdl.signal_pad_ptrs[peer]
|
|
put_signal_kernel[(1, 1, 1)](
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel=numel,
|
|
sig_ptr=sig_ptr,
|
|
signal_val=SIGNAL_VAL,
|
|
sig_op=NVSHMEM_SIGNAL_ADD,
|
|
peer=peer,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
|
|
if rank == 1:
|
|
sig_ptr_local = out_hdl.signal_pad_ptrs[rank]
|
|
signal_wait_until_kernel[(1, 1, 1)](
|
|
sig_ptr_local,
|
|
cmp_op=NVSHMEM_CMP_EQ,
|
|
cmp_val=SIGNAL_VAL,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
torch.testing.assert_close(
|
|
out, val * torch.ones(numel, dtype=dtype, device=self.device)
|
|
)
|
|
torch.testing.assert_close(
|
|
flag, torch.tensor([SIGNAL_VAL], dtype=torch.int64, device=self.device)
|
|
)
|
|
|
|
@skipIfRocm
|
|
@requires_triton()
|
|
def test_triton_wait_until(self) -> None:
|
|
# A Triton kernel that calls nvshmem device side API for PUT
|
|
@triton.jit
|
|
def put_kernel(
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel: tl.constexpr,
|
|
peer: tl.constexpr,
|
|
):
|
|
nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)
|
|
|
|
# A Triton kernel that calls nvshmem device side API for WAIT_UNTIL
|
|
@triton.jit
|
|
def wait_until_kernel(
|
|
ivar_ptr,
|
|
cmp_op: tl.constexpr,
|
|
cmp_val: tl.constexpr,
|
|
):
|
|
nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val)
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
self._init_device()
|
|
nvshmem_lib = nvshmem.enable_triton()
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
rank = self.rank
|
|
|
|
# Data buffers
|
|
msg_size_bytes = 8
|
|
dtype = torch.int8
|
|
numel = msg_size_bytes // dtype.itemsize
|
|
val = 13
|
|
flag_val = 21
|
|
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
|
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
|
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
|
|
out_hdl = symm_mem.rendezvous(out, group=group_name)
|
|
dist.barrier()
|
|
|
|
peer = 1 - rank
|
|
NVSHMEM_CMP_EQ = 0 # from nvshmem.h
|
|
|
|
if rank == 0:
|
|
# Rank 0 waits for the flag to be set by Rank 1, then checks the data
|
|
ivar_ptr = out_hdl.signal_pad_ptrs[rank]
|
|
wait_until_kernel[(1, 1, 1)](
|
|
ivar_ptr,
|
|
cmp_op=NVSHMEM_CMP_EQ,
|
|
cmp_val=flag_val,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
torch.testing.assert_close(
|
|
out, val * torch.ones(numel, dtype=dtype, device=self.device)
|
|
)
|
|
|
|
if rank == 1:
|
|
# Rank 1 puts data into Rank 0's output buffer
|
|
dst_ptr = out_hdl.buffer_ptrs[rank]
|
|
src_ptr = inp_hdl.buffer_ptrs[rank]
|
|
put_kernel[(1, 1, 1)](
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel=numel,
|
|
peer=peer,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
# Rank 1 sets the flag on Rank 0
|
|
# We use a temporary tensor for the value to put.
|
|
flag_update_val = torch.tensor(
|
|
[flag_val], dtype=torch.int64, device=self.device
|
|
)
|
|
dst_ptr = out_hdl.signal_pad_ptrs[rank]
|
|
src_ptr = flag_update_val.data_ptr()
|
|
put_kernel[(1, 1, 1)](
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel=1,
|
|
peer=peer,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
|
|
@skipIfRocm
|
|
@requires_triton()
|
|
def test_triton_signal_wait_until(self) -> None:
|
|
# A Triton kernel that waits on a signal variable until it meets the compare condition.
|
|
@triton.jit
|
|
def signal_wait_until_kernel(
|
|
sig_ptr,
|
|
cmp_op: tl.constexpr,
|
|
cmp_val: tl.constexpr,
|
|
):
|
|
nvshmem.signal_wait_until(sig_ptr, cmp_op, cmp_val)
|
|
|
|
# A Triton kernel for the producer that puts data and then signals completion.
|
|
@triton.jit
|
|
def put_and_signal_kernel(
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel: tl.constexpr,
|
|
sig_ptr,
|
|
signal_val: tl.constexpr,
|
|
sig_op: tl.constexpr,
|
|
peer: tl.constexpr,
|
|
):
|
|
nvshmem.putmem_signal_block(
|
|
dst_ptr, src_ptr, numel, sig_ptr, signal_val, sig_op, peer
|
|
)
|
|
|
|
self._init_device()
|
|
# Enable NVSHMEM for Triton
|
|
nvshmem_lib = nvshmem.enable_triton()
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
rank = self.rank
|
|
peer = 1 - rank
|
|
|
|
# NVSHMEM constants from documentation
|
|
NVSHMEM_CMP_EQ = 0 # equal comparison
|
|
NVSHMEM_SIGNAL_SET = 0 # atomic set operation
|
|
|
|
# Message configuration
|
|
msg_size_bytes = 8
|
|
dtype = torch.int8
|
|
numel = msg_size_bytes // dtype.itemsize
|
|
val_to_put = 123 # arbitrary test value
|
|
COMPLETION_FLAG_VAL = 1
|
|
|
|
# Producer (rank 0) prepares the data to send
|
|
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val_to_put)
|
|
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
|
|
# Consumer (rank 1) prepares the destination buffer
|
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
|
out_hdl = symm_mem.rendezvous(out, group=group_name)
|
|
# Use the signal pad for synchronization, as in previous tests
|
|
flag_dtype = torch.int64
|
|
flag = out_hdl.get_signal_pad(rank, (1,), dtype=flag_dtype).fill_(0)
|
|
# Ensure setup is complete on all ranks before proceeding
|
|
dist.barrier()
|
|
|
|
if rank == 0:
|
|
# Producer (rank 0): Puts data into rank 1's `out` buffer and then sets the flag
|
|
dst_ptr = out_hdl.buffer_ptrs[peer]
|
|
src_ptr = inp_hdl.buffer_ptrs[rank]
|
|
sig_ptr = out_hdl.signal_pad_ptrs[peer]
|
|
put_and_signal_kernel[(1, 1, 1)](
|
|
dst_ptr,
|
|
src_ptr,
|
|
numel,
|
|
sig_ptr,
|
|
signal_val=COMPLETION_FLAG_VAL,
|
|
sig_op=NVSHMEM_SIGNAL_SET,
|
|
peer=peer,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
elif rank == 1:
|
|
# Consumer (rank 1): Waits on the signal variable using `signal_wait_until`.
|
|
sig_ptr = out_hdl.signal_pad_ptrs[rank]
|
|
signal_wait_until_kernel[(1, 1, 1)](
|
|
sig_ptr,
|
|
cmp_op=NVSHMEM_CMP_EQ,
|
|
cmp_val=COMPLETION_FLAG_VAL,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
# After the wait returns, verify data and flag
|
|
torch.testing.assert_close(
|
|
out, val_to_put * torch.ones(numel, dtype=dtype, device=self.device)
|
|
)
|
|
torch.testing.assert_close(
|
|
flag,
|
|
torch.tensor(
|
|
[COMPLETION_FLAG_VAL], dtype=flag_dtype, device=self.device
|
|
),
|
|
)
|
|
# Final barrier to ensure the test does not exit before assertions complete
|
|
dist.barrier()
|
|
|
|
@skipIfRocm
|
|
@requires_triton()
|
|
def test_triton_fence(self) -> None:
|
|
"""
|
|
Rank 0 performs two put operations into Rank 1's buffers with a fence
|
|
between them, followed by another fence and a flag update. Rank 1 waits
|
|
for the flag, then verifies that both destination buffers contain the
|
|
expected values. The flag is transferred after the final fence, so
|
|
its arrival implies that both preceding puts have been delivered in
|
|
order.
|
|
"""
|
|
|
|
# Triton kernel that issues two ordered puts separated by fences and
|
|
# finally writes the completion flag.
|
|
@triton.jit
|
|
def put_with_fence_kernel(
|
|
dst_ptr1,
|
|
dst_ptr2,
|
|
src_ptr1,
|
|
src_ptr2,
|
|
flag_ptr,
|
|
flag_src_ptr,
|
|
numel: tl.constexpr,
|
|
peer: tl.constexpr,
|
|
):
|
|
# First put
|
|
nvshmem.putmem_block(dst_ptr1, src_ptr1, numel, peer)
|
|
# Ensure the first put is ordered before the next.
|
|
nvshmem.fence()
|
|
# Second put
|
|
nvshmem.putmem_block(dst_ptr2, src_ptr2, numel, peer)
|
|
# Order the second put before flag update.
|
|
nvshmem.fence()
|
|
# Write the flag (single int64) to signal completion.
|
|
nvshmem.putmem_block(flag_ptr, flag_src_ptr, 1, peer)
|
|
|
|
# Kernel for Rank 1 to wait until the flag becomes the expected value.
|
|
@triton.jit
|
|
def wait_until_kernel(
|
|
ivar_ptr,
|
|
cmp_op: tl.constexpr,
|
|
cmp_val: tl.constexpr,
|
|
):
|
|
nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val)
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
self._init_device()
|
|
nvshmem_lib = nvshmem.enable_triton()
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
rank = self.rank
|
|
peer = 1 - rank
|
|
# Message configuration
|
|
msg_size_bytes = 8
|
|
dtype = torch.int8
|
|
numel = msg_size_bytes // dtype.itemsize
|
|
val1 = 10
|
|
val2 = 20
|
|
flag_val = 1
|
|
# Symmetric buffers
|
|
inp1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val1)
|
|
inp2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val2)
|
|
out1 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
|
out2 = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
|
inp1_hdl = symm_mem.rendezvous(inp1, group=group_name)
|
|
inp2_hdl = symm_mem.rendezvous(inp2, group=group_name)
|
|
out1_hdl = symm_mem.rendezvous(out1, group=group_name)
|
|
out2_hdl = symm_mem.rendezvous(out2, group=group_name)
|
|
|
|
# Flag buffer resides in the signal pad of out2.
|
|
flag = out2_hdl.get_signal_pad(rank, (1,), dtype=torch.int64).fill_(0)
|
|
flag_update_val = torch.tensor(
|
|
[flag_val], dtype=torch.int64, device=self.device
|
|
)
|
|
NVSHMEM_CMP_EQ = 0 # compare equal
|
|
dist.barrier()
|
|
|
|
if rank == 0:
|
|
dst_ptr1 = out1_hdl.buffer_ptrs[rank]
|
|
dst_ptr2 = out2_hdl.buffer_ptrs[rank]
|
|
src_ptr1 = inp1_hdl.buffer_ptrs[rank]
|
|
src_ptr2 = inp2_hdl.buffer_ptrs[rank]
|
|
flag_ptr = out2_hdl.signal_pad_ptrs[rank]
|
|
flag_src_ptr = flag_update_val.data_ptr()
|
|
|
|
put_with_fence_kernel[(1, 1, 1)](
|
|
dst_ptr1,
|
|
dst_ptr2,
|
|
src_ptr1,
|
|
src_ptr2,
|
|
flag_ptr,
|
|
flag_src_ptr,
|
|
numel,
|
|
peer=peer,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
elif rank == 1:
|
|
# Wait until flag is set by Rank 0.
|
|
ivar_ptr = out2_hdl.signal_pad_ptrs[rank]
|
|
wait_until_kernel[(1, 1, 1)](
|
|
ivar_ptr,
|
|
cmp_op=NVSHMEM_CMP_EQ,
|
|
cmp_val=flag_val,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
|
|
# Verify ordered data arrival.
|
|
torch.testing.assert_close(
|
|
out1, val1 * torch.ones(numel, dtype=dtype, device=self.device)
|
|
)
|
|
torch.testing.assert_close(
|
|
out2, val2 * torch.ones(numel, dtype=dtype, device=self.device)
|
|
)
|
|
torch.testing.assert_close(
|
|
flag, torch.tensor([flag_val], dtype=torch.int64, device=self.device)
|
|
)
|
|
dist.barrier()
|
|
|
|
@skipIfRocm
|
|
@requires_triton()
|
|
def test_triton_quiet(self) -> None:
|
|
# A Triton kernel that uses nvshmem_quiet to ensure completion
|
|
@triton.jit
|
|
def put_with_quiet_kernel(
|
|
dst_ptr,
|
|
src_ptr,
|
|
flag_dst_ptr,
|
|
flag_src_ptr,
|
|
numel: tl.constexpr,
|
|
peer: tl.constexpr,
|
|
):
|
|
# Put data
|
|
nvshmem.putmem_block(dst_ptr, src_ptr, numel, peer)
|
|
# Call quiet to ensure put is complete
|
|
nvshmem.quiet()
|
|
# Only after quiet, set the completion flag
|
|
# This ensures the data put is complete before flag is set
|
|
nvshmem.putmem_block(flag_dst_ptr, flag_src_ptr, 1, peer)
|
|
|
|
torch.manual_seed(42 + self.rank)
|
|
self._init_device()
|
|
# Enable NVSHMEM for Triton
|
|
nvshmem_lib = nvshmem.enable_triton()
|
|
group_name = dist.group.WORLD.group_name
|
|
symm_mem.enable_symm_mem_for_group(group_name)
|
|
rank = self.rank
|
|
msg_size_bytes = 8
|
|
dtype = torch.int8
|
|
numel = msg_size_bytes // dtype.itemsize
|
|
# Data buffers
|
|
val = 15
|
|
inp = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(val)
|
|
out = symm_mem.empty(numel, dtype=dtype, device=self.device).fill_(-1)
|
|
inp_hdl = symm_mem.rendezvous(inp, group=group_name)
|
|
out_hdl = symm_mem.rendezvous(out, group=group_name)
|
|
# Use signal pad as completion flag
|
|
flag_val = 42
|
|
peer = 1 - rank
|
|
NVSHMEM_CMP_EQ = 0
|
|
|
|
@triton.jit
|
|
def wait_until_kernel(
|
|
ivar_ptr,
|
|
cmp_op: tl.constexpr,
|
|
cmp_val: tl.constexpr,
|
|
):
|
|
nvshmem.wait_until(ivar_ptr, cmp_op, cmp_val)
|
|
|
|
dist.barrier()
|
|
if rank == 0:
|
|
# Rank 0 waits for flag from Rank 1
|
|
ivar_ptr = out_hdl.signal_pad_ptrs[rank]
|
|
wait_until_kernel[(1, 1, 1)](
|
|
ivar_ptr,
|
|
cmp_op=NVSHMEM_CMP_EQ,
|
|
cmp_val=flag_val,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
# After flag is set, data should be complete due to quiet
|
|
torch.testing.assert_close(
|
|
out, val * torch.ones(numel, dtype=dtype, device=self.device)
|
|
)
|
|
if rank == 1:
|
|
# Rank 1 puts data and flag with quiet in between
|
|
dst_ptr = out_hdl.buffer_ptrs[rank]
|
|
src_ptr = inp_hdl.buffer_ptrs[rank]
|
|
flag_dst_ptr = out_hdl.signal_pad_ptrs[rank]
|
|
# Create a tensor for the flag value
|
|
flag_update_val = torch.tensor(
|
|
[flag_val], dtype=torch.int64, device=self.device
|
|
)
|
|
flag_src_ptr = flag_update_val.data_ptr()
|
|
put_with_quiet_kernel[(1, 1, 1)](
|
|
dst_ptr,
|
|
src_ptr,
|
|
flag_dst_ptr,
|
|
flag_src_ptr,
|
|
numel=numel,
|
|
peer=peer,
|
|
extern_libs=nvshmem_lib,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|