import multiprocessing as mp import os import socket import unittest from enum import IntEnum from typing import Any import sgl_kernel.allreduce as custom_ops import torch import torch.distributed as dist class MscclContextSelection(IntEnum): MSCCL1SHOT1NODELL = 1 MSCCL1SHOT2NODELL = 2 def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes): device = torch.device(f"cuda:{rank % torch.cuda.device_count()}") torch.cuda.set_device(device) distributed_init_method = f"tcp://localhost:{distributed_init_port}" dist.init_process_group( backend="nccl", init_method=distributed_init_method, rank=rank, world_size=world_size, ) group = dist.group.WORLD cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo") if rank == 0: unique_id = [custom_ops.mscclpp_generate_unique_id()] else: unique_id = [None] dist.broadcast_object_list( unique_id, src=0, device=torch.device("cpu"), group=cpu_group ) unique_id = unique_id[0] rank_to_node, rank_to_ib = list(range(world_size)), list(range(world_size)) for r in range(world_size): rank_to_node[r] = r // 8 rank_to_ib[r] = rank % 8 MAX_BYTES = 2**20 scratch = torch.empty( MAX_BYTES * 8, dtype=torch.bfloat16, device=torch.cuda.current_device() ) put_buffer = torch.empty( MAX_BYTES, dtype=torch.bfloat16, device=torch.cuda.current_device() ) print(f"[{rank}] start mscclpp_context init") nranks_per_node = torch.cuda.device_count() selection = int(MscclContextSelection.MSCCL1SHOT1NODELL) mscclpp_context = custom_ops.mscclpp_init_context( unique_id, rank, world_size, scratch, put_buffer, nranks_per_node, rank_to_node, rank_to_ib, selection, ) try: test_loop = 10 for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: if sz * dtype.itemsize > MAX_BYTES: continue if rank == 0: print(f"mscclpp allreduce test sz {sz}, dtype {dtype}") for _ in range(test_loop): inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device) inp1_ref = inp1.clone() out1 = torch.empty_like(inp1) custom_ops.mscclpp_allreduce( mscclpp_context, inp1, out1, nthreads=512, nblocks=21 ) dist.all_reduce(inp1_ref, group=group) torch.testing.assert_close(out1, inp1_ref) finally: dist.barrier(group=group) dist.destroy_process_group(group=group) def get_open_port() -> int: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] except OSError: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: s.bind(("::1", 0)) return s.getsockname()[1] def multi_process_parallel( world_size: int, test_target: Any, target_args: tuple = () ) -> None: mp.set_start_method("spawn", force=True) procs = [] distributed_init_port = get_open_port() for i in range(world_size): proc_args = (world_size, i, distributed_init_port) + target_args proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") proc.start() procs.append(proc) for i in range(world_size): procs[i].join() assert ( procs[i].exitcode == 0 ), f"Process {i} failed with exit code {procs[i].exitcode}" class TestMSCCLAllReduce(unittest.TestCase): test_sizes = [ 512, 2560, 4096, 5120, 7680, 32768, 262144, 524288, ] world_sizes = [8] def test_correctness(self): for world_size in self.world_sizes: available_gpus = torch.cuda.device_count() if world_size > available_gpus: print( f"Skipping world_size={world_size}, found {available_gpus} and now ray is not supported here" ) continue print(f"Running test for world_size={world_size}") multi_process_parallel( world_size, _run_correctness_worker, target_args=(self.test_sizes,) ) print(f"custom allreduce tp = {world_size}: OK") if __name__ == "__main__": unittest.main()