sglang_v0.5.2/flashinfer_0.3.1/tests/test_trtllm_moe_allreduce_f...

472 lines
22 KiB
Python

import multiprocessing as mp
import socket
from typing import Any
import numpy as np
import pytest
import torch
import torch.distributed as dist
import flashinfer.comm as comm
# todo(Yingyi): add benchmark and quant test
# Usage: test var
kOneShotMaxTokenNum = 128
MAX_TOKEN_NUM = 2048
HIDDEN_SIZE = 7168
MAX_EXPERT_NUM = 16
SF_VEC_SIZE = 16
# temp var
SCALE_FACTOR_RANGE = (-1, 1)
def _run_correctness_worker(world_size, rank, dtype, distributed_init_port):
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
dist.init_process_group(
backend="nccl",
init_method=distributed_init_method,
rank=rank,
world_size=world_size,
)
group = dist.group.WORLD
try:
device = torch.device(f"cuda:{rank}")
token_nums = [
1,
64,
128,
256,
2048,
] # 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048
candidate_active_expert_num = [8, 12, 16]
# candidate_active_expert_num = [1] # debug-only
swizzled_layout_codes = [
comm.QuantizationSFLayout.LINEAR,
comm.QuantizationSFLayout.SWIZZLED_128x4,
comm.QuantizationSFLayout.SWIZZLED_8x4,
]
launch_with_pdls = [True, False]
# create workspace for moe allreduce fusion
ipc_handles, workspace_tensor = (
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
rank, world_size, MAX_TOKEN_NUM, HIDDEN_SIZE, group=group
)
)
test_loop = 5
for token_num in token_nums:
for active_expert_num in candidate_active_expert_num:
for swizzled_layout_code in swizzled_layout_codes:
for launch_with_pdl in launch_with_pdls:
dist.barrier(group=group)
test_passed = True
print(
f"test RANK {rank}: token{token_num}-expert{active_expert_num}-tp{world_size}-{dtype}-layout{swizzled_layout_code}-pdl{launch_with_pdl} start"
)
dist.barrier(group=group)
torch.cuda.synchronize()
for _ in range(test_loop):
message_size = token_num * HIDDEN_SIZE
residual_in = torch.randn(
message_size, dtype=dtype, device=device
)
residual_in_clone = residual_in.clone()
moe_allreduce_out = torch.zeros(
message_size, dtype=dtype, device=device
)
residual_out = torch.empty_like(residual_in)
norm_out = torch.empty_like(residual_in)
quant_out = torch.empty(
message_size // 4, dtype=dtype, device=device
) # quant: fp16/bf16 -> fp4, reference: cpp/tensorrt_llm/thop/allreduceOp.cpp:L487
scale_out = None
assert HIDDEN_SIZE % SF_VEC_SIZE == 0, (
"HIDDEN_SIZE must be divisible by SF_VEC_SIZE"
)
if (
swizzled_layout_code
== comm.QuantizationSFLayout.SWIZZLED_128x4
):
padded_message_size = (
comm.compute_fp4_swizzled_layout_sf_size(
token_num, HIDDEN_SIZE // SF_VEC_SIZE
)
)
scale_out = torch.empty(
padded_message_size, dtype=dtype, device=device
)
else:
scale_out = torch.empty(
message_size // SF_VEC_SIZE,
dtype=dtype,
device=device,
)
rms_gamma = torch.randn(
HIDDEN_SIZE, dtype=dtype, device=device
)
scale_factor = (
torch.rand(1, dtype=torch.float32, device=device)
* (SCALE_FACTOR_RANGE[1] - SCALE_FACTOR_RANGE[0])
+ SCALE_FACTOR_RANGE[0]
)
rms_eps = 1e-3
scale_factor_float = scale_factor.item()
# init moe params
# [device_num_expert, m]
moe_reduction_scale_input = torch.randn(
active_expert_num * token_num,
dtype=torch.float32,
device=device,
)
moe_reduction_scale_input_clone = (
moe_reduction_scale_input.clone()
)
# [device_num_expert, m, 7168]
moe_reduction_active_experts_token_input = torch.randn(
active_expert_num * message_size,
dtype=dtype,
device=device,
)
moe_reduction_active_experts_token_input_clone = (
moe_reduction_active_experts_token_input.clone()
)
# [m, 7168]
moe_reduction_token_input = torch.randn(
message_size, dtype=dtype, device=device
)
moe_reduction_token_input_clone = (
moe_reduction_token_input.clone()
)
# == Calculate reference output ==
# 1. MoE Reduction
moe_expert_out = (
moe_reduction_active_experts_token_input_clone.view(
active_expert_num, token_num, HIDDEN_SIZE
).to(torch.float32)
)
moe_scales = moe_reduction_scale_input_clone.view(
active_expert_num, token_num
).to(torch.float32)
moe_scales = moe_scales.unsqueeze(
2
) # [active_expert_num, token_num, 1]
scaled_expert_out = moe_expert_out * moe_scales.to(
torch.float32
) # [active_expert_num, token_num, HIDDEN_SIZE]
reduced_expert_out = torch.sum(
scaled_expert_out, dim=0
) # [token_num, HIDDEN_SIZE]
# 2. Add FC2 output
moe_out_ref = (
reduced_expert_out
+ moe_reduction_token_input_clone.view(
token_num, HIDDEN_SIZE
).to(torch.float32)
) # [token_num, HIDDEN_SIZE]
# 3. All-Reduce
moe_allreduce_ref = moe_out_ref.clone().to(dtype)
dist.all_reduce(moe_allreduce_ref, group=group)
moe_allreduce_ref = moe_allreduce_ref.to(torch.float32)
# 4. Fused Ops
ref_residual_out = (
moe_allreduce_ref
+ residual_in_clone.view(token_num, HIDDEN_SIZE).to(
torch.float32
)
)
variance = (
ref_residual_out.to(torch.float32)
.pow(2)
.mean(dim=-1, keepdim=True)
)
hidden_states = ref_residual_out * torch.rsqrt(
variance + rms_eps
)
ref_norm_out = rms_gamma.to(torch.float32) * hidden_states
# 5. Run kernel
# warmup
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3): # Multiple warmup iterations
comm.trtllm_moe_allreduce_fusion(
world_size=world_size,
world_rank=rank,
token_num=token_num,
hidden_dim=HIDDEN_SIZE,
workspace_ptrs=workspace_tensor,
launch_with_pdl=launch_with_pdl,
residual_in=residual_in,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_factor_float,
moe_reduction_device_num_experts=active_expert_num,
moe_reduction_scale_input=moe_reduction_scale_input,
moe_reduction_active_experts_token_input=moe_reduction_active_experts_token_input,
moe_reduction_token_input=moe_reduction_token_input,
layout_code=swizzled_layout_code,
moe_allreduce_out=moe_allreduce_out,
residual_out=residual_out,
norm_out=norm_out,
quant_out=quant_out,
scale_out=scale_out,
)
torch.cuda.current_stream().wait_stream(s)
torch.cuda.synchronize() # Ensure warmup is complete
# capture
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for _ in range(3): # Multiple iterations in graph
comm.trtllm_moe_allreduce_fusion(
world_size=world_size,
world_rank=rank,
token_num=token_num,
hidden_dim=HIDDEN_SIZE,
workspace_ptrs=workspace_tensor,
launch_with_pdl=launch_with_pdl,
residual_in=residual_in,
rms_gamma=rms_gamma,
rms_eps=rms_eps,
scale_factor=scale_factor_float,
moe_reduction_device_num_experts=active_expert_num,
moe_reduction_scale_input=moe_reduction_scale_input,
moe_reduction_active_experts_token_input=moe_reduction_active_experts_token_input,
moe_reduction_token_input=moe_reduction_token_input,
layout_code=swizzled_layout_code,
moe_allreduce_out=moe_allreduce_out,
residual_out=residual_out,
norm_out=norm_out,
quant_out=quant_out,
scale_out=scale_out,
)
# replay
g.replay()
# match shape
moe_allreduce_out = moe_allreduce_out.view(
token_num, HIDDEN_SIZE
)
residual_out = residual_out.view(token_num, HIDDEN_SIZE)
norm_out = norm_out.view(token_num, HIDDEN_SIZE)
torch.cuda.synchronize()
# 6. Check correctness
tolerance = 8e-2 if dtype == torch.float16 else 8e-1
# 6.1 Check allreduce_out
if not torch.allclose(
moe_allreduce_out.to(torch.float32),
moe_allreduce_ref,
atol=tolerance,
rtol=1e-2,
):
test_passed = False
print(f"Rank {rank} moe_allreduce_out mismatch")
print(f"moe_allreduce_out: {moe_allreduce_out}")
print(f"moe_allreduce_ref: {moe_allreduce_ref}")
# Print max diff elements for allreduce_out
max_diff = torch.max(
torch.abs(
moe_allreduce_out.to(torch.float32)
- moe_allreduce_ref
)
)
max_diff_idx = torch.argmax(
torch.abs(
moe_allreduce_out.to(torch.float32)
- moe_allreduce_ref
)
)
print(
f"Rank {rank} moe_allreduce_out max diff: {max_diff}"
)
print(
f"Rank {rank} moe_allreduce_out max diff idx: {max_diff_idx}"
)
print(
f"Rank {rank} moe_allreduce_out value at max diff: {moe_allreduce_out.view(-1)[max_diff_idx]}"
)
print(
f"Rank {rank} moe_allreduce_out ref value at max diff: {moe_allreduce_ref.view(-1)[max_diff_idx]}"
)
torch.testing.assert_close(
moe_allreduce_out.to(torch.float32),
moe_allreduce_ref,
atol=tolerance,
rtol=1e-2,
)
# 6.2 Check residual_out
if not torch.allclose(
residual_out.to(torch.float32),
ref_residual_out,
atol=tolerance,
rtol=1e-2,
):
test_passed = False
print(f"Rank {rank} residual_out mismatch")
print(f"residual_out: {residual_out}")
print(f"ref_residual_out: {ref_residual_out}")
# Print max diff elements for residual_out
max_diff = torch.max(
torch.abs(
residual_out.to(torch.float32)
- ref_residual_out
)
)
max_diff_idx = torch.argmax(
torch.abs(
residual_out.to(torch.float32)
- ref_residual_out
)
)
print(f"Rank {rank} residual_out max diff: {max_diff}")
print(
f"Rank {rank} residual_out max diff idx: {max_diff_idx}"
)
print(
f"Rank {rank} residual_out value at max diff: {residual_out.view(-1)[max_diff_idx]}"
)
print(
f"Rank {rank} residual_out ref value at max diff: {ref_residual_out.view(-1)[max_diff_idx]}"
)
torch.testing.assert_close(
residual_out.to(torch.float32),
ref_residual_out,
atol=tolerance,
rtol=1e-2,
)
# 6.3 Check norm_out
if not torch.allclose(
norm_out.to(torch.float32),
ref_norm_out,
atol=tolerance,
rtol=1e-2,
):
test_passed = False
print(f"Rank {rank} norm_out mismatch")
print(f"norm_out: {norm_out}")
print(f"ref_norm_out: {ref_norm_out}")
# Print max diff elements for norm_out
max_diff = torch.max(
torch.abs(norm_out.to(torch.float32) - ref_norm_out)
)
max_diff_idx = torch.argmax(
torch.abs(norm_out.to(torch.float32) - ref_norm_out)
)
print(f"Rank {rank} norm_out max diff: {max_diff}")
print(
f"Rank {rank} norm_out max diff idx: {max_diff_idx}"
)
print(
f"Rank {rank} norm_out value at max diff: {norm_out.view(-1)[max_diff_idx]}"
)
print(
f"Rank {rank} norm_out ref value at max diff: {ref_norm_out.view(-1)[max_diff_idx]}"
)
torch.testing.assert_close(
norm_out.to(torch.float32),
ref_norm_out,
atol=tolerance,
rtol=1e-2,
)
# 6.4 Check quant_out
# todo
dist.barrier(group=group)
if test_passed:
print(
f"test RANK {rank}: token{token_num}-expert{active_expert_num}-tp{world_size}-{dtype}-layout{swizzled_layout_code}-pdl{launch_with_pdl} passed"
)
else:
print(
f"test RANK {rank}: token{token_num}-expert{active_expert_num}-tp{world_size}-{dtype}-layout{swizzled_layout_code}-pdl{launch_with_pdl} failed"
)
finally:
dist.barrier(group=group)
comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group)
dist.destroy_process_group(group=group)
def get_open_port() -> int:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
except OSError:
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
s.bind(("::1", 0))
return s.getsockname()[1]
def multi_process_parallel(
world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = ()
) -> None:
mp.set_start_method("spawn", force=True)
procs = []
distributed_init_port = get_open_port()
for i in range(world_size):
proc_args = (world_size, i, dtype, distributed_init_port) + target_args
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
proc.start()
procs.append(proc)
for i in range(world_size):
procs[i].join()
assert procs[i].exitcode == 0, (
f"Process {i} failed with exit code {procs[i].exitcode}"
)
@pytest.mark.parametrize("world_size", [2, 4])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_trtllm_moe_allreduce_fusion(world_size, dtype):
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
available_gpus = torch.cuda.device_count()
if world_size > available_gpus:
pytest.skip(
f"world_size {world_size} is greater than available_gpus {available_gpus}"
)
print(f"Running test for world_size={world_size}")
multi_process_parallel(
world_size,
dtype,
_run_correctness_worker,
target_args=(),
)
print(f"moe allreduce fusion tp = {world_size}: OK")
if __name__ == "__main__":
test_trtllm_moe_allreduce_fusion(2, torch.float16)