113 lines
3.4 KiB
Python
113 lines
3.4 KiB
Python
import logging
|
|
import multiprocessing as mp
|
|
import socket
|
|
from typing import Any
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from flashinfer.comm.nvshmem_allreduce import NVSHMEMAllReduce
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _run_correctness_worker(world_size, rank, distributed_init_port):
|
|
assert rank >= 0
|
|
torch.cuda.set_device(rank)
|
|
device = torch.device("cuda", rank)
|
|
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
|
dist.init_process_group(
|
|
backend="cpu:gloo,cuda:nccl",
|
|
rank=rank,
|
|
world_size=world_size,
|
|
device_id=device,
|
|
init_method=distributed_init_method,
|
|
)
|
|
group = dist.group.WORLD
|
|
num_ranks = torch.distributed.get_world_size()
|
|
rank_id = torch.distributed.get_rank()
|
|
|
|
batch_sizes = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
|
|
max_batch_size = 4096
|
|
hidden_dim = 8192
|
|
test_loop = 10
|
|
tensor_dtype = torch.bfloat16
|
|
nvshmem_allreduce = NVSHMEMAllReduce(
|
|
rank_id,
|
|
num_ranks,
|
|
max_batch_size * hidden_dim,
|
|
tensor_dtype,
|
|
device,
|
|
group,
|
|
)
|
|
|
|
try:
|
|
for batch_size in batch_sizes:
|
|
for _ in range(test_loop):
|
|
tensor_size = batch_size * hidden_dim
|
|
inp1 = torch.full(
|
|
[tensor_size], rank_id, dtype=tensor_dtype, device=device
|
|
)
|
|
inp1_ref = inp1.clone()
|
|
out1 = torch.empty_like(inp1)
|
|
nvshmem_allreduce.all_reduce(inp1, out1)
|
|
torch.distributed.all_reduce(inp1_ref, group=group)
|
|
torch.cuda.synchronize()
|
|
torch.testing.assert_close(out1, inp1_ref)
|
|
torch.distributed.barrier(group)
|
|
except Exception as e:
|
|
print(f"Rank {rank_id}: Exception during test: {e}")
|
|
raise
|
|
finally:
|
|
torch.distributed.barrier(group)
|
|
nvshmem_allreduce.shutdown()
|
|
torch.distributed.destroy_process_group(group)
|
|
|
|
|
|
def get_open_port() -> int:
|
|
try:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("127.0.0.1", 0))
|
|
return s.getsockname()[1]
|
|
except OSError:
|
|
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
|
s.bind(("::1", 0))
|
|
return s.getsockname()[1]
|
|
|
|
|
|
def multi_process_parallel(
|
|
world_size: int, test_target: Any, target_args: tuple = ()
|
|
) -> None:
|
|
mp.set_start_method("spawn", force=True)
|
|
|
|
procs = []
|
|
distributed_init_port = get_open_port()
|
|
for i in range(world_size):
|
|
proc_args = (world_size, i, distributed_init_port) + target_args
|
|
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
|
|
proc.start()
|
|
procs.append(proc)
|
|
|
|
for i in range(world_size):
|
|
procs[i].join()
|
|
assert procs[i].exitcode == 0, (
|
|
f"Process {i} failed with exit code {procs[i].exitcode}"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("world_size", [8])
|
|
def test_nvshmem_allreduce(world_size):
|
|
available_gpus = torch.cuda.device_count()
|
|
if world_size > available_gpus:
|
|
pytest.skip(
|
|
f"world_size {world_size} is greater than available_gpus {available_gpus}"
|
|
)
|
|
print(f"Running test for world_size={world_size}")
|
|
multi_process_parallel(
|
|
world_size,
|
|
_run_correctness_worker,
|
|
target_args=(),
|
|
)
|
|
print(f"NVSHMEM allreduce tp = {world_size}: OK")
|