sglang_v0.5.2/flashinfer_0.3.1/tests/test_trtllm_allreduce.py

293 lines
13 KiB
Python

import multiprocessing as mp
import socket
from typing import Any
import pytest
import torch
import torch.distributed as dist
import flashinfer.comm as comm
"""
NOTE:
The assertion of result closeness is disabled for now,
since assertion fails for some cases, which breaks the tests and introduces NCCL timeout.
Trt-llm encourage using certain shapes for this custom all-reduce kernel,
hidden_size in range [256, 8192], and maxHiddenSize should be 8192.
The recommended case is [1024, 2048, 4096, 8192].
If new trt-llm source kernels are available (function name starts with "trtllm_"), we would recommend using them.
"""
maxBatchSize = 1
maxBeamWidth = 3
maxTokenNum = 128
maxHiddenSize = 4096 # max hidden size for all reduce
RANDOM_SEED = 42
def _run_correctness_worker(world_size, rank, dtype, distributed_init_port):
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD
try:
device = torch.device(f"cuda:{rank}")
token_nums = [64, 128]
strategy_codes = [
comm.AllReduceStrategyType.ONESHOT,
comm.AllReduceStrategyType.TWOSHOT,
]
# below are the recommended hidden sizes for custom all-reduce in trtllm test
# hidden_size should be in range [256, 8192], and maxHiddenSize should be 8192
hidden_sizes = [1024, 4096]
config_codes = [
0,
comm.AllReduceStrategyConfig.USE_MEMCPY,
comm.AllReduceStrategyConfig.PUSH_MODE,
]
fusion_op_codes = [
comm.AllReduceFusionOp.NONE,
comm.AllReduceFusionOp.RESIDUAL_RMS_NORM,
comm.AllReduceFusionOp.RESIDUAL_RMS_PREPOST_NORM,
# Below are not enabled for custom all-reduce in trtllm test, skip
# comm.AllReduceFusionOp.LAST_PROCESS_FOR_UB,
# comm.AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8,
# comm.AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4,
# comm.AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8,
# comm.AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4,
# comm.AllReduceFusionOp.MOE_ALLREDUCE_RESIDUAL_RMS_NORM,
]
launch_with_pdls = [True, False]
# create ipc memory
workspace = comm.trtllm_create_ipc_workspace_for_all_reduce(
rank=rank,
tp_size=world_size,
max_token_num=maxTokenNum,
hidden_dim=maxHiddenSize,
group=group,
)
test_loop = 2 # could be any number
# NOTE: the barrier flag should be initialized to 1, and incremented by 1 for each AR
flag_value = 1
for token_num in token_nums:
for hidden_size in hidden_sizes:
for strategy_code in strategy_codes:
for config_code in config_codes:
for fusion_op_code in fusion_op_codes:
for launch_with_pdl in launch_with_pdls:
pass_flag = True
if (
strategy_code == comm.AllReduceStrategyType.TWOSHOT
and fusion_op_code
== comm.AllReduceFusionOp.RESIDUAL_RMS_PREPOST_NORM
):
# skip twoshot pre-post norm: not supported in trtllm test
continue
print(
f"test RANK {rank}: {world_size}-{dtype}-{strategy_code}-{config_code}-{fusion_op_code}-{launch_with_pdl}-{hidden_size} start"
)
torch.cuda.synchronize()
for _ in range(test_loop):
message_size = token_num * hidden_size
inp1 = torch.randn(
message_size, dtype=dtype, device=device
)
inp1_ref = inp1.clone()
out1 = torch.empty_like(inp1)
# init params for each fusion op
bias = torch.randn(
hidden_size, dtype=dtype, device=device
)
residual = torch.randn(
message_size, dtype=dtype, device=device
)
weight = torch.randn(
hidden_size, dtype=dtype, device=device
)
weight_pre_residual_norm = torch.randn(
hidden_size, dtype=dtype, device=device
)
eps = 1e-6
intermediate_buffer = torch.zeros(
message_size, dtype=dtype, device=device
)
comm.trtllm_custom_all_reduce(
inp=inp1,
out=out1,
tp_size=world_size,
tp_rank=rank,
token_num=token_num,
fusion_op_code=fusion_op_code,
strategy_code=strategy_code,
config_code=config_code,
launch_with_pdl=launch_with_pdl,
flag_value=flag_value,
peer_comm_buffer_ptrs=torch.tensor(
workspace[0], dtype=torch.int64
),
peer_barrier_ptrs_in=torch.tensor(
workspace[2], dtype=torch.int64
),
peer_barrier_ptrs_out=torch.tensor(
workspace[3], dtype=torch.int64
),
bias=bias,
residual=residual,
weight=weight,
weight_pre_residual_norm=weight_pre_residual_norm,
eps=eps,
intermediate_buffer=intermediate_buffer,
lamport_peer_comm_buffer_ptrs_0=torch.tensor(
workspace[4], dtype=torch.int64
),
lamport_peer_comm_buffer_ptrs_1=torch.tensor(
workspace[5], dtype=torch.int64
),
lamport_peer_comm_buffer_ptrs_2=torch.tensor(
workspace[6], dtype=torch.int64
),
)
dist.all_reduce(inp1_ref, group=group)
tolerance = 1e-2 if dtype == torch.float16 else 8e-2
if fusion_op_code == comm.AllReduceFusionOp.NONE:
torch.testing.assert_close(
out1, inp1_ref, atol=tolerance, rtol=3e-2
)
elif (
fusion_op_code
== comm.AllReduceFusionOp.RESIDUAL_RMS_NORM
):
# cache intermediate_buffer to inter_buffer
inter_buffer = intermediate_buffer.clone()
# residual and bias
ref = inp1_ref.clone()
ref_float = ref.to(torch.float32)
residual_float = residual.to(torch.float32)
bias_float = bias.to(torch.float32)
for i in range(ref.numel()):
ref_float[i] += (
residual_float[i]
+ bias_float[i % hidden_size]
)
ref_half = ref_float.to(dtype)
torch.testing.assert_close(
inter_buffer,
ref_half,
atol=tolerance,
rtol=3e-2,
)
# RMSNorm over hidden size
ref_float = ref_float.view(
token_num, hidden_size
)
normed_float = torch.empty_like(ref_float)
mean_sq = torch.mean(
ref_float * ref_float, dim=-1, keepdim=True
)
denom = torch.sqrt(mean_sq + eps)
normed_float = ref_float / denom
normed_float = normed_float * weight.to(
torch.float32
)
normed_half = normed_float.to(dtype)
torch.testing.assert_close(
out1,
normed_half.view(-1),
atol=tolerance,
rtol=3e-2,
)
elif (
fusion_op_code
== comm.AllReduceFusionOp.RESIDUAL_RMS_PREPOST_NORM
):
# NOTE(yingyi): bugfix todo, the test invokes nccl timeout for now
pass
flag_value += 1
if pass_flag:
print(
f"test RANK {rank}: {world_size}-{dtype}-{strategy_code}-{config_code}-{fusion_op_code}-{launch_with_pdl}-{hidden_size} passed"
)
# torch.cuda.synchronize()
# # you might want to enable this barrier for a better log output, but it's not mandatory across allReduce calls
finally:
dist.barrier(group=group)
comm.trtllm_destroy_ipc_workspace_for_all_reduce(workspace, group=group)
dist.destroy_process_group(group=group)
def get_open_port() -> int:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("::1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = ()
) -> None:
mp.set_start_method("spawn", force=True)
procs = []
distributed_init_port = get_open_port()
for i in range(world_size):
proc_args = (world_size, i, dtype, distributed_init_port) + target_args
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
proc.start()
procs.append(proc)
for i in range(world_size):
procs[i].join()
assert procs[i].exitcode == 0, (
f"Process {i} failed with exit code {procs[i].exitcode}"
)
@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_trtllm_custom_allreduce(world_size, dtype):
torch.manual_seed(RANDOM_SEED)
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
pytest.skip(
f"world_size {world_size} is greater than available_gpus {available_gpus}"
)
print(f"Running test for world_size={world_size}")
multi_process_parallel(
world_size,
dtype,
_run_correctness_worker,
target_args=(),
)
print(f"custom allreduce tp = {world_size}: OK")