865 lines
30 KiB
Python
865 lines
30 KiB
Python
"""
|
|
Copyright (c) 2024 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
import flashinfer.comm.trtllm_alltoall as tllm_alltoall
|
|
|
|
has_setup_max_sm_count = False
|
|
|
|
|
|
@pytest.fixture(autouse=True, scope="session")
|
|
def setup_test_environment():
|
|
"""Set up test environment and warm up JIT compilation."""
|
|
global has_setup_max_sm_count
|
|
if not has_setup_max_sm_count:
|
|
# Set up SM count once for all tests
|
|
sm_count = torch.cuda.get_device_properties(0).multi_processor_count
|
|
max_sm_count = sm_count // 8 # Maximum world size is 8
|
|
tllm_alltoall.set_moe_max_usable_sm_count(max_sm_count)
|
|
has_setup_max_sm_count = True
|
|
|
|
torch.manual_seed(0x1234)
|
|
yield
|
|
|
|
|
|
# Single GPU test parameters
|
|
SINGLE_GPU_PARAMS = [
|
|
(902, 701, 32768, 100, torch.float16), # Large data, float16
|
|
(101, 75, 288, 10, torch.float16), # Medium data, float16
|
|
(10, 5, 8, 1, torch.float16), # Small data, float16
|
|
(902, 701, 7168, 100, torch.bfloat16), # Large data, bfloat16
|
|
(101, 75, 288, 10, torch.bfloat16), # Medium data, bfloat16
|
|
]
|
|
|
|
MULTI_RANK_PARAMS = [
|
|
(2, 5, 8, torch.float16), # Small input, 2 ranks
|
|
(4, 901, 32768, torch.bfloat16), # Large input, 4 ranks
|
|
(8, 16384, 128, torch.float16), # Many small vectors, 8 ranks
|
|
]
|
|
|
|
PREPARE_INDICES_PARAMS = [
|
|
(0, 8, 256, 4, 3, False), # Rank 0, small config
|
|
(1, 8, 256, 4, 3, True), # Rank 1, small config with real cumsum
|
|
(7, 8, 256, 8, 1025, False), # High rank, medium config
|
|
(7, 64, 1024, 32, 1029, True), # High rank, large config with real cumsum
|
|
]
|
|
|
|
LOCAL_GATHER_PARAMS = [
|
|
(0, 8, 256, 4, 3), # Rank 0, small config
|
|
(7, 8, 256, 8, 32), # High rank, medium config
|
|
(7, 64, 1024, 32, 1029), # High rank, large config
|
|
]
|
|
|
|
|
|
# Real cross-GPU communication test parameters
|
|
CROSS_GPU_PARAMS = [
|
|
(2, 100, 256, torch.float16), # 2 GPUs, 2 ranks
|
|
(2, 300, 512, torch.bfloat16), # 2 GPUs, 2 ranks, larger data
|
|
(4, 150, 256, torch.float16), # 4 GPUs, 4 ranks (if available)
|
|
(4, 400, 512, torch.float16), # 4 GPUs, 4 ranks, larger data
|
|
]
|
|
|
|
|
|
def get_available_gpu_count():
|
|
"""Get the number of available GPUs."""
|
|
if not torch.cuda.is_available():
|
|
return 0
|
|
return torch.cuda.device_count()
|
|
|
|
|
|
def requires_gpus(min_gpus):
|
|
"""Decorator to skip test if insufficient GPUs are available."""
|
|
|
|
def decorator(func):
|
|
return pytest.mark.skipif(
|
|
get_available_gpu_count() < min_gpus,
|
|
reason=f"Requires at least {min_gpus} GPUs, but only {get_available_gpu_count()} available",
|
|
)(func)
|
|
|
|
return decorator
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"input_entry_count,output_entry_count,vector_dim,send_recv_count,dtype",
|
|
SINGLE_GPU_PARAMS,
|
|
)
|
|
def test_moe_alltoall_single_gpu(
|
|
input_entry_count, output_entry_count, vector_dim, send_recv_count, dtype
|
|
):
|
|
"""Test MOE alltoall communication on single GPU."""
|
|
torch.cuda.set_device(0)
|
|
# Create a random input tensor
|
|
input_tensor = torch.randn(
|
|
input_entry_count, vector_dim, dtype=dtype, device=torch.device("cuda")
|
|
)
|
|
output_tensor = torch.zeros(
|
|
output_entry_count, vector_dim, dtype=dtype, device=torch.device("cuda")
|
|
)
|
|
|
|
send_cumsum = (
|
|
torch.ones((1,), dtype=torch.int32, device=torch.device("cuda"))
|
|
* send_recv_count
|
|
)
|
|
recv_cumsum = (
|
|
torch.ones((1,), dtype=torch.int32, device=torch.device("cuda"))
|
|
* send_recv_count
|
|
)
|
|
send_indices = torch.randperm(
|
|
input_entry_count, dtype=torch.int32, device=torch.device("cuda")
|
|
)[:send_recv_count]
|
|
recv_indices = torch.randperm(
|
|
output_entry_count, dtype=torch.int32, device=torch.device("cuda")
|
|
)[:send_recv_count]
|
|
|
|
ref_output_tensor = torch.zeros(
|
|
output_entry_count, vector_dim, dtype=dtype, device=torch.device("cuda")
|
|
)
|
|
ref_output_tensor[recv_indices] = input_tensor[send_indices]
|
|
|
|
workspace_size = tllm_alltoall.get_moe_commworkspace_size_per_rank(1)
|
|
all_workspaces = torch.zeros(
|
|
1, workspace_size, dtype=torch.uint64, device=torch.device("cuda")
|
|
)
|
|
|
|
tllm_alltoall.moe_comm(
|
|
input_tensor,
|
|
send_cumsum,
|
|
send_indices,
|
|
output_tensor,
|
|
recv_cumsum,
|
|
recv_indices,
|
|
all_workspaces,
|
|
0,
|
|
1,
|
|
)
|
|
|
|
torch.testing.assert_close(output_tensor, ref_output_tensor, atol=1e-5, rtol=1e-5)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"world_size,input_entry_per_rank,vector_dim,dtype", MULTI_RANK_PARAMS
|
|
)
|
|
def test_moe_alltoall_multi_rank_single_gpu(
|
|
world_size, input_entry_per_rank, vector_dim, dtype
|
|
):
|
|
"""Test MOE alltoall communication with multiple ranks on single GPU."""
|
|
torch.cuda.set_device(0)
|
|
max_world_size = 8
|
|
assert world_size <= max_world_size, (
|
|
f"should run with world_size at most {max_world_size}"
|
|
)
|
|
|
|
# SM count is now set up globally in the fixture
|
|
|
|
# Create a random input tensor
|
|
input_tensor = torch.randn(
|
|
input_entry_per_rank * world_size,
|
|
vector_dim,
|
|
dtype=dtype,
|
|
device=torch.device("cuda"),
|
|
)
|
|
output_tensor = torch.zeros(
|
|
input_entry_per_rank * world_size,
|
|
vector_dim,
|
|
dtype=dtype,
|
|
device=torch.device("cuda"),
|
|
)
|
|
ref_output_tensor = torch.zeros(
|
|
input_entry_per_rank * world_size,
|
|
vector_dim,
|
|
dtype=dtype,
|
|
device=torch.device("cuda"),
|
|
)
|
|
target_rank_ids = torch.randint(
|
|
0,
|
|
world_size,
|
|
(input_entry_per_rank * world_size,),
|
|
dtype=torch.int32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
|
|
input_tensors_all_ranks = list(torch.split(input_tensor, input_entry_per_rank))
|
|
target_rank_ids_all_ranks = list(torch.split(target_rank_ids, input_entry_per_rank))
|
|
|
|
send_ids_all_ranks = []
|
|
send_counts_all_ranks = []
|
|
send_cumsum_all_ranks = []
|
|
send_start_end_all_ranks = []
|
|
|
|
# each rank do its own local compute to get how to send data to other ranks.
|
|
for rank in range(world_size):
|
|
send_start_end = []
|
|
local_target_rank_ids = target_rank_ids_all_ranks[rank]
|
|
sorted_local_target_rank_ids, local_send_id = torch.sort(local_target_rank_ids)
|
|
local_send_id = local_send_id.to(torch.int32)
|
|
padded_sorted_local_target_rank_ids = torch.cat(
|
|
(
|
|
sorted_local_target_rank_ids,
|
|
torch.arange(
|
|
world_size, dtype=torch.int32, device=torch.device("cuda")
|
|
),
|
|
)
|
|
)
|
|
unique_target_rank_ids, local_send_counts = torch.unique(
|
|
padded_sorted_local_target_rank_ids, return_counts=True
|
|
)
|
|
local_send_counts = local_send_counts.to(torch.int32)
|
|
assert unique_target_rank_ids.numel() == world_size, (
|
|
"unique_target_rank_ids must be equal to world_size"
|
|
)
|
|
local_send_counts -= 1 # remove padding
|
|
local_send_cumsum = torch.cumsum(local_send_counts, dim=0).to(torch.int32)
|
|
send_ids_all_ranks.append(local_send_id)
|
|
send_counts_all_ranks.append(local_send_counts)
|
|
send_cumsum_all_ranks.append(local_send_cumsum)
|
|
local_send_cumsum_cpu = local_send_cumsum.cpu().tolist()
|
|
for i in range(len(local_send_cumsum_cpu)):
|
|
send_start_end.append(
|
|
(
|
|
local_send_cumsum_cpu[i - 1] if i > 0 else 0,
|
|
local_send_cumsum_cpu[i],
|
|
)
|
|
)
|
|
send_start_end_all_ranks.append(send_start_end)
|
|
|
|
recv_ids_all_ranks = []
|
|
recv_cumsum_all_ranks = []
|
|
|
|
output_tensors_all_ranks = []
|
|
|
|
total_recv_all_ranks_cpu = []
|
|
output_indice_offset = 0
|
|
|
|
output_start_current_rank = 0
|
|
# each rank do compute based on other ranks' send counts to get how to receive data from other ranks.
|
|
for rank in range(world_size):
|
|
local_recv_counts = torch.zeros(
|
|
world_size, dtype=torch.int32, device=torch.device("cuda")
|
|
)
|
|
for other_rank in range(world_size):
|
|
local_recv_counts[other_rank] = send_counts_all_ranks[other_rank][rank]
|
|
local_recv_count_pair = local_recv_counts[other_rank].cpu().item()
|
|
send_rank_start_end = send_start_end_all_ranks[other_rank][rank]
|
|
ref_output_tensor[
|
|
output_indice_offset : output_indice_offset + local_recv_count_pair
|
|
] = input_tensors_all_ranks[other_rank][
|
|
send_ids_all_ranks[other_rank][
|
|
send_rank_start_end[0] : send_rank_start_end[1]
|
|
]
|
|
]
|
|
output_indice_offset += local_recv_count_pair
|
|
local_recv_cumsum = torch.cumsum(local_recv_counts, dim=0).to(torch.int32)
|
|
recv_cumsum_all_ranks.append(local_recv_cumsum)
|
|
total_recv_count = local_recv_cumsum[-1].cpu()
|
|
total_recv_all_ranks_cpu.append(total_recv_count)
|
|
output_tensors_all_ranks.append(
|
|
output_tensor[
|
|
output_start_current_rank : output_start_current_rank + total_recv_count
|
|
]
|
|
)
|
|
output_start_current_rank += total_recv_count
|
|
local_recv_ids = torch.arange(
|
|
total_recv_count, dtype=torch.int32, device=torch.device("cuda")
|
|
)
|
|
recv_ids_all_ranks.append(local_recv_ids)
|
|
|
|
cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(world_size)]
|
|
|
|
workspace_size = tllm_alltoall.get_moe_commworkspace_size_per_rank(world_size)
|
|
all_workspaces = torch.zeros(
|
|
world_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda")
|
|
)
|
|
|
|
# Synchronize before starting parallel communication
|
|
torch.cuda.synchronize()
|
|
|
|
# do alltoall in parallel
|
|
for rank in range(world_size):
|
|
with torch.cuda.stream(cuda_streams_all_ranks[rank]):
|
|
tllm_alltoall.moe_comm(
|
|
input_tensors_all_ranks[rank],
|
|
send_cumsum_all_ranks[rank],
|
|
send_ids_all_ranks[rank],
|
|
output_tensors_all_ranks[rank],
|
|
recv_cumsum_all_ranks[rank],
|
|
recv_ids_all_ranks[rank],
|
|
all_workspaces,
|
|
rank,
|
|
world_size,
|
|
)
|
|
for rank in range(world_size):
|
|
cuda_streams_all_ranks[rank].synchronize()
|
|
|
|
torch.testing.assert_close(output_tensor, ref_output_tensor, atol=1e-5, rtol=1e-5)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"ep_rank,ep_size,expert_count,top_k,max_token_count_per_rank,use_real_rank_token_count_cumsum",
|
|
PREPARE_INDICES_PARAMS,
|
|
)
|
|
def test_moe_alltoall_prepare_indices(
|
|
ep_rank,
|
|
ep_size,
|
|
expert_count,
|
|
top_k,
|
|
max_token_count_per_rank,
|
|
use_real_rank_token_count_cumsum,
|
|
):
|
|
"""Test MOE alltoall prepare indices functionality."""
|
|
torch.cuda.set_device(0)
|
|
|
|
def generate_references():
|
|
rank_token_count = max_token_count_per_rank
|
|
if use_real_rank_token_count_cumsum:
|
|
# Make sure we have at least 1 token in each rank except last rank
|
|
rank_token_counts = [
|
|
max(1, torch.randint(1, max_token_count_per_rank + 1, (1,)).item())
|
|
for _ in range(ep_size - 1)
|
|
]
|
|
rank_token_counts.append(
|
|
max_token_count_per_rank
|
|
) # last rank has max tokens
|
|
real_rank_token_count_cumsum = (
|
|
torch.tensor(
|
|
rank_token_counts, dtype=torch.int32, device=torch.device("cuda")
|
|
)
|
|
.cumsum(dim=0)
|
|
.to(torch.int32)
|
|
)
|
|
rank_token_count = rank_token_counts[ep_rank]
|
|
else:
|
|
real_rank_token_count_cumsum = None
|
|
|
|
# Generate target rank ids for this rank
|
|
target_rank_ids = torch.randint(
|
|
0,
|
|
ep_size,
|
|
(rank_token_count, top_k),
|
|
dtype=torch.int32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
|
|
if not use_real_rank_token_count_cumsum:
|
|
gathered_target_rank_ids = torch.zeros(
|
|
ep_size * max_token_count_per_rank,
|
|
top_k,
|
|
dtype=torch.int32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
gathered_target_rank_ids[
|
|
ep_rank * max_token_count_per_rank : ep_rank * max_token_count_per_rank
|
|
+ rank_token_count
|
|
] = target_rank_ids
|
|
else:
|
|
total_tokens = real_rank_token_count_cumsum[-1].item()
|
|
gathered_target_rank_ids = torch.zeros(
|
|
total_tokens, top_k, dtype=torch.int32, device=torch.device("cuda")
|
|
)
|
|
start_pos = (
|
|
0 if ep_rank == 0 else real_rank_token_count_cumsum[ep_rank - 1].item()
|
|
)
|
|
gathered_target_rank_ids[start_pos : start_pos + rank_token_count] = (
|
|
target_rank_ids
|
|
)
|
|
|
|
return gathered_target_rank_ids, real_rank_token_count_cumsum, target_rank_ids
|
|
|
|
gathered_target_rank_ids, real_rank_token_count_cumsum, target_rank_ids = (
|
|
generate_references()
|
|
)
|
|
|
|
(
|
|
local_gather_indices,
|
|
send_rank_count_cumsum,
|
|
send_rank_local_indices,
|
|
recv_rank_count_cumsum,
|
|
recv_rank_local_indices,
|
|
backward_recv_rank_local_indices,
|
|
) = tllm_alltoall.moe_comm_prepare_indices(
|
|
gathered_target_rank_ids,
|
|
real_rank_token_count_cumsum,
|
|
max_token_count_per_rank,
|
|
expert_count,
|
|
top_k,
|
|
ep_rank,
|
|
ep_size,
|
|
)
|
|
|
|
# Validate shapes
|
|
assert local_gather_indices.shape[0] <= max_token_count_per_rank * ep_size
|
|
assert send_rank_count_cumsum.shape[0] == ep_size
|
|
assert recv_rank_count_cumsum.shape[0] == ep_size
|
|
assert send_rank_local_indices.shape[0] <= max_token_count_per_rank * max(
|
|
ep_size, top_k
|
|
)
|
|
assert recv_rank_local_indices.shape[0] <= max_token_count_per_rank * ep_size
|
|
assert backward_recv_rank_local_indices.shape[0] <= max_token_count_per_rank * max(
|
|
ep_size, top_k
|
|
)
|
|
|
|
# Basic validation - cumulative sums should be non-decreasing
|
|
assert torch.all(send_rank_count_cumsum[1:] >= send_rank_count_cumsum[:-1])
|
|
assert torch.all(recv_rank_count_cumsum[1:] >= recv_rank_count_cumsum[:-1])
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"ep_rank,ep_size,expert_count,top_k,max_token_count_per_rank", LOCAL_GATHER_PARAMS
|
|
)
|
|
def test_moe_local_gather(
|
|
ep_rank,
|
|
ep_size,
|
|
expert_count,
|
|
top_k,
|
|
max_token_count_per_rank,
|
|
):
|
|
"""Test MOE local gather functionality."""
|
|
torch.cuda.set_device(0)
|
|
|
|
# Generate test data using the original method
|
|
rank_token_count_cumsum = torch.randint(
|
|
0,
|
|
max_token_count_per_rank + 1,
|
|
(ep_size,),
|
|
dtype=torch.int32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
rank_token_count_cumsum = torch.cumsum(rank_token_count_cumsum, dim=0).to(
|
|
torch.int32
|
|
)
|
|
local_token_count = rank_token_count_cumsum[ep_size - 1].cpu().item()
|
|
local_max_token_count = max_token_count_per_rank * ep_size
|
|
local_gather_indices = torch.randint(
|
|
0,
|
|
max_token_count_per_rank * ep_size,
|
|
(local_max_token_count,),
|
|
dtype=torch.int32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
|
|
gathered_expert_ids = torch.randint(
|
|
0,
|
|
expert_count,
|
|
(max_token_count_per_rank * ep_size, top_k),
|
|
dtype=torch.int32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
gathered_scales = torch.rand(
|
|
(max_token_count_per_rank * ep_size, top_k),
|
|
dtype=torch.float32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
|
|
ref_local_expert_ids = torch.zeros(
|
|
local_max_token_count, top_k, dtype=torch.int32, device=torch.device("cuda")
|
|
)
|
|
ref_local_scales = torch.zeros(
|
|
local_max_token_count,
|
|
top_k,
|
|
dtype=torch.float32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
|
|
# compute reference
|
|
ref_local_expert_ids += expert_count
|
|
valid_local_gather_indices = local_gather_indices[:local_token_count]
|
|
ref_local_expert_ids[:local_token_count] = gathered_expert_ids[
|
|
valid_local_gather_indices
|
|
]
|
|
ref_local_scales[:local_token_count] = gathered_scales[valid_local_gather_indices]
|
|
|
|
local_expert_ids = torch.empty(
|
|
local_max_token_count, top_k, dtype=torch.int32, device=torch.device("cuda")
|
|
)
|
|
local_scales = torch.empty(
|
|
local_max_token_count,
|
|
top_k,
|
|
dtype=torch.float32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
|
|
tllm_alltoall.moe_local_gather(
|
|
rank_token_count_cumsum,
|
|
local_gather_indices,
|
|
gathered_expert_ids,
|
|
gathered_scales,
|
|
local_expert_ids,
|
|
local_scales,
|
|
max_token_count_per_rank,
|
|
expert_count,
|
|
top_k,
|
|
ep_rank,
|
|
ep_size,
|
|
)
|
|
|
|
assert torch.equal(local_expert_ids, ref_local_expert_ids)
|
|
assert torch.equal(local_scales, ref_local_scales)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"ep_rank, ep_size, expert_count, slot_count, top_k, max_token_count_per_rank",
|
|
[
|
|
(0, 2, 16, 20, 8, 512),
|
|
(0, 2, 16, 16, 3, 300),
|
|
(0, 4, 20, 24, 8, 4000),
|
|
(0, 8, 96, 96, 8, 1000),
|
|
(3, 8, 128, 128, 8, 1000),
|
|
(3, 8, 128, 144, 8, 1),
|
|
(0, 4, 72, 80, 4, 2256),
|
|
(0, 4, 72, 80, 6, 3333),
|
|
# Hang with stream count > 8
|
|
# (0, 9, 90, 8, 100),
|
|
],
|
|
)
|
|
def test_moe_alltoall_prepare(
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
expert_count: int,
|
|
slot_count: int,
|
|
top_k: int,
|
|
max_token_count_per_rank: int,
|
|
):
|
|
torch.cuda.set_device(0)
|
|
|
|
cpu_expert_ids_all_ranks_lists = []
|
|
cpu_token_count_lists = []
|
|
cpu_scales_all_ranks_lists = []
|
|
for _ in range(ep_size):
|
|
token_count = torch.randint(
|
|
max_token_count_per_rank // 2,
|
|
max_token_count_per_rank + 1,
|
|
(1,),
|
|
dtype=torch.int32,
|
|
device=torch.device("cpu"),
|
|
)
|
|
token_count = 1 if token_count == 0 else token_count
|
|
|
|
token_count = max_token_count_per_rank
|
|
|
|
cpu_expert_ids_all_ranks_lists.append(
|
|
torch.randint(
|
|
0,
|
|
slot_count,
|
|
(token_count, top_k),
|
|
dtype=torch.int32,
|
|
device=torch.device("cpu"),
|
|
)
|
|
)
|
|
|
|
cpu_scales_all_ranks_lists.append(
|
|
torch.zeros(
|
|
token_count, top_k, dtype=torch.float32, device=torch.device("cpu")
|
|
)
|
|
+ 0.5
|
|
)
|
|
|
|
cpu_token_count_lists.append(token_count)
|
|
|
|
def compute_target_rank(expert_id):
|
|
ep_per_rank = slot_count // ep_size
|
|
return expert_id // ep_per_rank
|
|
|
|
def generate_references():
|
|
ref_prepared_local_expert_ids = []
|
|
ref_prepared_local_scales = []
|
|
ref_local_send_rank_count_cumsum = [0] * ep_size
|
|
ref_local_recv_rank_count_cumsum = [0] * ep_size
|
|
ref_local_recv_rank_indices = []
|
|
|
|
local_token_count = cpu_token_count_lists[ep_rank]
|
|
send_token_count_to_ranks = [0] * ep_size
|
|
|
|
# send part
|
|
for token_id in range(local_token_count):
|
|
target_set = set()
|
|
for pos in range(top_k):
|
|
expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos])
|
|
target_rank_id = compute_target_rank(expert_id)
|
|
target_set.add(target_rank_id)
|
|
|
|
for target_rank_id in target_set:
|
|
send_token_count_to_ranks[target_rank_id] += 1
|
|
|
|
total_send_token_count = 0
|
|
for rank in range(ep_size):
|
|
# print(f'rank: {rank}, send_token_count_to_ranks[rank]: {send_token_count_to_ranks[rank]}')
|
|
base = ref_local_send_rank_count_cumsum[rank - 1] if rank > 0 else 0
|
|
ref_local_send_rank_count_cumsum[rank] = (
|
|
send_token_count_to_ranks[rank] + base
|
|
)
|
|
total_send_token_count += send_token_count_to_ranks[rank]
|
|
|
|
ref_local_backward_send_rank_indices = [0] * (total_send_token_count)
|
|
ref_local_send_rank_indices = [0] * (total_send_token_count)
|
|
|
|
current_send_token_ids = [0] * ep_size
|
|
for token_id in range(local_token_count):
|
|
target_set = set()
|
|
for pos in range(top_k):
|
|
expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][pos])
|
|
target_rank_id = compute_target_rank(expert_id)
|
|
if target_rank_id not in target_set:
|
|
cumsum_before = (
|
|
0
|
|
if target_rank_id == 0
|
|
else ref_local_send_rank_count_cumsum[target_rank_id - 1]
|
|
)
|
|
send_index = cumsum_before + current_send_token_ids[target_rank_id]
|
|
ref_local_send_rank_indices[send_index] = token_id
|
|
ref_local_backward_send_rank_indices[send_index] = (
|
|
token_id * top_k + pos
|
|
)
|
|
current_send_token_ids[target_rank_id] += 1
|
|
target_set.add(target_rank_id)
|
|
|
|
# receive part
|
|
total_recv_token_count = 0
|
|
for rank in range(ep_size):
|
|
token_count = cpu_token_count_lists[rank]
|
|
current_recv_token_count = 0
|
|
for token_id in range(token_count):
|
|
token_is_received = False
|
|
for pos in range(top_k):
|
|
expert_id = int(cpu_expert_ids_all_ranks_lists[rank][token_id][pos])
|
|
sf = cpu_scales_all_ranks_lists[rank][token_id][pos]
|
|
target_rank_id = compute_target_rank(expert_id)
|
|
if target_rank_id == ep_rank:
|
|
if not token_is_received:
|
|
token_is_received = True
|
|
ref_prepared_local_expert_ids.append([slot_count] * top_k)
|
|
ref_prepared_local_scales.append([0.0] * top_k)
|
|
ref_prepared_local_expert_ids[-1][pos] = expert_id
|
|
ref_prepared_local_scales[-1][pos] = sf
|
|
if token_is_received:
|
|
ref_local_recv_rank_indices.append(total_recv_token_count)
|
|
total_recv_token_count += 1
|
|
current_recv_token_count += 1
|
|
ref_local_recv_rank_count_cumsum[rank] = (
|
|
current_recv_token_count
|
|
if rank == 0
|
|
else ref_local_recv_rank_count_cumsum[rank - 1]
|
|
+ current_recv_token_count
|
|
)
|
|
|
|
return (
|
|
ref_prepared_local_expert_ids,
|
|
ref_prepared_local_scales,
|
|
ref_local_send_rank_count_cumsum,
|
|
ref_local_send_rank_indices,
|
|
ref_local_recv_rank_count_cumsum,
|
|
ref_local_recv_rank_indices,
|
|
ref_local_backward_send_rank_indices,
|
|
total_recv_token_count,
|
|
)
|
|
|
|
(
|
|
ref_prepared_local_expert_ids,
|
|
ref_prepared_local_scales,
|
|
ref_local_send_rank_count_cumsum,
|
|
ref_local_send_rank_indices,
|
|
ref_local_recv_rank_count_cumsum,
|
|
ref_local_recv_rank_indices,
|
|
ref_local_backward_send_rank_indices,
|
|
total_recv_token_count,
|
|
) = generate_references()
|
|
|
|
cpu_experter_count_lists = []
|
|
for rank in range(ep_size):
|
|
local_expert_count = []
|
|
for i in range(expert_count):
|
|
local_expert_count.append(rank * expert_count + i)
|
|
cpu_experter_count_lists.append(torch.IntTensor(local_expert_count))
|
|
|
|
# expert_ids_all_ranks = torch.tensor(cpu_expert_ids_all_ranks_lists).cuda()
|
|
expert_ids_all_ranks = [
|
|
cpu_expert_ids_all_ranks_lists[i].cuda() for i in range(ep_size)
|
|
]
|
|
# scales_all_ranks = torch.FloatTensor(cpu_scales_all_ranks_lists).cuda()
|
|
scales_all_ranks = [cpu_scales_all_ranks_lists[i].cuda() for i in range(ep_size)]
|
|
|
|
experter_count_lists = [cpu_experter_count_lists[i].cuda() for i in range(ep_size)]
|
|
|
|
cuda_streams_all_ranks = [torch.cuda.Stream() for _ in range(ep_size)]
|
|
|
|
workspace_size = tllm_alltoall.get_moe_prepare_workspace_size_per_rank(ep_size)
|
|
|
|
all_workspaces = torch.zeros(
|
|
ep_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda")
|
|
)
|
|
|
|
stream = torch.cuda.Stream()
|
|
with torch.cuda.stream(stream):
|
|
tllm_alltoall.moe_prepare(
|
|
expert_ids_all_ranks[0],
|
|
scales_all_ranks[0],
|
|
experter_count_lists[0],
|
|
all_workspaces,
|
|
max_token_count_per_rank,
|
|
0,
|
|
1,
|
|
expert_count,
|
|
slot_count,
|
|
top_k,
|
|
)
|
|
stream.wait_stream(torch.cuda.current_stream())
|
|
|
|
# Make torch alloc tensor to avoid cuda sync
|
|
prepared_local_experts = []
|
|
prepared_local_scales = []
|
|
local_send_rank_count_cumsum = []
|
|
local_send_rank_indices = []
|
|
local_recv_rank_count_cumsum = []
|
|
local_recv_rank_indices = []
|
|
backward_local_recv_rank_indices = []
|
|
for _ in range(ep_size):
|
|
prepared_local_experts.append(
|
|
torch.empty(
|
|
max_token_count_per_rank * ep_size,
|
|
top_k,
|
|
dtype=torch.int32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
)
|
|
prepared_local_scales.append(
|
|
torch.empty(
|
|
max_token_count_per_rank * ep_size,
|
|
top_k,
|
|
dtype=torch.float32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
)
|
|
local_send_rank_count_cumsum.append(
|
|
torch.empty(ep_size, dtype=torch.int32, device=torch.device("cuda"))
|
|
)
|
|
local_send_rank_indices.append(
|
|
torch.empty(
|
|
max_token_count_per_rank * ep_size,
|
|
dtype=torch.int32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
)
|
|
local_recv_rank_count_cumsum.append(
|
|
torch.empty(0, dtype=torch.int32, device=torch.device("cuda"))
|
|
)
|
|
local_recv_rank_indices.append(
|
|
torch.empty(0, dtype=torch.int32, device=torch.device("cuda"))
|
|
)
|
|
backward_local_recv_rank_indices.append(
|
|
torch.empty(0, dtype=torch.int32, device=torch.device("cuda"))
|
|
)
|
|
|
|
prepared_local_experts = None
|
|
prepared_local_scales = None
|
|
local_send_rank_count_cumsum = None
|
|
local_send_rank_indices = None
|
|
local_recv_rank_count_cumsum = None
|
|
local_recv_rank_indices = None
|
|
backward_local_recv_rank_indices = None
|
|
|
|
# reset the workspace
|
|
all_workspaces = torch.zeros(
|
|
ep_size, workspace_size, dtype=torch.uint64, device=torch.device("cuda")
|
|
)
|
|
|
|
# do prepare in parallel
|
|
for rank in range(ep_size):
|
|
with torch.cuda.stream(cuda_streams_all_ranks[rank]):
|
|
if rank == ep_rank:
|
|
(
|
|
prepared_local_experts,
|
|
prepared_local_scales,
|
|
local_send_rank_count_cumsum,
|
|
local_send_rank_indices,
|
|
local_recv_rank_count_cumsum,
|
|
local_recv_rank_indices,
|
|
backward_local_recv_rank_indices,
|
|
gathered_expert_statics,
|
|
) = tllm_alltoall.moe_prepare(
|
|
expert_ids_all_ranks[rank],
|
|
scales_all_ranks[rank],
|
|
experter_count_lists[rank],
|
|
all_workspaces,
|
|
max_token_count_per_rank,
|
|
rank,
|
|
ep_size,
|
|
expert_count,
|
|
slot_count,
|
|
top_k,
|
|
)
|
|
else:
|
|
tllm_alltoall.moe_prepare(
|
|
expert_ids_all_ranks[rank],
|
|
scales_all_ranks[rank],
|
|
experter_count_lists[rank],
|
|
all_workspaces,
|
|
max_token_count_per_rank,
|
|
rank,
|
|
ep_size,
|
|
expert_count,
|
|
slot_count,
|
|
top_k,
|
|
)
|
|
for rank in range(ep_size):
|
|
cuda_streams_all_ranks[rank].synchronize()
|
|
|
|
prepared_local_experts_cpu = prepared_local_experts[:total_recv_token_count].cpu()
|
|
prepared_local_scales_cpu = prepared_local_scales[:total_recv_token_count].cpu()
|
|
for i in range(total_recv_token_count):
|
|
for j in range(top_k):
|
|
expert_id = int(prepared_local_experts_cpu[i][j])
|
|
assert expert_id >= 0 and expert_id <= slot_count
|
|
if expert_id < slot_count:
|
|
assert compute_target_rank(expert_id) == ep_rank
|
|
scale = float(prepared_local_scales_cpu[i][j])
|
|
assert scale > 1e-6
|
|
|
|
gathered_expert_statics_cpu = gathered_expert_statics.cpu()
|
|
for rank in range(ep_size):
|
|
for i in range(expert_count):
|
|
assert int(gathered_expert_statics_cpu[rank][i]) == rank * expert_count + i
|
|
|
|
ref_local_send_rank_count_cumsum = torch.IntTensor(ref_local_send_rank_count_cumsum)
|
|
assert torch.equal(
|
|
local_send_rank_count_cumsum.cpu(), ref_local_send_rank_count_cumsum
|
|
)
|
|
|
|
local_send_rank_indices = local_send_rank_indices.cpu()
|
|
backward_local_recv_rank_indices = backward_local_recv_rank_indices.cpu()
|
|
for i in range(ep_size):
|
|
base = 0 if i == 0 else ref_local_send_rank_count_cumsum[i - 1]
|
|
for j in range(base, ref_local_send_rank_count_cumsum[i]):
|
|
token_id = local_send_rank_indices[j]
|
|
lane_id = backward_local_recv_rank_indices[j] - token_id * top_k
|
|
expert_id = int(cpu_expert_ids_all_ranks_lists[ep_rank][token_id][lane_id])
|
|
assert compute_target_rank(expert_id) == i
|
|
|
|
ref_local_recv_rank_count_cumsum = torch.IntTensor(ref_local_recv_rank_count_cumsum)
|
|
assert torch.equal(
|
|
local_recv_rank_count_cumsum[: ref_local_recv_rank_count_cumsum.size(0)].cpu(),
|
|
ref_local_recv_rank_count_cumsum,
|
|
)
|
|
|
|
ref_local_recv_rank_indices = torch.IntTensor(ref_local_recv_rank_indices)
|
|
assert torch.equal(
|
|
local_recv_rank_indices[: ref_local_recv_rank_indices.size(0)].cpu(),
|
|
ref_local_recv_rank_indices,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|