sglang.0.4.8.post1/sglang/sgl-kernel/tests/test_mscclpp.py

147 lines
4.5 KiB
Python

import multiprocessing as mp
import os
import socket
import unittest
from enum import IntEnum
from typing import Any
import sgl_kernel.allreduce as custom_ops
import torch
import torch.distributed as dist
class MscclContextSelection(IntEnum):
MSCCL1SHOT1NODELL = 1
MSCCL1SHOT2NODELL = 2
def _run_correctness_worker(world_size, rank, distributed_init_port, test_sizes):
device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
if rank == 0:
unique_id = [custom_ops.mscclpp_generate_unique_id()]
else:
unique_id = [None]
dist.broadcast_object_list(
unique_id, src=0, device=torch.device("cpu"), group=cpu_group
)
unique_id = unique_id[0]
rank_to_node, rank_to_ib = list(range(world_size)), list(range(world_size))
for r in range(world_size):
rank_to_node[r] = r // 8
rank_to_ib[r] = rank % 8
MAX_BYTES = 2**20
scratch = torch.empty(
MAX_BYTES * 8, dtype=torch.bfloat16, device=torch.cuda.current_device()
)
put_buffer = torch.empty(
MAX_BYTES, dtype=torch.bfloat16, device=torch.cuda.current_device()
)
print(f"[{rank}] start mscclpp_context init")
nranks_per_node = torch.cuda.device_count()
selection = int(MscclContextSelection.MSCCL1SHOT1NODELL)
mscclpp_context = custom_ops.mscclpp_init_context(
unique_id,
rank,
world_size,
scratch,
put_buffer,
nranks_per_node,
rank_to_node,
rank_to_ib,
selection,
)
try:
test_loop = 10
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
if sz * dtype.itemsize > MAX_BYTES:
continue
if rank == 0:
print(f"mscclpp allreduce test sz {sz}, dtype {dtype}")
for _ in range(test_loop):
inp1 = torch.randint(1, 16, (sz,), dtype=dtype, device=device)
inp1_ref = inp1.clone()
out1 = torch.empty_like(inp1)
custom_ops.mscclpp_allreduce(
mscclpp_context, inp1, out1, nthreads=512, nblocks=21
)
dist.all_reduce(inp1_ref, group=group)
torch.testing.assert_close(out1, inp1_ref)
finally:
dist.barrier(group=group)
dist.destroy_process_group(group=group)
def get_open_port() -> int:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("::1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int, test_target: Any, target_args: tuple = ()
) -> None:
mp.set_start_method("spawn", force=True)
procs = []
distributed_init_port = get_open_port()
for i in range(world_size):
proc_args = (world_size, i, distributed_init_port) + target_args
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
proc.start()
procs.append(proc)
for i in range(world_size):
procs[i].join()
assert (
procs[i].exitcode == 0
), f"Process {i} failed with exit code {procs[i].exitcode}"
class TestMSCCLAllReduce(unittest.TestCase):
test_sizes = [
512,
2560,
4096,
5120,
7680,
32768,
262144,
524288,
]
world_sizes = [8]
def test_correctness(self):
for world_size in self.world_sizes:
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
print(
f"Skipping world_size={world_size}, found {available_gpus} and now ray is not supported here"
)
continue
print(f"Running test for world_size={world_size}")
multi_process_parallel(
world_size, _run_correctness_worker, target_args=(self.test_sizes,)
)
print(f"custom allreduce tp = {world_size}: OK")
if __name__ == "__main__":
unittest.main()