472 lines
22 KiB
Python
472 lines
22 KiB
Python
import multiprocessing as mp
|
|
import socket
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
import flashinfer.comm as comm
|
|
|
|
# todo(Yingyi): add benchmark and quant test
|
|
|
|
# Usage: test var
|
|
kOneShotMaxTokenNum = 128
|
|
MAX_TOKEN_NUM = 2048
|
|
HIDDEN_SIZE = 7168
|
|
MAX_EXPERT_NUM = 16
|
|
SF_VEC_SIZE = 16
|
|
|
|
# temp var
|
|
SCALE_FACTOR_RANGE = (-1, 1)
|
|
|
|
|
|
def _run_correctness_worker(world_size, rank, dtype, distributed_init_port):
|
|
device = torch.device(f"cuda:{rank}")
|
|
torch.cuda.set_device(device)
|
|
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
|
dist.init_process_group(
|
|
backend="nccl",
|
|
init_method=distributed_init_method,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|
|
group = dist.group.WORLD
|
|
|
|
try:
|
|
device = torch.device(f"cuda:{rank}")
|
|
token_nums = [
|
|
1,
|
|
64,
|
|
128,
|
|
256,
|
|
2048,
|
|
] # 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048
|
|
candidate_active_expert_num = [8, 12, 16]
|
|
# candidate_active_expert_num = [1] # debug-only
|
|
swizzled_layout_codes = [
|
|
comm.QuantizationSFLayout.LINEAR,
|
|
comm.QuantizationSFLayout.SWIZZLED_128x4,
|
|
comm.QuantizationSFLayout.SWIZZLED_8x4,
|
|
]
|
|
launch_with_pdls = [True, False]
|
|
|
|
# create workspace for moe allreduce fusion
|
|
ipc_handles, workspace_tensor = (
|
|
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
|
|
rank, world_size, MAX_TOKEN_NUM, HIDDEN_SIZE, group=group
|
|
)
|
|
)
|
|
|
|
test_loop = 5
|
|
|
|
for token_num in token_nums:
|
|
for active_expert_num in candidate_active_expert_num:
|
|
for swizzled_layout_code in swizzled_layout_codes:
|
|
for launch_with_pdl in launch_with_pdls:
|
|
dist.barrier(group=group)
|
|
test_passed = True
|
|
print(
|
|
f"test RANK {rank}: token{token_num}-expert{active_expert_num}-tp{world_size}-{dtype}-layout{swizzled_layout_code}-pdl{launch_with_pdl} start"
|
|
)
|
|
dist.barrier(group=group)
|
|
torch.cuda.synchronize()
|
|
for _ in range(test_loop):
|
|
message_size = token_num * HIDDEN_SIZE
|
|
|
|
residual_in = torch.randn(
|
|
message_size, dtype=dtype, device=device
|
|
)
|
|
residual_in_clone = residual_in.clone()
|
|
|
|
moe_allreduce_out = torch.zeros(
|
|
message_size, dtype=dtype, device=device
|
|
)
|
|
residual_out = torch.empty_like(residual_in)
|
|
norm_out = torch.empty_like(residual_in)
|
|
quant_out = torch.empty(
|
|
message_size // 4, dtype=dtype, device=device
|
|
) # quant: fp16/bf16 -> fp4, reference: cpp/tensorrt_llm/thop/allreduceOp.cpp:L487
|
|
|
|
scale_out = None
|
|
assert HIDDEN_SIZE % SF_VEC_SIZE == 0, (
|
|
"HIDDEN_SIZE must be divisible by SF_VEC_SIZE"
|
|
)
|
|
if (
|
|
swizzled_layout_code
|
|
== comm.QuantizationSFLayout.SWIZZLED_128x4
|
|
):
|
|
padded_message_size = (
|
|
comm.compute_fp4_swizzled_layout_sf_size(
|
|
token_num, HIDDEN_SIZE // SF_VEC_SIZE
|
|
)
|
|
)
|
|
scale_out = torch.empty(
|
|
padded_message_size, dtype=dtype, device=device
|
|
)
|
|
else:
|
|
scale_out = torch.empty(
|
|
message_size // SF_VEC_SIZE,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
rms_gamma = torch.randn(
|
|
HIDDEN_SIZE, dtype=dtype, device=device
|
|
)
|
|
scale_factor = (
|
|
torch.rand(1, dtype=torch.float32, device=device)
|
|
* (SCALE_FACTOR_RANGE[1] - SCALE_FACTOR_RANGE[0])
|
|
+ SCALE_FACTOR_RANGE[0]
|
|
)
|
|
rms_eps = 1e-3
|
|
scale_factor_float = scale_factor.item()
|
|
|
|
# init moe params
|
|
# [device_num_expert, m]
|
|
moe_reduction_scale_input = torch.randn(
|
|
active_expert_num * token_num,
|
|
dtype=torch.float32,
|
|
device=device,
|
|
)
|
|
|
|
moe_reduction_scale_input_clone = (
|
|
moe_reduction_scale_input.clone()
|
|
)
|
|
|
|
# [device_num_expert, m, 7168]
|
|
moe_reduction_active_experts_token_input = torch.randn(
|
|
active_expert_num * message_size,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
moe_reduction_active_experts_token_input_clone = (
|
|
moe_reduction_active_experts_token_input.clone()
|
|
)
|
|
# [m, 7168]
|
|
moe_reduction_token_input = torch.randn(
|
|
message_size, dtype=dtype, device=device
|
|
)
|
|
moe_reduction_token_input_clone = (
|
|
moe_reduction_token_input.clone()
|
|
)
|
|
|
|
# == Calculate reference output ==
|
|
# 1. MoE Reduction
|
|
moe_expert_out = (
|
|
moe_reduction_active_experts_token_input_clone.view(
|
|
active_expert_num, token_num, HIDDEN_SIZE
|
|
).to(torch.float32)
|
|
)
|
|
moe_scales = moe_reduction_scale_input_clone.view(
|
|
active_expert_num, token_num
|
|
).to(torch.float32)
|
|
moe_scales = moe_scales.unsqueeze(
|
|
2
|
|
) # [active_expert_num, token_num, 1]
|
|
scaled_expert_out = moe_expert_out * moe_scales.to(
|
|
torch.float32
|
|
) # [active_expert_num, token_num, HIDDEN_SIZE]
|
|
reduced_expert_out = torch.sum(
|
|
scaled_expert_out, dim=0
|
|
) # [token_num, HIDDEN_SIZE]
|
|
|
|
# 2. Add FC2 output
|
|
moe_out_ref = (
|
|
reduced_expert_out
|
|
+ moe_reduction_token_input_clone.view(
|
|
token_num, HIDDEN_SIZE
|
|
).to(torch.float32)
|
|
) # [token_num, HIDDEN_SIZE]
|
|
|
|
# 3. All-Reduce
|
|
moe_allreduce_ref = moe_out_ref.clone().to(dtype)
|
|
dist.all_reduce(moe_allreduce_ref, group=group)
|
|
moe_allreduce_ref = moe_allreduce_ref.to(torch.float32)
|
|
|
|
# 4. Fused Ops
|
|
ref_residual_out = (
|
|
moe_allreduce_ref
|
|
+ residual_in_clone.view(token_num, HIDDEN_SIZE).to(
|
|
torch.float32
|
|
)
|
|
)
|
|
|
|
variance = (
|
|
ref_residual_out.to(torch.float32)
|
|
.pow(2)
|
|
.mean(dim=-1, keepdim=True)
|
|
)
|
|
hidden_states = ref_residual_out * torch.rsqrt(
|
|
variance + rms_eps
|
|
)
|
|
ref_norm_out = rms_gamma.to(torch.float32) * hidden_states
|
|
|
|
# 5. Run kernel
|
|
# warmup
|
|
s = torch.cuda.Stream()
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
with torch.cuda.stream(s):
|
|
for _ in range(3): # Multiple warmup iterations
|
|
comm.trtllm_moe_allreduce_fusion(
|
|
world_size=world_size,
|
|
world_rank=rank,
|
|
token_num=token_num,
|
|
hidden_dim=HIDDEN_SIZE,
|
|
workspace_ptrs=workspace_tensor,
|
|
launch_with_pdl=launch_with_pdl,
|
|
residual_in=residual_in,
|
|
rms_gamma=rms_gamma,
|
|
rms_eps=rms_eps,
|
|
scale_factor=scale_factor_float,
|
|
moe_reduction_device_num_experts=active_expert_num,
|
|
moe_reduction_scale_input=moe_reduction_scale_input,
|
|
moe_reduction_active_experts_token_input=moe_reduction_active_experts_token_input,
|
|
moe_reduction_token_input=moe_reduction_token_input,
|
|
layout_code=swizzled_layout_code,
|
|
moe_allreduce_out=moe_allreduce_out,
|
|
residual_out=residual_out,
|
|
norm_out=norm_out,
|
|
quant_out=quant_out,
|
|
scale_out=scale_out,
|
|
)
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
torch.cuda.synchronize() # Ensure warmup is complete
|
|
|
|
# capture
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g):
|
|
for _ in range(3): # Multiple iterations in graph
|
|
comm.trtllm_moe_allreduce_fusion(
|
|
world_size=world_size,
|
|
world_rank=rank,
|
|
token_num=token_num,
|
|
hidden_dim=HIDDEN_SIZE,
|
|
workspace_ptrs=workspace_tensor,
|
|
launch_with_pdl=launch_with_pdl,
|
|
residual_in=residual_in,
|
|
rms_gamma=rms_gamma,
|
|
rms_eps=rms_eps,
|
|
scale_factor=scale_factor_float,
|
|
moe_reduction_device_num_experts=active_expert_num,
|
|
moe_reduction_scale_input=moe_reduction_scale_input,
|
|
moe_reduction_active_experts_token_input=moe_reduction_active_experts_token_input,
|
|
moe_reduction_token_input=moe_reduction_token_input,
|
|
layout_code=swizzled_layout_code,
|
|
moe_allreduce_out=moe_allreduce_out,
|
|
residual_out=residual_out,
|
|
norm_out=norm_out,
|
|
quant_out=quant_out,
|
|
scale_out=scale_out,
|
|
)
|
|
|
|
# replay
|
|
g.replay()
|
|
|
|
# match shape
|
|
moe_allreduce_out = moe_allreduce_out.view(
|
|
token_num, HIDDEN_SIZE
|
|
)
|
|
residual_out = residual_out.view(token_num, HIDDEN_SIZE)
|
|
norm_out = norm_out.view(token_num, HIDDEN_SIZE)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
# 6. Check correctness
|
|
tolerance = 8e-2 if dtype == torch.float16 else 8e-1
|
|
# 6.1 Check allreduce_out
|
|
if not torch.allclose(
|
|
moe_allreduce_out.to(torch.float32),
|
|
moe_allreduce_ref,
|
|
atol=tolerance,
|
|
rtol=1e-2,
|
|
):
|
|
test_passed = False
|
|
print(f"Rank {rank} moe_allreduce_out mismatch")
|
|
print(f"moe_allreduce_out: {moe_allreduce_out}")
|
|
print(f"moe_allreduce_ref: {moe_allreduce_ref}")
|
|
# Print max diff elements for allreduce_out
|
|
max_diff = torch.max(
|
|
torch.abs(
|
|
moe_allreduce_out.to(torch.float32)
|
|
- moe_allreduce_ref
|
|
)
|
|
)
|
|
max_diff_idx = torch.argmax(
|
|
torch.abs(
|
|
moe_allreduce_out.to(torch.float32)
|
|
- moe_allreduce_ref
|
|
)
|
|
)
|
|
print(
|
|
f"Rank {rank} moe_allreduce_out max diff: {max_diff}"
|
|
)
|
|
print(
|
|
f"Rank {rank} moe_allreduce_out max diff idx: {max_diff_idx}"
|
|
)
|
|
print(
|
|
f"Rank {rank} moe_allreduce_out value at max diff: {moe_allreduce_out.view(-1)[max_diff_idx]}"
|
|
)
|
|
print(
|
|
f"Rank {rank} moe_allreduce_out ref value at max diff: {moe_allreduce_ref.view(-1)[max_diff_idx]}"
|
|
)
|
|
|
|
torch.testing.assert_close(
|
|
moe_allreduce_out.to(torch.float32),
|
|
moe_allreduce_ref,
|
|
atol=tolerance,
|
|
rtol=1e-2,
|
|
)
|
|
|
|
# 6.2 Check residual_out
|
|
if not torch.allclose(
|
|
residual_out.to(torch.float32),
|
|
ref_residual_out,
|
|
atol=tolerance,
|
|
rtol=1e-2,
|
|
):
|
|
test_passed = False
|
|
print(f"Rank {rank} residual_out mismatch")
|
|
print(f"residual_out: {residual_out}")
|
|
print(f"ref_residual_out: {ref_residual_out}")
|
|
# Print max diff elements for residual_out
|
|
max_diff = torch.max(
|
|
torch.abs(
|
|
residual_out.to(torch.float32)
|
|
- ref_residual_out
|
|
)
|
|
)
|
|
max_diff_idx = torch.argmax(
|
|
torch.abs(
|
|
residual_out.to(torch.float32)
|
|
- ref_residual_out
|
|
)
|
|
)
|
|
print(f"Rank {rank} residual_out max diff: {max_diff}")
|
|
print(
|
|
f"Rank {rank} residual_out max diff idx: {max_diff_idx}"
|
|
)
|
|
print(
|
|
f"Rank {rank} residual_out value at max diff: {residual_out.view(-1)[max_diff_idx]}"
|
|
)
|
|
print(
|
|
f"Rank {rank} residual_out ref value at max diff: {ref_residual_out.view(-1)[max_diff_idx]}"
|
|
)
|
|
torch.testing.assert_close(
|
|
residual_out.to(torch.float32),
|
|
ref_residual_out,
|
|
atol=tolerance,
|
|
rtol=1e-2,
|
|
)
|
|
# 6.3 Check norm_out
|
|
if not torch.allclose(
|
|
norm_out.to(torch.float32),
|
|
ref_norm_out,
|
|
atol=tolerance,
|
|
rtol=1e-2,
|
|
):
|
|
test_passed = False
|
|
print(f"Rank {rank} norm_out mismatch")
|
|
print(f"norm_out: {norm_out}")
|
|
print(f"ref_norm_out: {ref_norm_out}")
|
|
# Print max diff elements for norm_out
|
|
max_diff = torch.max(
|
|
torch.abs(norm_out.to(torch.float32) - ref_norm_out)
|
|
)
|
|
max_diff_idx = torch.argmax(
|
|
torch.abs(norm_out.to(torch.float32) - ref_norm_out)
|
|
)
|
|
print(f"Rank {rank} norm_out max diff: {max_diff}")
|
|
print(
|
|
f"Rank {rank} norm_out max diff idx: {max_diff_idx}"
|
|
)
|
|
print(
|
|
f"Rank {rank} norm_out value at max diff: {norm_out.view(-1)[max_diff_idx]}"
|
|
)
|
|
print(
|
|
f"Rank {rank} norm_out ref value at max diff: {ref_norm_out.view(-1)[max_diff_idx]}"
|
|
)
|
|
|
|
torch.testing.assert_close(
|
|
norm_out.to(torch.float32),
|
|
ref_norm_out,
|
|
atol=tolerance,
|
|
rtol=1e-2,
|
|
)
|
|
# 6.4 Check quant_out
|
|
# todo
|
|
|
|
dist.barrier(group=group)
|
|
if test_passed:
|
|
print(
|
|
f"test RANK {rank}: token{token_num}-expert{active_expert_num}-tp{world_size}-{dtype}-layout{swizzled_layout_code}-pdl{launch_with_pdl} passed"
|
|
)
|
|
else:
|
|
print(
|
|
f"test RANK {rank}: token{token_num}-expert{active_expert_num}-tp{world_size}-{dtype}-layout{swizzled_layout_code}-pdl{launch_with_pdl} failed"
|
|
)
|
|
finally:
|
|
dist.barrier(group=group)
|
|
|
|
comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group)
|
|
|
|
dist.destroy_process_group(group=group)
|
|
|
|
|
|
def get_open_port() -> int:
|
|
try:
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("127.0.0.1", 0))
|
|
return s.getsockname()[1]
|
|
except OSError:
|
|
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
|
s.bind(("::1", 0))
|
|
return s.getsockname()[1]
|
|
|
|
|
|
def multi_process_parallel(
|
|
world_size: int, dtype: torch.dtype, test_target: Any, target_args: tuple = ()
|
|
) -> None:
|
|
mp.set_start_method("spawn", force=True)
|
|
|
|
procs = []
|
|
distributed_init_port = get_open_port()
|
|
for i in range(world_size):
|
|
proc_args = (world_size, i, dtype, distributed_init_port) + target_args
|
|
proc = mp.Process(target=test_target, args=proc_args, name=f"Worker-{i}")
|
|
proc.start()
|
|
procs.append(proc)
|
|
|
|
for i in range(world_size):
|
|
procs[i].join()
|
|
assert procs[i].exitcode == 0, (
|
|
f"Process {i} failed with exit code {procs[i].exitcode}"
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("world_size", [2, 4])
|
|
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
|
def test_trtllm_moe_allreduce_fusion(world_size, dtype):
|
|
np.random.seed(42)
|
|
torch.manual_seed(42)
|
|
torch.cuda.manual_seed_all(42)
|
|
available_gpus = torch.cuda.device_count()
|
|
if world_size > available_gpus:
|
|
pytest.skip(
|
|
f"world_size {world_size} is greater than available_gpus {available_gpus}"
|
|
)
|
|
print(f"Running test for world_size={world_size}")
|
|
|
|
multi_process_parallel(
|
|
world_size,
|
|
dtype,
|
|
_run_correctness_worker,
|
|
target_args=(),
|
|
)
|
|
print(f"moe allreduce fusion tp = {world_size}: OK")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_trtllm_moe_allreduce_fusion(2, torch.float16)
|