sglang_v0.5.2/flashinfer_0.3.1/tests/test_trtllm_alltoall.py

865 lines
30 KiB
Python

"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import pytest
import torch
import flashinfer.comm.trtllm_alltoall as tllm_alltoall
has_setup_max_sm_count = False
@pytest.fixture(autouse=True, scope="session")
def setup_test_environment():
"""Set up test environment and warm up JIT compilation."""
global has_setup_max_sm_count
if not has_setup_max_sm_count:
# Set up SM count once for all tests
sm_count = torch.cuda.get_device_properties(0).multi_processor_count
max_sm_count = sm_count // 8 # Maximum world size is 8
tllm_alltoall.set_moe_max_usable_sm_count(max_sm_count)
has_setup_max_sm_count = True
torch.manual_seed(0x1234)
yield
# Single GPU test parameters
SINGLE_GPU_PARAMS = [
(902, 701, 32768, 100, torch.float16), # Large data, float16
(101, 75, 288, 10, torch.float16), # Medium data, float16
(10, 5, 8, 1, torch.float16), # Small data, float16
(902, 701, 7168, 100, torch.bfloat16), # Large data, bfloat16
(101, 75, 288, 10, torch.bfloat16), # Medium data, bfloat16
]
MULTI_RANK_PARAMS = [
(2, 5, 8, torch.float16), # Small input, 2 ranks
(4, 901, 32768, torch.bfloat16), # Large input, 4 ranks
(8, 16384, 128, torch.float16), # Many small vectors, 8 ranks
]
PREPARE_INDICES_PARAMS = [
(0, 8, 256, 4, 3, False), # Rank 0, small config
(1, 8, 256, 4, 3, True), # Rank 1, small config with real cumsum
(7, 8, 256, 8, 1025, False), # High rank, medium config
(7, 64, 1024, 32, 1029, True), # High rank, large config with real cumsum
]
LOCAL_GATHER_PARAMS = [
(0, 8, 256, 4, 3), # Rank 0, small config
(7, 8, 256, 8, 32), # High rank, medium config
(7, 64, 1024, 32, 1029), # High rank, large config
]
# Real cross-GPU communication test parameters
CROSS_GPU_PARAMS = [
(2, 100, 256, torch.float16), # 2 GPUs, 2 ranks
(2, 300, 512, torch.bfloat16), # 2 GPUs, 2 ranks, larger data
(4, 150, 256, torch.float16), # 4 GPUs, 4 ranks (if available)
(4, 400, 512, torch.float16), # 4 GPUs, 4 ranks, larger data
]
def get_available_gpu_count():
"""Get the number of available GPUs."""
if not torch.cuda.is_available():
return 0
return torch.cuda.device_count()
def requires_gpus(min_gpus):
"""Decorator to skip test if insufficient GPUs are available."""
def decorator(func):
return pytest.mark.skipif(
get_available_gpu_count() < min_gpus,
reason=f"Requires at least {min_gpus} GPUs, but only {get_available_gpu_count()} available",
)(func)
return decorator
@pytest.mark.parametrize(
"input_entry_count,output_entry_count,vector_dim,send_recv_count,dtype",
SINGLE_GPU_PARAMS,
)
def test_moe_alltoall_single_gpu(
input_entry_count, output_entry_count, vector_dim, send_recv_count, dtype
):
"""Test MOE alltoall communication on single GPU."""
torch.cuda.set_device(0)
# Create a random input tensor
input_tensor = torch.randn(
input_entry_count, vector_dim, dtype=dtype, device=torch.device("cuda")
)
output_tensor = torch.zeros(
output_entry_count, vector_dim, dtype=dtype, device=torch.device("cuda")
)
send_cumsum = (
torch.ones((1,), dtype=torch.int32, device=torch.device("cuda"))
* send_recv_count
)
recv_cumsum = (
torch.ones((1,), dtype=torch.int32, device=torch.device("cuda"))
* send_recv_count
)
send_indices = torch.randperm(
input_entry_count, dtype=torch.int32, device=torch.device("cuda")
)[:send_recv_count]
recv_indices = torch.randperm(
output_entry_count, dtype=torch.int32, device=torch.device("cuda")
)[:send_recv_count]
ref_output_tensor = torch.zeros(
output_entry_count, vector_dim, dtype=dtype, device=torch.device("cuda")
)
ref_output_tensor[recv_indices] = input_tensor[send_indices]
workspace_size = tllm_alltoall.get_moe_commworkspace_size_per_rank(1)
all_workspaces = torch.zeros(
1, workspace_size, dtype=torch.uint64, device=torch.device("cuda")
)
tllm_alltoall.moe_comm(
input_tensor,
send_cumsum,
send_indices,
output_tensor,
recv_cumsum,
recv_indices,
all_workspaces,
0,
1,
)
torch.testing.assert_close(output_tensor, ref_output_tensor, atol=1e-5, rtol=1e-5)
@pytest.mark.parametrize(
"world_size,input_entry_per_rank,vector_dim,dtype", MULTI_RANK_PARAMS
)
def test_moe_alltoall_multi_rank_single_gpu(
world_size, input_entry_per_rank, vector_dim, dtype
):
"""Test MOE alltoall communication with multiple ranks on single GPU."""
torch.cuda.set_device(0)
max_world_size = 8
assert world_size <= max_world_size, (
f"should run with world_size at most {max_world_size}"
)
# SM count is now set up globally in the fixture
# Create a random input tensor
input_tensor = torch.randn(
input_entry_per_rank * world_size,
vector_dim,
dtype=dtype,
device=torch.device("cuda"),
)
output_tensor = torch.zeros(
input_entry_per_rank * world_size,
vector_dim,
dtype=dtype,
device=torch.device("cuda"),
)
ref_output_tensor = torch.zeros(
input_entry_per_rank * world_size,
vector_dim,
dtype=dtype,
device=torch.device("cuda"),
)
target_rank_ids = torch.randint(
0,
world_size,
(input_entry_per_rank * world_size,),
dtype=torch.int32,
device=torch.device("cuda"),
)
input_tensors_all_ranks = list(torch.split(input_tensor, input_entry_per_rank))
target_rank_ids_all_ranks = list(torch.split(target_rank_ids, input_entry_per_rank))
send_ids_all_ranks = []
send_counts_all_ranks = []
send_cumsum_all_ranks = []
send_start_end_all_ranks = []
# each rank do its own local compute to get how to send data to other ranks.
for rank in range(world_size):
send_start_end = []
local_target_rank_ids = target_rank_ids_all_ranks[rank]
sorted_local_target_rank_ids, local_send_id = torch.sort(local_target_rank_ids)
local_send_id = local_send_id.to(torch.int32)
padded_sorted_local_target_rank_ids = torch.cat(
(
sorted_local_target_rank_ids,
torch.arange(
world_size, dtype=torch.int32, device=torch.device("cuda")
),
)
)
unique_target_rank_ids, local_send_counts = torch.unique(
padded_sorted_local_target_rank_ids, return_counts=True
)
local_send_counts = local_send_counts.to(torch.int32)
assert unique_target_rank_ids.numel() == world_size, (
"unique_target_rank_ids must be equal to world_size"
)
local_send_counts -= 1 # remove padding
local_send_cumsum = torch.cumsum(local_send_counts, dim=0).to(torch.int32)
send_ids_all_ranks.append(local_send_id)
send_counts_all_ranks.append(local_send_counts)
send_cumsum_all_ranks.append(local_send_cumsum)
local_send_cumsum_cpu = local_send_cumsum.cpu().tolist()
for i in range(len(local_send_cumsum_cpu)):
send_start_end.append(
(
local_send_cumsum_cpu[i - 1] if i > 0 else 0,
local_send_cumsum_cpu[i],
)
)
send_start_end_all_ranks.append(send_start_end)
recv_ids_all_ranks = []
recv_cumsum_all_ranks = []
output_tensors_all_ranks = []
total_recv_all_ranks_cpu = []
output_indice_offset = 0
output_start_current_rank = 0
# each rank do compute based on other ranks' send counts to get how to receive data from other ranks.
for rank in range(world_size):
local_recv_counts = torch.zeros(
world_size, dtype=torch.int32, device=torch.device("cuda")
)
for other_rank in range(world_size):
local_recv_counts[other_rank] = send_counts_all_ranks[other_rank][rank]
local_recv_count_pair = local_recv_counts[other_rank].cpu().item()
send_rank_start_end = send_start_end_all_ranks[other_rank][rank]
ref_output_tensor[
output_indice_offset : output_indice_offset + local_recv_count_pair
] = input_tensors_all_ranks[other_rank][
send_ids_all_ranks[other_rank][
send_rank_start_end[0] : send_rank_start_end[1]
]
]
output_indice_offset += local_recv_count_pair
local_recv_cumsum = torch.cumsum(local_recv_counts, dim=0).to(torch.int32)
recv_cumsum_all_ranks.append(local_recv_cumsum)
total_recv_count = local_recv_cumsum[-1].cpu()
total_recv_all_ranks_cpu.append(total_recv_count)
output_tensors_all_ranks.append(
output_tensor[
output_start_current_rank : output_start_current_rank + total_recv_count
]
)
output_start_current_rank += total_recv_count
local_recv_ids = torch.arange(
total_recv_count, dtype=torch.int32, device=torch.device("cuda")
)
recv_ids_all_ranks.append(local_recv_ids)
cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(world_size)]
workspace_size = tllm_alltoall.get_moe_commworkspace_size_per_rank(world_size)
all_workspaces = torch.zeros(
world_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda")
)
# Synchronize before starting parallel communication
torch.cuda.synchronize()
# do alltoall in parallel
for rank in range(world_size):
with torch.cuda.stream(cuda_streams_all_ranks[rank]):
tllm_alltoall.moe_comm(
input_tensors_all_ranks[rank],
send_cumsum_all_ranks[rank],
send_ids_all_ranks[rank],
output_tensors_all_ranks[rank],
recv_cumsum_all_ranks[rank],
recv_ids_all_ranks[rank],
all_workspaces,
rank,
world_size,
)
for rank in range(world_size):
cuda_streams_all_ranks[rank].synchronize()
torch.testing.assert_close(output_tensor, ref_output_tensor, atol=1e-5, rtol=1e-5)
@pytest.mark.parametrize(
"ep_rank,ep_size,expert_count,top_k,max_token_count_per_rank,use_real_rank_token_count_cumsum",
PREPARE_INDICES_PARAMS,
)
def test_moe_alltoall_prepare_indices(
ep_rank,
ep_size,
expert_count,
top_k,
max_token_count_per_rank,
use_real_rank_token_count_cumsum,
):
"""Test MOE alltoall prepare indices functionality."""
torch.cuda.set_device(0)
def generate_references():
rank_token_count = max_token_count_per_rank
if use_real_rank_token_count_cumsum:
# Make sure we have at least 1 token in each rank except last rank
rank_token_counts = [
max(1, torch.randint(1, max_token_count_per_rank + 1, (1,)).item())
for _ in range(ep_size - 1)
]
rank_token_counts.append(
max_token_count_per_rank
) # last rank has max tokens
real_rank_token_count_cumsum = (
torch.tensor(
rank_token_counts, dtype=torch.int32, device=torch.device("cuda")
)
.cumsum(dim=0)
.to(torch.int32)
)
rank_token_count = rank_token_counts[ep_rank]
else:
real_rank_token_count_cumsum = None
# Generate target rank ids for this rank
target_rank_ids = torch.randint(
0,
ep_size,
(rank_token_count, top_k),
dtype=torch.int32,
device=torch.device("cuda"),
)
if not use_real_rank_token_count_cumsum:
gathered_target_rank_ids = torch.zeros(
ep_size * max_token_count_per_rank,
top_k,
dtype=torch.int32,
device=torch.device("cuda"),
)
gathered_target_rank_ids[
ep_rank * max_token_count_per_rank : ep_rank * max_token_count_per_rank
+ rank_token_count
] = target_rank_ids
else:
total_tokens = real_rank_token_count_cumsum[-1].item()
gathered_target_rank_ids = torch.zeros(
total_tokens, top_k, dtype=torch.int32, device=torch.device("cuda")
)
start_pos = (
0 if ep_rank == 0 else real_rank_token_count_cumsum[ep_rank - 1].item()
)
gathered_target_rank_ids[start_pos : start_pos + rank_token_count] = (
target_rank_ids
)
return gathered_target_rank_ids, real_rank_token_count_cumsum, target_rank_ids
gathered_target_rank_ids, real_rank_token_count_cumsum, target_rank_ids = (
generate_references()
)
(
local_gather_indices,
send_rank_count_cumsum,
send_rank_local_indices,
recv_rank_count_cumsum,
recv_rank_local_indices,
backward_recv_rank_local_indices,
) = tllm_alltoall.moe_comm_prepare_indices(
gathered_target_rank_ids,
real_rank_token_count_cumsum,
max_token_count_per_rank,
expert_count,
top_k,
ep_rank,
ep_size,
)
# Validate shapes
assert local_gather_indices.shape[0] <= max_token_count_per_rank * ep_size
assert send_rank_count_cumsum.shape[0] == ep_size
assert recv_rank_count_cumsum.shape[0] == ep_size
assert send_rank_local_indices.shape[0] <= max_token_count_per_rank * max(
ep_size, top_k
)
assert recv_rank_local_indices.shape[0] <= max_token_count_per_rank * ep_size
assert backward_recv_rank_local_indices.shape[0] <= max_token_count_per_rank * max(
ep_size, top_k
)
# Basic validation - cumulative sums should be non-decreasing
assert torch.all(send_rank_count_cumsum[1:] >= send_rank_count_cumsum[:-1])
assert torch.all(recv_rank_count_cumsum[1:] >= recv_rank_count_cumsum[:-1])
@pytest.mark.parametrize(
"ep_rank,ep_size,expert_count,top_k,max_token_count_per_rank", LOCAL_GATHER_PARAMS
)
def test_moe_local_gather(
ep_rank,
ep_size,
expert_count,
top_k,
max_token_count_per_rank,
):
"""Test MOE local gather functionality."""
torch.cuda.set_device(0)
# Generate test data using the original method
rank_token_count_cumsum = torch.randint(
0,
max_token_count_per_rank + 1,
(ep_size,),
dtype=torch.int32,
device=torch.device("cuda"),
)
rank_token_count_cumsum = torch.cumsum(rank_token_count_cumsum, dim=0).to(
torch.int32
)
local_token_count = rank_token_count_cumsum[ep_size - 1].cpu().item()
local_max_token_count = max_token_count_per_rank * ep_size
local_gather_indices = torch.randint(
0,
max_token_count_per_rank * ep_size,
(local_max_token_count,),
dtype=torch.int32,
device=torch.device("cuda"),
)
gathered_expert_ids = torch.randint(
0,
expert_count,
(max_token_count_per_rank * ep_size, top_k),
dtype=torch.int32,
device=torch.device("cuda"),
)
gathered_scales = torch.rand(
(max_token_count_per_rank * ep_size, top_k),
dtype=torch.float32,
device=torch.device("cuda"),
)
ref_local_expert_ids = torch.zeros(
local_max_token_count, top_k, dtype=torch.int32, device=torch.device("cuda")
)
ref_local_scales = torch.zeros(
local_max_token_count,
top_k,
dtype=torch.float32,
device=torch.device("cuda"),
)
# compute reference
ref_local_expert_ids += expert_count
valid_local_gather_indices = local_gather_indices[:local_token_count]
ref_local_expert_ids[:local_token_count] = gathered_expert_ids[
valid_local_gather_indices
]
ref_local_scales[:local_token_count] = gathered_scales[valid_local_gather_indices]
local_expert_ids = torch.empty(
local_max_token_count, top_k, dtype=torch.int32, device=torch.device("cuda")
)
local_scales = torch.empty(
local_max_token_count,
top_k,
dtype=torch.float32,
device=torch.device("cuda"),
)
tllm_alltoall.moe_local_gather(
rank_token_count_cumsum,
local_gather_indices,
gathered_expert_ids,
gathered_scales,
local_expert_ids,
local_scales,
max_token_count_per_rank,
expert_count,
top_k,
ep_rank,
ep_size,
)
assert torch.equal(local_expert_ids, ref_local_expert_ids)
assert torch.equal(local_scales, ref_local_scales)
@pytest.mark.parametrize(
"ep_rank, ep_size, expert_count, slot_count, top_k, max_token_count_per_rank",
[
(0, 2, 16, 20, 8, 512),
(0, 2, 16, 16, 3, 300),
(0, 4, 20, 24, 8, 4000),
(0, 8, 96, 96, 8, 1000),
(3, 8, 128, 128, 8, 1000),
(3, 8, 128, 144, 8, 1),
(0, 4, 72, 80, 4, 2256),
(0, 4, 72, 80, 6, 3333),
# Hang with stream count > 8
# (0, 9, 90, 8, 100),
],
)
def test_moe_alltoall_prepare(
ep_rank: int,
ep_size: int,
expert_count: int,
slot_count: int,
top_k: int,
max_token_count_per_rank: int,
):
torch.cuda.set_device(0)
cpu_expert_ids_all_ranks_lists = []
cpu_token_count_lists = []
cpu_scales_all_ranks_lists = []
for _ in range(ep_size):
token_count = torch.randint(
max_token_count_per_rank // 2,
max_token_count_per_rank + 1,
(1,),
dtype=torch.int32,
device=torch.device("cpu"),
)
token_count = 1 if token_count == 0 else token_count
token_count = max_token_count_per_rank
cpu_expert_ids_all_ranks_lists.append(
torch.randint(
0,
slot_count,
(token_count, top_k),
dtype=torch.int32,
device=torch.device("cpu"),
)
)
cpu_scales_all_ranks_lists.append(
torch.zeros(
token_count, top_k, dtype=torch.float32, device=torch.device("cpu")
)
+ 0.5
)
cpu_token_count_lists.append(token_count)
def compute_target_rank(expert_id):
ep_per_rank = slot_count // ep_size
return expert_id // ep_per_rank
def generate_references():
ref_prepared_local_expert_ids = []
ref_prepared_local_scales = []
ref_local_send_rank_count_cumsum = [0] * ep_size
ref_local_recv_rank_count_cumsum = [0] * ep_size
ref_local_recv_rank_indices = []
local_token_count = cpu_token_count_lists[ep_rank]
send_token_count_to_ranks = [0] * ep_size
# send part
for token_id in range(local_token_count):
target_set = set()
for pos in range(top_k):
expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos])
target_rank_id = compute_target_rank(expert_id)
target_set.add(target_rank_id)
for target_rank_id in target_set:
send_token_count_to_ranks[target_rank_id] += 1
total_send_token_count = 0
for rank in range(ep_size):
# print(f'rank: {rank}, send_token_count_to_ranks[rank]: {send_token_count_to_ranks[rank]}')
base = ref_local_send_rank_count_cumsum[rank - 1] if rank > 0 else 0
ref_local_send_rank_count_cumsum[rank] = (
send_token_count_to_ranks[rank] + base
)
total_send_token_count += send_token_count_to_ranks[rank]
ref_local_backward_send_rank_indices = [0] * (total_send_token_count)
ref_local_send_rank_indices = [0] * (total_send_token_count)
current_send_token_ids = [0] * ep_size
for token_id in range(local_token_count):
target_set = set()
for pos in range(top_k):
expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos])
target_rank_id = compute_target_rank(expert_id)
if target_rank_id not in target_set:
cumsum_before = (
0
if target_rank_id == 0
else ref_local_send_rank_count_cumsum[target_rank_id - 1]
)
send_index = cumsum_before + current_send_token_ids[target_rank_id]
ref_local_send_rank_indices[send_index] = token_id
ref_local_backward_send_rank_indices[send_index] = (
token_id * top_k + pos
)
current_send_token_ids[target_rank_id] += 1
target_set.add(target_rank_id)
# receive part
total_recv_token_count = 0
for rank in range(ep_size):
token_count = cpu_token_count_lists[rank]
current_recv_token_count = 0
for token_id in range(token_count):
token_is_received = False
for pos in range(top_k):
expert_id = int(cpu_expert_ids_all_ranks_lists[rank][token_id][pos])
sf = cpu_scales_all_ranks_lists[rank][token_id][pos]
target_rank_id = compute_target_rank(expert_id)
if target_rank_id == ep_rank:
if not token_is_received:
token_is_received = True
ref_prepared_local_expert_ids.append([slot_count] * top_k)
ref_prepared_local_scales.append([0.0] * top_k)
ref_prepared_local_expert_ids[-1][pos] = expert_id
ref_prepared_local_scales[-1][pos] = sf
if token_is_received:
ref_local_recv_rank_indices.append(total_recv_token_count)
total_recv_token_count += 1
current_recv_token_count += 1
ref_local_recv_rank_count_cumsum[rank] = (
current_recv_token_count
if rank == 0
else ref_local_recv_rank_count_cumsum[rank - 1]
+ current_recv_token_count
)
return (
ref_prepared_local_expert_ids,
ref_prepared_local_scales,
ref_local_send_rank_count_cumsum,
ref_local_send_rank_indices,
ref_local_recv_rank_count_cumsum,
ref_local_recv_rank_indices,
ref_local_backward_send_rank_indices,
total_recv_token_count,
)
(
ref_prepared_local_expert_ids,
ref_prepared_local_scales,
ref_local_send_rank_count_cumsum,
ref_local_send_rank_indices,
ref_local_recv_rank_count_cumsum,
ref_local_recv_rank_indices,
ref_local_backward_send_rank_indices,
total_recv_token_count,
) = generate_references()
cpu_experter_count_lists = []
for rank in range(ep_size):
local_expert_count = []
for i in range(expert_count):
local_expert_count.append(rank * expert_count + i)
cpu_experter_count_lists.append(torch.IntTensor(local_expert_count))
# expert_ids_all_ranks = torch.tensor(cpu_expert_ids_all_ranks_lists).cuda()
expert_ids_all_ranks = [
cpu_expert_ids_all_ranks_lists[i].cuda() for i in range(ep_size)
]
# scales_all_ranks = torch.FloatTensor(cpu_scales_all_ranks_lists).cuda()
scales_all_ranks = [cpu_scales_all_ranks_lists[i].cuda() for i in range(ep_size)]
experter_count_lists = [cpu_experter_count_lists[i].cuda() for i in range(ep_size)]
cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(ep_size)]
workspace_size = tllm_alltoall.get_moe_prepare_workspace_size_per_rank(ep_size)
all_workspaces = torch.zeros(
ep_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda")
)
stream = torch.cuda.Stream()
with torch.cuda.stream(stream):
tllm_alltoall.moe_prepare(
expert_ids_all_ranks[0],
scales_all_ranks[0],
experter_count_lists[0],
all_workspaces,
max_token_count_per_rank,
0,
1,
expert_count,
slot_count,
top_k,
)
stream.wait_stream(torch.cuda.current_stream())
# Make torch alloc tensor to avoid cuda sync
prepared_local_experts = []
prepared_local_scales = []
local_send_rank_count_cumsum = []
local_send_rank_indices = []
local_recv_rank_count_cumsum = []
local_recv_rank_indices = []
backward_local_recv_rank_indices = []
for _ in range(ep_size):
prepared_local_experts.append(
torch.empty(
max_token_count_per_rank * ep_size,
top_k,
dtype=torch.int32,
device=torch.device("cuda"),
)
)
prepared_local_scales.append(
torch.empty(
max_token_count_per_rank * ep_size,
top_k,
dtype=torch.float32,
device=torch.device("cuda"),
)
)
local_send_rank_count_cumsum.append(
torch.empty(ep_size, dtype=torch.int32, device=torch.device("cuda"))
)
local_send_rank_indices.append(
torch.empty(
max_token_count_per_rank * ep_size,
dtype=torch.int32,
device=torch.device("cuda"),
)
)
local_recv_rank_count_cumsum.append(
torch.empty(0, dtype=torch.int32, device=torch.device("cuda"))
)
local_recv_rank_indices.append(
torch.empty(0, dtype=torch.int32, device=torch.device("cuda"))
)
backward_local_recv_rank_indices.append(
torch.empty(0, dtype=torch.int32, device=torch.device("cuda"))
)
prepared_local_experts = None
prepared_local_scales = None
local_send_rank_count_cumsum = None
local_send_rank_indices = None
local_recv_rank_count_cumsum = None
local_recv_rank_indices = None
backward_local_recv_rank_indices = None
# reset the workspace
all_workspaces = torch.zeros(
ep_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda")
)
# do prepare in parallel
for rank in range(ep_size):
with torch.cuda.stream(cuda_streams_all_ranks[rank]):
if rank == ep_rank:
(
prepared_local_experts,
prepared_local_scales,
local_send_rank_count_cumsum,
local_send_rank_indices,
local_recv_rank_count_cumsum,
local_recv_rank_indices,
backward_local_recv_rank_indices,
gathered_expert_statics,
) = tllm_alltoall.moe_prepare(
expert_ids_all_ranks[rank],
scales_all_ranks[rank],
experter_count_lists[rank],
all_workspaces,
max_token_count_per_rank,
rank,
ep_size,
expert_count,
slot_count,
top_k,
)
else:
tllm_alltoall.moe_prepare(
expert_ids_all_ranks[rank],
scales_all_ranks[rank],
experter_count_lists[rank],
all_workspaces,
max_token_count_per_rank,
rank,
ep_size,
expert_count,
slot_count,
top_k,
)
for rank in range(ep_size):
cuda_streams_all_ranks[rank].synchronize()
prepared_local_experts_cpu = prepared_local_experts[:total_recv_token_count].cpu()
prepared_local_scales_cpu = prepared_local_scales[:total_recv_token_count].cpu()
for i in range(total_recv_token_count):
for j in range(top_k):
expert_id = int(prepared_local_experts_cpu[i][j])
assert expert_id >= 0 and expert_id <= slot_count
if expert_id < slot_count:
assert compute_target_rank(expert_id) == ep_rank
scale = float(prepared_local_scales_cpu[i][j])
assert scale > 1e-6
gathered_expert_statics_cpu = gathered_expert_statics.cpu()
for rank in range(ep_size):
for i in range(expert_count):
assert int(gathered_expert_statics_cpu[rank][i]) == rank * expert_count + i
ref_local_send_rank_count_cumsum = torch.IntTensor(ref_local_send_rank_count_cumsum)
assert torch.equal(
local_send_rank_count_cumsum.cpu(), ref_local_send_rank_count_cumsum
)
local_send_rank_indices = local_send_rank_indices.cpu()
backward_local_recv_rank_indices = backward_local_recv_rank_indices.cpu()
for i in range(ep_size):
base = 0 if i == 0 else ref_local_send_rank_count_cumsum[i - 1]
for j in range(base, ref_local_send_rank_count_cumsum[i]):
token_id = local_send_rank_indices[j]
lane_id = backward_local_recv_rank_indices[j] - token_id * top_k
expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][lane_id])
assert compute_target_rank(expert_id) == i
ref_local_recv_rank_count_cumsum = torch.IntTensor(ref_local_recv_rank_count_cumsum)
assert torch.equal(
local_recv_rank_count_cumsum[: ref_local_recv_rank_count_cumsum.size(0)].cpu(),
ref_local_recv_rank_count_cumsum,
)
ref_local_recv_rank_indices = torch.IntTensor(ref_local_recv_rank_indices)
assert torch.equal(
local_recv_rank_indices[: ref_local_recv_rank_indices.size(0)].cpu(),
ref_local_recv_rank_indices,
)
if __name__ == "__main__":
pytest.main([__file__, "-v"])