import ctypes import multiprocessing as mp import random import socket import unittest from typing import Any, List, Optional import sgl_kernel.allreduce as custom_ops import torch import torch.distributed as dist from torch.distributed import ProcessGroup from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes): device = torch.device(f"cuda:{rank}") torch.cuda.set_device(device) distributed_init_method = f"tcp://localhost:{distributed_init_port}" dist.init_process_group( backend="nccl", init_method=distributed_init_method, rank=rank, world_size=world_size, ) group = dist.group.WORLD try: device = torch.device(f"cuda:{rank}") max_size = 8192 * 1024 meta_ptrs = TestCustomAllReduce.create_shared_buffer( custom_ops.meta_size() + max_size, group=group ) rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=device) buffer_ptrs = TestCustomAllReduce.create_shared_buffer(max_size, group=group) custom_ptr = custom_ops.init_custom_ar(meta_ptrs, rank_data, rank, True) custom_ops.register_buffer(custom_ptr, buffer_ptrs) test_loop = 10 for sz in test_sizes: for dtype in [torch.float32, torch.float16, torch.bfloat16]: for _ in range(test_loop): inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device) inp1_ref = inp1.clone() out1 = torch.empty_like(inp1) custom_ops.all_reduce( custom_ptr, inp1, out1, buffer_ptrs[rank], max_size ) dist.all_reduce(inp1_ref, group=group) torch.testing.assert_close(out1, inp1_ref) finally: dist.barrier(group=group) if custom_ptr is not None: custom_ops.dispose(custom_ptr) if buffer_ptrs: TestCustomAllReduce.free_shared_buffer(buffer_ptrs, group) if meta_ptrs: TestCustomAllReduce.free_shared_buffer(meta_ptrs, group) dist.destroy_process_group(group=group) def get_open_port() -> int: try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("127.0.0.1", 0)) return s.getsockname()[1] except OSError: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: s.bind(("::1", 0)) return s.getsockname()[1] def multi_process_parallel( world_size: int, test_target: Any, target_args: tuple = () ) -> None: mp.set_start_method("spawn", force=True) procs = [] distributed_init_port = get_open_port() for i in range(world_size): proc_args = (world_size, i, distributed_init_port) + target_args proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}") proc.start() procs.append(proc) for i in range(world_size): procs[i].join() assert ( procs[i].exitcode == 0 ), f"Process {i} failed with exit code {procs[i].exitcode}" class TestCustomAllReduce(unittest.TestCase): test_sizes = [ 512, 2560, 4096, 5120, 7680, 32768, 262144, 524288, 1048576, 2097152, ] world_sizes = [2, 4, 8] @staticmethod def create_shared_buffer( size_in_bytes: int, group: Optional[ProcessGroup] = None ) -> List[int]: lib = CudaRTLibrary() pointer = lib.cudaMalloc(size_in_bytes) handle = lib.cudaIpcGetMemHandle(pointer) if group is None: group = dist.group.WORLD world_size = dist.get_world_size(group=group) rank = dist.get_rank(group=group) handle_bytes = ctypes.string_at(ctypes.addressof(handle), ctypes.sizeof(handle)) input_tensor = torch.ByteTensor(list(handle_bytes)).to(f"cuda:{rank}") gathered_tensors = [torch.empty_like(input_tensor) for _ in range(world_size)] dist.all_gather(gathered_tensors, input_tensor, group=group) handles = [] handle_type = type(handle) for tensor in gathered_tensors: bytes_list = tensor.cpu().tolist() bytes_data = bytes(bytes_list) handle_obj = handle_type() ctypes.memmove(ctypes.addressof(handle_obj), bytes_data, len(bytes_data)) handles.append(handle_obj) pointers: List[int] = [] for i, h in enumerate(handles): if i == rank: pointers.append(pointer.value) else: try: opened_ptr = lib.cudaIpcOpenMemHandle(h) pointers.append(opened_ptr.value) except Exception as e: print(f"Rank {rank}: Failed to open IPC handle from rank {i}: {e}") raise dist.barrier(group=group) return pointers @staticmethod def free_shared_buffer( pointers: List[int], group: Optional[ProcessGroup] = None ) -> None: if group is None: group = dist.group.WORLD rank = dist.get_rank(group=group) lib = CudaRTLibrary() if pointers and len(pointers) > rank and pointers[rank] is not None: lib.cudaFree(ctypes.c_void_p(pointers[rank])) dist.barrier(group=group) def test_correctness(self): for world_size in self.world_sizes: available_gpus = torch.cuda.device_count() if world_size > available_gpus: print( f"Skipping world_size={world_size}, requires {world_size} GPUs, found {available_gpus}" ) continue print(f"Running test for world_size={world_size}") multi_process_parallel( world_size, _run_correctness_worker, target_args=(self.test_sizes,) ) print(f"custom allreduce tp = {world_size}: OK") if __name__ == "__main__": unittest.main()