sglang_v0.5.2/sglang/test/srt/test_quick_allreduce.py

213 lines
8.3 KiB
Python

import os
import random
import socket
import unittest
from typing import Any
import ray
import torch
import torch.distributed as dist
from sglang.srt.distributed import init_distributed_environment
from sglang.srt.distributed.communication_op import ( # noqa
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
qr_rocm_arch_available,
)
from sglang.srt.distributed.parallel_state import (
get_tensor_model_parallel_group,
graph_capture,
initialize_model_parallel,
)
from sglang.test.test_utils import CustomTestCase
torch.manual_seed(42)
random.seed(44) # keep the deterministic seed
def get_open_port() -> int:
# try ipv4
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
except OSError:
# try ipv6
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int, cls: Any, test_target: Any, quant_mode: str
) -> None:
# Using ray helps debugging the error when it failed
# as compared to multiprocessing.
# NOTE: We need to set working_dir for distributed tests,
# otherwise we may get import errors on ray workers
ray.init(log_to_driver=True)
distributed_init_port = get_open_port()
refs = []
for rank in range(world_size):
refs.append(
test_target.remote(cls, world_size, rank, distributed_init_port, quant_mode)
)
ray.get(refs)
ray.shutdown()
class TestQuickAllReduce(CustomTestCase):
TEST_SIZES = [
2 * 1024 * 1024,
4 * 1024 * 1024,
8 * 1024 * 1024,
16 * 1024 * 1024,
32 * 1024 * 1024,
]
TEST_LOOP = 5
# Too many configurations can lead to a test grid that is too large
# The tp takes too long to boot,let's just choose 4 out of 12 configurations
# WORLD_SIZES = [2, 4, 8]
# QUANT_MODE = ["FP", "INT8", "INT6", "INT4"]
QUANT_MODE_WORLD_SIZE_PART = [["FP", 8], ["INT4", 4], ["INT8", 2], ["INT6", 2]]
@unittest.skipIf(
not qr_rocm_arch_available(),
"Only test Quick AllReduce on ROCm architectures >= gfx94*",
)
def test_graph_allreduce(self):
for quant_mode_world_size_part in self.QUANT_MODE_WORLD_SIZE_PART:
quant_mode = quant_mode_world_size_part[0]
world_size = quant_mode_world_size_part[1]
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.graph_allreduce, quant_mode)
@unittest.skipIf(
not qr_rocm_arch_available(),
"Only test Quick AllReduce on ROCm architectures >= gfx94*",
)
def test_eager_allreduce(self):
for quant_mode_world_size_part in self.QUANT_MODE_WORLD_SIZE_PART:
quant_mode = quant_mode_world_size_part[0]
world_size = quant_mode_world_size_part[1]
if world_size > torch.cuda.device_count():
continue
multi_process_parallel(world_size, self, self.eager_allreduce, quant_mode)
@ray.remote(num_gpus=1, max_calls=1)
def graph_allreduce(self, world_size, rank, distributed_init_port, quant_mode):
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
os.environ["ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_mode
os.environ["ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "0"
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
world_size=world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=rank,
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
group = get_tensor_model_parallel_group().device_group
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data = torch.zeros(1)
data = data.to(device=device)
torch.distributed.all_reduce(data, group=group)
torch.cuda.synchronize()
del data
for sz in self.TEST_SIZES:
for dtype in [torch.float16, torch.bfloat16]:
for _ in range(self.TEST_LOOP):
with graph_capture() as graph_capture_context:
# use integers so result matches NCCL exactly
inp1 = torch.randint(
1,
23,
(sz,),
dtype=dtype,
device=torch.cuda.current_device(),
)
inp2 = torch.randint(
-23,
1,
(sz,),
dtype=dtype,
device=torch.cuda.current_device(),
)
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(
graph, stream=graph_capture_context.stream
):
out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test
# synchronization
dist.all_reduce(inp1, group=group)
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2, group=group)
graph.replay()
atol = 1.25 * world_size
rtol = 0.5 * world_size
for inp, out in [[inp1, out1], [inp2, out2]]:
torch.testing.assert_close(out, inp, atol=atol, rtol=rtol)
# try:
# torch.testing.assert_close(out, inp, atol=atol, rtol=rtol)
# except AssertionError as e:
# print("Max abs diff:", (out - inp).abs().max())
# print("Max rel diff:", ((out - inp).abs() / inp.abs().clamp(min=1e-5)).max())
@ray.remote(num_gpus=1, max_calls=1)
def eager_allreduce(self, world_size, rank, distributed_init_port, quant_mode):
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
os.environ["ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_mode
os.environ["ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "0"
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
init_distributed_environment(
world_size=world_size,
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=rank,
)
initialize_model_parallel(tensor_model_parallel_size=world_size)
group = get_tensor_model_parallel_group().device_group
for sz in self.TEST_SIZES:
for dtype in [torch.float16, torch.bfloat16]:
for _ in range(self.TEST_LOOP):
inp1 = torch.randint(
1,
23,
(sz,),
dtype=dtype,
device=torch.cuda.current_device(),
)
out1 = tensor_model_parallel_all_reduce(inp1)
dist.all_reduce(inp1, group=group)
atol = 1.25 * world_size
rtol = 0.5 * world_size
torch.testing.assert_close(out1, inp1, atol=atol, rtol=rtol)
# try:
# torch.testing.assert_close(out1, inp1, atol=atol, rtol=rtol)
# except AssertionError as e:
# print("Max abs diff:", (out1 - inp1).abs().max())
# print("Max rel diff:", ((out1 - inp1).abs() / inp1.abs().clamp(min=1e-5)).max())
if __name__ == "__main__":
unittest.main()