271 lines
11 KiB
Python
271 lines
11 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import socket
|
|
|
|
import pynvml
|
|
import pytest
|
|
import torch
|
|
|
|
from flashinfer.comm.mapping import Mapping
|
|
from flashinfer.comm.mnnvl import MnnvlMemory, MpiComm
|
|
from flashinfer.comm.trtllm_alltoall import MnnvlMoe, MoEAlltoallInfo
|
|
|
|
pynvml.nvmlInit()
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not MnnvlMemory.supports_mnnvl(),
|
|
reason="Mnnvl memory is not supported on this platform",
|
|
)
|
|
class TestMnnvlMemory:
|
|
@pytest.fixture(autouse=True)
|
|
def setup(self):
|
|
# get num of task per node
|
|
hostname = socket.gethostname()
|
|
self.comm = MpiComm()
|
|
self.world_size = self.comm.Get_size()
|
|
self.rank = self.comm.Get_rank()
|
|
all_hostnames = self.comm.allgather(hostname)
|
|
local_ntasks_per_node = all_hostnames.count(hostname)
|
|
all_ntasks_per_node = self.comm.allgather(local_ntasks_per_node)
|
|
uniform_ntasks = all(x == all_ntasks_per_node[0] for x in all_ntasks_per_node)
|
|
assert uniform_ntasks, "Not all nodes has same ntasks_per_node"
|
|
self.local_world_size = local_ntasks_per_node
|
|
self.local_rank = self.rank % self.local_world_size
|
|
local_dev_count = torch.cuda.device_count()
|
|
assert self.local_world_size <= local_dev_count, (
|
|
"ntasks_per_node should be less than local device count"
|
|
)
|
|
torch.cuda.set_device(self.local_rank)
|
|
MnnvlMemory.initialize()
|
|
self.mapping = Mapping(
|
|
self.world_size, self.rank, self.local_world_size, tp_size=self.world_size
|
|
)
|
|
|
|
@staticmethod
|
|
def align_memory(size: int):
|
|
align_size = 2 * 1024 * 1024
|
|
return (size + align_size - 1) // align_size * align_size
|
|
|
|
@pytest.mark.skipif(
|
|
not MnnvlMemory.supports_mnnvl(),
|
|
reason="Mnnvl memory is not supported on this platform",
|
|
)
|
|
def test_mnnvl_memory(self):
|
|
# allocate un-aligned memory
|
|
allocate0_size = 4 * 1024 * 1024 - 3 * 1024
|
|
mnnvl_memory0 = MnnvlMemory(self.mapping, allocate0_size)
|
|
allocate0_size_aligned = TestMnnvlMemory.align_memory(allocate0_size)
|
|
assert MnnvlMemory.current_mem_offset == allocate0_size_aligned
|
|
|
|
tensor0 = mnnvl_memory0.as_torch_strided_tensor(torch.int32)
|
|
numel_per_rank = allocate0_size // 4
|
|
tensor0[(self.rank + 1) % self.world_size] = torch.arange(
|
|
start=self.rank, end=self.rank + numel_per_rank, device="cuda"
|
|
)
|
|
self.comm.Barrier()
|
|
for r in range(self.world_size):
|
|
torch.equal(
|
|
tensor0[(r + 1) % self.world_size],
|
|
torch.arange(start=r, end=r + numel_per_rank, device="cuda"),
|
|
)
|
|
|
|
allocate1_size = 30 * 1024 * 1024 - 2 * 1024
|
|
mnnvl_memory1 = MnnvlMemory(self.mapping, allocate1_size)
|
|
allocate1_size_aligned = TestMnnvlMemory.align_memory(allocate1_size)
|
|
assert (
|
|
MnnvlMemory.current_mem_offset
|
|
== allocate0_size_aligned + allocate1_size_aligned
|
|
)
|
|
tensor1 = mnnvl_memory1.as_torch_strided_tensor(torch.float32)
|
|
numel_per_rank = allocate1_size // 4
|
|
tensor1[(self.rank + 5) % self.world_size] = torch.arange(
|
|
start=self.rank,
|
|
end=self.rank + numel_per_rank,
|
|
dtype=torch.float32,
|
|
device="cuda",
|
|
)
|
|
self.comm.Barrier()
|
|
for r in range(self.world_size):
|
|
torch.equal(
|
|
tensor1[(r + 5) % self.world_size],
|
|
torch.arange(
|
|
start=r, end=r + numel_per_rank, dtype=torch.float32, device="cuda"
|
|
),
|
|
)
|
|
self.comm.Barrier()
|
|
del tensor0, mnnvl_memory0
|
|
self.comm.Barrier()
|
|
|
|
large_allocation2_size = 768 * 1024 * 1024
|
|
large_mnnvl_memory2 = MnnvlMemory(self.mapping, large_allocation2_size)
|
|
allocate2_size_aligned = TestMnnvlMemory.align_memory(large_allocation2_size)
|
|
assert MnnvlMemory.current_mem_offset == allocate2_size_aligned
|
|
assert large_mnnvl_memory2.rank_stride == (1 << 30)
|
|
|
|
del tensor1
|
|
|
|
@pytest.mark.skipif(
|
|
not MnnvlMemory.supports_mnnvl(),
|
|
reason="Mnnvl memory is not supported on this platform",
|
|
)
|
|
def test_moe_alltoall_multi_rank_single_gpu(self):
|
|
torch.cuda.set_device(self.rank)
|
|
max_world_size = 8
|
|
assert self.world_size <= max_world_size, (
|
|
f"should run with world_size at most {max_world_size}"
|
|
)
|
|
torch.manual_seed(self.world_size)
|
|
input_entry_per_rank, vector_dim, dtype = 128, 256, torch.float16
|
|
|
|
# Create a random input tensor
|
|
input_tensor = torch.randn(
|
|
input_entry_per_rank * self.world_size,
|
|
vector_dim,
|
|
dtype=dtype,
|
|
device=torch.device("cuda"),
|
|
)
|
|
ref_output_tensor = torch.zeros(
|
|
input_entry_per_rank * self.world_size,
|
|
vector_dim,
|
|
dtype=dtype,
|
|
device=torch.device("cuda"),
|
|
)
|
|
target_rank_ids = torch.randint(
|
|
0,
|
|
self.world_size,
|
|
(input_entry_per_rank * self.world_size,),
|
|
dtype=torch.int32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
|
|
input_tensors_all_ranks = list(torch.split(input_tensor, input_entry_per_rank))
|
|
target_rank_ids_all_ranks = list(
|
|
torch.split(target_rank_ids, input_entry_per_rank)
|
|
)
|
|
|
|
send_ids_all_ranks = []
|
|
send_counts_all_ranks = []
|
|
send_cumsum_all_ranks = []
|
|
send_start_end_all_ranks = []
|
|
|
|
# each rank do its own local compute to get how to send data to other ranks.
|
|
for rank in range(self.world_size):
|
|
send_start_end = []
|
|
local_target_rank_ids = target_rank_ids_all_ranks[rank]
|
|
sorted_local_target_rank_ids, local_send_id = torch.sort(
|
|
local_target_rank_ids
|
|
)
|
|
local_send_id = local_send_id.to(torch.int32)
|
|
padded_sorted_local_target_rank_ids = torch.cat(
|
|
(
|
|
sorted_local_target_rank_ids,
|
|
torch.arange(
|
|
self.world_size, dtype=torch.int32, device=torch.device("cuda")
|
|
),
|
|
)
|
|
)
|
|
unique_target_rank_ids, local_send_counts = torch.unique(
|
|
padded_sorted_local_target_rank_ids, return_counts=True
|
|
)
|
|
local_send_counts = local_send_counts.to(torch.int32)
|
|
assert unique_target_rank_ids.numel() == self.world_size, (
|
|
"unique_target_rank_ids must be equal to world_size"
|
|
)
|
|
local_send_counts -= 1 # remove padding
|
|
local_send_cumsum = torch.cumsum(local_send_counts, dim=0).to(torch.int32)
|
|
send_ids_all_ranks.append(local_send_id)
|
|
send_counts_all_ranks.append(local_send_counts)
|
|
send_cumsum_all_ranks.append(local_send_cumsum)
|
|
local_send_cumsum_cpu = local_send_cumsum.cpu().tolist()
|
|
for i in range(len(local_send_cumsum_cpu)):
|
|
send_start_end.append(
|
|
(
|
|
local_send_cumsum_cpu[i - 1] if i > 0 else 0,
|
|
local_send_cumsum_cpu[i],
|
|
)
|
|
)
|
|
send_start_end_all_ranks.append(send_start_end)
|
|
|
|
recv_ids_all_ranks = []
|
|
recv_cumsum_all_ranks = []
|
|
|
|
ref_output_tensors_all_ranks = []
|
|
|
|
total_recv_all_ranks_cpu = []
|
|
output_indice_offset = 0
|
|
|
|
output_start_current_rank = 0
|
|
# each rank do compute based on other ranks' send counts to get how to receive data from other ranks.
|
|
for rank in range(self.world_size):
|
|
local_recv_counts = torch.zeros(
|
|
self.world_size, dtype=torch.int32, device=torch.device("cuda")
|
|
)
|
|
for other_rank in range(self.world_size):
|
|
local_recv_counts[other_rank] = send_counts_all_ranks[other_rank][rank]
|
|
local_recv_count_pair = local_recv_counts[other_rank].cpu().item()
|
|
send_rank_start_end = send_start_end_all_ranks[other_rank][rank]
|
|
ref_output_tensor[
|
|
output_indice_offset : output_indice_offset + local_recv_count_pair
|
|
] = input_tensors_all_ranks[other_rank][
|
|
send_ids_all_ranks[other_rank][
|
|
send_rank_start_end[0] : send_rank_start_end[1]
|
|
]
|
|
]
|
|
output_indice_offset += local_recv_count_pair
|
|
local_recv_cumsum = torch.cumsum(local_recv_counts, dim=0).to(torch.int32)
|
|
recv_cumsum_all_ranks.append(local_recv_cumsum)
|
|
total_recv_count = local_recv_cumsum[-1].cpu()
|
|
total_recv_all_ranks_cpu.append(total_recv_count)
|
|
ref_output_tensors_all_ranks.append(
|
|
ref_output_tensor[
|
|
output_start_current_rank : output_start_current_rank
|
|
+ total_recv_count
|
|
]
|
|
)
|
|
output_start_current_rank += total_recv_count
|
|
local_recv_ids = torch.arange(
|
|
total_recv_count, dtype=torch.int32, device=torch.device("cuda")
|
|
)
|
|
recv_ids_all_ranks.append(local_recv_ids)
|
|
|
|
alltoall_info = MoEAlltoallInfo(
|
|
None,
|
|
send_cumsum_all_ranks[self.rank],
|
|
send_ids_all_ranks[self.rank],
|
|
recv_cumsum_all_ranks[self.rank],
|
|
recv_ids_all_ranks[self.rank],
|
|
None,
|
|
ref_output_tensors_all_ranks[self.rank].shape[0],
|
|
)
|
|
|
|
alltoall_workspace = MnnvlMoe.get_moe_workspaces(self.mapping)
|
|
|
|
self.comm.Barrier()
|
|
|
|
output = MnnvlMoe.mnnvl_moe_alltoallv(
|
|
input_tensors_all_ranks[self.rank],
|
|
alltoall_info,
|
|
alltoall_workspace,
|
|
self.rank,
|
|
self.world_size,
|
|
)
|
|
|
|
self.comm.Barrier()
|
|
|
|
torch.testing.assert_close(
|
|
output, ref_output_tensors_all_ranks[self.rank], atol=1e-5, rtol=1e-5
|
|
)
|