sglang_v0.5.2/flashinfer_0.3.1/tests/test_mnnvl_custom_comm.py

184 lines
5.9 KiB
Python

import multiprocessing as mp
import socket
from typing import Any
import pytest
import torch
import torch.distributed as dist
import pynvml
from flashinfer.comm.mapping import Mapping
from flashinfer.comm.mnnvl import MnnvlConfig, MnnvlMemory
from flashinfer.comm.mnnvl import CommBackend as CommBackend
pynvml.nvmlInit()
class CustomCommunicator(CommBackend):
def __init__(self, group):
self._group = group
def Get_rank(self) -> int:
return dist.get_rank(self._group)
def Get_size(self) -> int:
return dist.get_world_size(self._group)
def allgather(self, data: int | bytes):
device = f"cuda:{torch.cuda.current_device()}"
if isinstance(data, int):
local_tensor = torch.tensor([data], device=device, dtype=torch.int32)
world_size = self.Get_size()
gathered = [torch.zeros_like(local_tensor) for _ in range(world_size)]
dist.all_gather(gathered, local_tensor, group=self._group)
return [int(x.item()) for x in gathered]
elif isinstance(data, bytes):
local_tensor = torch.ByteTensor(list(data)).unsqueeze(0).to(device)
world_size = self.Get_size()
gathered = [data] * self.Get_size()
dist.all_gather_object(gathered, data, group=self._group)
return gathered
else:
raise TypeError(f"Unsupported type for allgather: {type(data)}")
def Split(self, color: int, key: int) -> "CustomCommunicator":
return self
def get_open_port() -> int:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("::1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = ()
) -> None:
mp.set_start_method("spawn", force=True)
procs = []
distributed_init_port = get_open_port()
for i in range(world_size):
proc_args = (world_size, i, dtype, distributed_init_port) + target_args
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
proc.start()
procs.append(proc)
for i in range(world_size):
procs[i].join()
assert procs[i].exitcode == 0, (
f"Process {i} failed with exit code {procs[i].exitcode}"
)
def align_memory(size: int):
align_size = 2 * 1024 * 1024
return (size + align_size - 1) // align_size * align_size
def _init_mnnvl_memory(world_size, rank, dtype, distributed_init_port):
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD
torch.cuda.set_device(rank)
MnnvlMemory.initialize()
mapping = Mapping(world_size, rank, world_size, tp_size=world_size)
allocate0_size = 4 * 1024 * 1024 - 3 * 1024
mnnvl_config = MnnvlConfig(
comm_backend=CustomCommunicator(group),
fabric_page_size=1 << 29, # 512MB
allocation_granularity=0, # Auto-detect
)
MnnvlMemory.set_comm_from_config(mapping, mnnvl_config)
mnnvl_memory0 = MnnvlMemory(mapping, allocate0_size)
allocate0_size_aligned = align_memory(allocate0_size)
assert MnnvlMemory.current_mem_offset == allocate0_size_aligned
tensor0 = mnnvl_memory0.as_torch_strided_tensor(torch.int32)
numel_per_rank = allocate0_size // 4
tensor0[(rank + 1) % world_size] = torch.arange(
start=rank, end=rank + numel_per_rank, device="cuda"
)
dist.barrier(group=group)
for r in range(world_size):
torch.equal(
tensor0[(r + 1) % world_size],
torch.arange(start=r, end=r + numel_per_rank, device="cuda"),
)
allocate1_size = 30 * 1024 * 1024 - 2 * 1024
mnnvl_memory1 = MnnvlMemory(mapping, allocate1_size)
allocate1_size_aligned = align_memory(allocate1_size)
assert (
MnnvlMemory.current_mem_offset
== allocate0_size_aligned + allocate1_size_aligned
)
tensor1 = mnnvl_memory1.as_torch_strided_tensor(torch.float32)
numel_per_rank = allocate1_size // 4
tensor1[(rank + 5) % world_size] = torch.arange(
start=rank,
end=rank + numel_per_rank,
dtype=torch.float32,
device="cuda",
)
dist.barrier(group=group)
for r in range(world_size):
torch.equal(
tensor1[(r + 5) % world_size],
torch.arange(
start=r, end=r + numel_per_rank, dtype=torch.float32, device="cuda"
),
)
dist.barrier(group=group)
del tensor0, mnnvl_memory0
dist.barrier(group=group)
large_allocation2_size = 768 * 1024 * 1024
large_mnnvl_memory2 = MnnvlMemory(mapping, large_allocation2_size)
allocate2_size_aligned = align_memory(large_allocation2_size)
assert MnnvlMemory.current_mem_offset == allocate2_size_aligned
assert large_mnnvl_memory2.rank_stride == (1 << 30)
del tensor1
@pytest.mark.skipif(
not MnnvlMemory.supports_mnnvl(),
reason="Mnnvl memory is not supported on this platform",
)
@pytest.mark.parametrize("world_size", [2, 4])
def test_mnnvl_custom_communicator(world_size):
dtype = torch.float16
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
raise ValueError(
f"world_size {world_size} is greater than available_gpus {available_gpus}"
)
print(f"Running test for world_size={world_size}")
multi_process_parallel(
world_size,
dtype,
_init_mnnvl_memory,
target_args=(),
)
print(f"custom mnnvl communicator world_size = {world_size}: OK")