213 lines
8.3 KiB
Python
213 lines
8.3 KiB
Python
import os
|
|
import random
|
|
import socket
|
|
import unittest
|
|
from typing import Any
|
|
|
|
import ray
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from sglang.srt.distributed import init_distributed_environment
|
|
from sglang.srt.distributed.communication_op import ( # noqa
|
|
tensor_model_parallel_all_reduce,
|
|
)
|
|
from sglang.srt.distributed.device_communicators.quick_all_reduce import (
|
|
qr_rocm_arch_available,
|
|
)
|
|
from sglang.srt.distributed.parallel_state import (
|
|
get_tensor_model_parallel_group,
|
|
graph_capture,
|
|
initialize_model_parallel,
|
|
)
|
|
from sglang.test.test_utils import CustomTestCase
|
|
|
|
torch.manual_seed(42)
|
|
random.seed(44) # keep the deterministic seed
|
|
|
|
|
|
def get_open_port() -> int:
|
|
# try ipv4
|
|
try:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("", 0))
|
|
return s.getsockname()[1]
|
|
except OSError:
|
|
# try ipv6
|
|
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
|
s.bind(("", 0))
|
|
return s.getsockname()[1]
|
|
|
|
|
|
def multi_process_parallel(
|
|
world_size: int, cls: Any, test_target: Any, quant_mode: str
|
|
) -> None:
|
|
|
|
# Using ray helps debugging the error when it failed
|
|
# as compared to multiprocessing.
|
|
# NOTE: We need to set working_dir for distributed tests,
|
|
# otherwise we may get import errors on ray workers
|
|
|
|
ray.init(log_to_driver=True)
|
|
|
|
distributed_init_port = get_open_port()
|
|
refs = []
|
|
for rank in range(world_size):
|
|
refs.append(
|
|
test_target.remote(cls, world_size, rank, distributed_init_port, quant_mode)
|
|
)
|
|
ray.get(refs)
|
|
|
|
ray.shutdown()
|
|
|
|
|
|
class TestQuickAllReduce(CustomTestCase):
|
|
TEST_SIZES = [
|
|
2 * 1024 * 1024,
|
|
4 * 1024 * 1024,
|
|
8 * 1024 * 1024,
|
|
16 * 1024 * 1024,
|
|
32 * 1024 * 1024,
|
|
]
|
|
TEST_LOOP = 5
|
|
# Too many configurations can lead to a test grid that is too large
|
|
# The tp takes too long to boot,let's just choose 4 out of 12 configurations
|
|
# WORLD_SIZES = [2, 4, 8]
|
|
# QUANT_MODE = ["FP", "INT8", "INT6", "INT4"]
|
|
QUANT_MODE_WORLD_SIZE_PART = [["FP", 8], ["INT4", 4], ["INT8", 2], ["INT6", 2]]
|
|
|
|
@unittest.skipIf(
|
|
not qr_rocm_arch_available(),
|
|
"Only test Quick AllReduce on ROCm architectures >= gfx94*",
|
|
)
|
|
def test_graph_allreduce(self):
|
|
for quant_mode_world_size_part in self.QUANT_MODE_WORLD_SIZE_PART:
|
|
quant_mode = quant_mode_world_size_part[0]
|
|
world_size = quant_mode_world_size_part[1]
|
|
if world_size > torch.cuda.device_count():
|
|
continue
|
|
multi_process_parallel(world_size, self, self.graph_allreduce, quant_mode)
|
|
|
|
@unittest.skipIf(
|
|
not qr_rocm_arch_available(),
|
|
"Only test Quick AllReduce on ROCm architectures >= gfx94*",
|
|
)
|
|
def test_eager_allreduce(self):
|
|
for quant_mode_world_size_part in self.QUANT_MODE_WORLD_SIZE_PART:
|
|
quant_mode = quant_mode_world_size_part[0]
|
|
world_size = quant_mode_world_size_part[1]
|
|
if world_size > torch.cuda.device_count():
|
|
continue
|
|
multi_process_parallel(world_size, self, self.eager_allreduce, quant_mode)
|
|
|
|
@ray.remote(num_gpus=1, max_calls=1)
|
|
def graph_allreduce(self, world_size, rank, distributed_init_port, quant_mode):
|
|
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
|
os.environ["ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_mode
|
|
os.environ["ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "0"
|
|
device = torch.device(f"cuda:{rank}")
|
|
torch.cuda.set_device(device)
|
|
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
|
init_distributed_environment(
|
|
world_size=world_size,
|
|
rank=rank,
|
|
distributed_init_method=distributed_init_method,
|
|
local_rank=rank,
|
|
)
|
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
|
group = get_tensor_model_parallel_group().device_group
|
|
|
|
# A small all_reduce for warmup.
|
|
# this is needed because device communicators might be created lazily
|
|
# (e.g. NCCL). This will ensure that the communicator is initialized
|
|
# before any communication happens, so that this group can be used for
|
|
# graph capture immediately.
|
|
data = torch.zeros(1)
|
|
data = data.to(device=device)
|
|
torch.distributed.all_reduce(data, group=group)
|
|
torch.cuda.synchronize()
|
|
del data
|
|
|
|
for sz in self.TEST_SIZES:
|
|
for dtype in [torch.float16, torch.bfloat16]:
|
|
for _ in range(self.TEST_LOOP):
|
|
with graph_capture() as graph_capture_context:
|
|
# use integers so result matches NCCL exactly
|
|
inp1 = torch.randint(
|
|
1,
|
|
23,
|
|
(sz,),
|
|
dtype=dtype,
|
|
device=torch.cuda.current_device(),
|
|
)
|
|
inp2 = torch.randint(
|
|
-23,
|
|
1,
|
|
(sz,),
|
|
dtype=dtype,
|
|
device=torch.cuda.current_device(),
|
|
)
|
|
torch.cuda.synchronize()
|
|
graph = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(
|
|
graph, stream=graph_capture_context.stream
|
|
):
|
|
out1 = tensor_model_parallel_all_reduce(inp1)
|
|
# the input buffer is immediately modified to test
|
|
# synchronization
|
|
dist.all_reduce(inp1, group=group)
|
|
out2 = tensor_model_parallel_all_reduce(inp2)
|
|
dist.all_reduce(inp2, group=group)
|
|
graph.replay()
|
|
atol = 1.25 * world_size
|
|
rtol = 0.5 * world_size
|
|
for inp, out in [[inp1, out1], [inp2, out2]]:
|
|
torch.testing.assert_close(out, inp, atol=atol, rtol=rtol)
|
|
# try:
|
|
# torch.testing.assert_close(out, inp, atol=atol, rtol=rtol)
|
|
# except AssertionError as e:
|
|
# print("Max abs diff:", (out - inp).abs().max())
|
|
# print("Max rel diff:", ((out - inp).abs() / inp.abs().clamp(min=1e-5)).max())
|
|
|
|
@ray.remote(num_gpus=1, max_calls=1)
|
|
def eager_allreduce(self, world_size, rank, distributed_init_port, quant_mode):
|
|
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
|
os.environ["ROCM_QUICK_REDUCE_QUANTIZATION"] = quant_mode
|
|
os.environ["ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16"] = "0"
|
|
device = torch.device(f"cuda:{rank}")
|
|
torch.cuda.set_device(device)
|
|
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
|
init_distributed_environment(
|
|
world_size=world_size,
|
|
rank=rank,
|
|
distributed_init_method=distributed_init_method,
|
|
local_rank=rank,
|
|
)
|
|
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
|
group = get_tensor_model_parallel_group().device_group
|
|
|
|
for sz in self.TEST_SIZES:
|
|
for dtype in [torch.float16, torch.bfloat16]:
|
|
for _ in range(self.TEST_LOOP):
|
|
inp1 = torch.randint(
|
|
1,
|
|
23,
|
|
(sz,),
|
|
dtype=dtype,
|
|
device=torch.cuda.current_device(),
|
|
)
|
|
out1 = tensor_model_parallel_all_reduce(inp1)
|
|
dist.all_reduce(inp1, group=group)
|
|
atol = 1.25 * world_size
|
|
rtol = 0.5 * world_size
|
|
torch.testing.assert_close(out1, inp1, atol=atol, rtol=rtol)
|
|
# try:
|
|
# torch.testing.assert_close(out1, inp1, atol=atol, rtol=rtol)
|
|
# except AssertionError as e:
|
|
# print("Max abs diff:", (out1 - inp1).abs().max())
|
|
# print("Max rel diff:", ((out1 - inp1).abs() / inp1.abs().clamp(min=1e-5)).max())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|