sglang_v0.5.2/flashinfer_0.3.1/tests/test_create_ipc_buffer.py

91 lines
2.6 KiB
Python

# adapted from vllm
import ctypes
import os
import subprocess
import sys
import pytest
import torch
import torch.distributed as dist
import flashinfer.comm as comm
from flashinfer.comm import CudaRTLibrary
def _run_ipc_test():
# Only run if we're inside a distributed context
if not dist.is_initialized():
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
cudart = CudaRTLibrary()
cudart.cudaSetDevice(rank)
buffer_size_in_bytes = 1024
byte_value = rank
pointers = comm.create_shared_buffer(buffer_size_in_bytes)
print(f"Rank {rank} init ipc buffer {pointers}", flush=True)
dist.barrier()
torch.cuda.synchronize()
for p in pointers:
pointer = ctypes.c_void_p(p + rank * (buffer_size_in_bytes // world_size))
cudart.cudaMemset(pointer, byte_value, buffer_size_in_bytes // world_size)
dist.barrier()
torch.cuda.synchronize()
host_data = (ctypes.c_char * buffer_size_in_bytes)()
for p in pointers:
for cur_rank in range(world_size):
offset_pointer = ctypes.c_void_p(
p + cur_rank * (buffer_size_in_bytes // world_size)
)
cudart.cudaMemcpy(
host_data, offset_pointer, buffer_size_in_bytes // world_size
)
for i in range(buffer_size_in_bytes // world_size):
assert ord(host_data[i]) == cur_rank, (
f"Rank {rank} failed to verify buffer {p}. "
f"Expected {cur_rank}, got {ord(host_data[i])}"
)
print(f"Rank {rank} verified all buffers.\n", flush=True)
dist.barrier()
torch.cuda.synchronize()
comm.free_shared_buffer(pointers)
# -------------------------------
# Pytest Entrypoint (main test)
# -------------------------------
@pytest.mark.parametrize("world_size", [2, 4])
def test_ipc_distributed(world_size):
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
pytest.skip(
f"world_size {world_size} is greater than available_gpus {available_gpus}"
)
# Spawn self with torchrun
script = os.path.abspath(__file__)
result = subprocess.run(
["torchrun", f"--nproc_per_node={world_size}", script, "--run_ipc_test"],
stdout=sys.stdout,
stderr=sys.stderr,
)
assert result.returncode == 0
# -------------------------------
# Actual Test Logic (called by subprocess)
# -------------------------------
if __name__ == "__main__":
if "--run_ipc_test" in sys.argv:
_run_ipc_test()