sglang_v0.5.2/flashinfer_0.3.1/tests/test_mnnvl_memory.py

271 lines
11 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import pynvml
import pytest
import torch
from flashinfer.comm.mapping import Mapping
from flashinfer.comm.mnnvl import MnnvlMemory, MpiComm
from flashinfer.comm.trtllm_alltoall import MnnvlMoe, MoEAlltoallInfo
pynvml.nvmlInit()
@pytest.mark.skipif(
not MnnvlMemory.supports_mnnvl(),
reason="Mnnvl memory is not supported on this platform",
)
class TestMnnvlMemory:
@pytest.fixture(autouse=True)
def setup(self):
# get num of task per node
hostname = socket.gethostname()
self.comm = MpiComm()
self.world_size = self.comm.Get_size()
self.rank = self.comm.Get_rank()
all_hostnames = self.comm.allgather(hostname)
local_ntasks_per_node = all_hostnames.count(hostname)
all_ntasks_per_node = self.comm.allgather(local_ntasks_per_node)
uniform_ntasks = all(x == all_ntasks_per_node[0] for x in all_ntasks_per_node)
assert uniform_ntasks, "Not all nodes has same ntasks_per_node"
self.local_world_size = local_ntasks_per_node
self.local_rank = self.rank % self.local_world_size
local_dev_count = torch.cuda.device_count()
assert self.local_world_size <= local_dev_count, (
"ntasks_per_node should be less than local device count"
)
torch.cuda.set_device(self.local_rank)
MnnvlMemory.initialize()
self.mapping = Mapping(
self.world_size, self.rank, self.local_world_size, tp_size=self.world_size
)
@staticmethod
def align_memory(size: int):
align_size = 2 * 1024 * 1024
return (size + align_size - 1) // align_size * align_size
@pytest.mark.skipif(
not MnnvlMemory.supports_mnnvl(),
reason="Mnnvl memory is not supported on this platform",
)
def test_mnnvl_memory(self):
# allocate un-aligned memory
allocate0_size = 4 * 1024 * 1024 - 3 * 1024
mnnvl_memory0 = MnnvlMemory(self.mapping, allocate0_size)
allocate0_size_aligned = TestMnnvlMemory.align_memory(allocate0_size)
assert MnnvlMemory.current_mem_offset == allocate0_size_aligned
tensor0 = mnnvl_memory0.as_torch_strided_tensor(torch.int32)
numel_per_rank = allocate0_size // 4
tensor0[(self.rank + 1) % self.world_size] = torch.arange(
start=self.rank, end=self.rank + numel_per_rank, device="cuda"
)
self.comm.Barrier()
for r in range(self.world_size):
torch.equal(
tensor0[(r + 1) % self.world_size],
torch.arange(start=r, end=r + numel_per_rank, device="cuda"),
)
allocate1_size = 30 * 1024 * 1024 - 2 * 1024
mnnvl_memory1 = MnnvlMemory(self.mapping, allocate1_size)
allocate1_size_aligned = TestMnnvlMemory.align_memory(allocate1_size)
assert (
MnnvlMemory.current_mem_offset
== allocate0_size_aligned + allocate1_size_aligned
)
tensor1 = mnnvl_memory1.as_torch_strided_tensor(torch.float32)
numel_per_rank = allocate1_size // 4
tensor1[(self.rank + 5) % self.world_size] = torch.arange(
start=self.rank,
end=self.rank + numel_per_rank,
dtype=torch.float32,
device="cuda",
)
self.comm.Barrier()
for r in range(self.world_size):
torch.equal(
tensor1[(r + 5) % self.world_size],
torch.arange(
start=r, end=r + numel_per_rank, dtype=torch.float32, device="cuda"
),
)
self.comm.Barrier()
del tensor0, mnnvl_memory0
self.comm.Barrier()
large_allocation2_size = 768 * 1024 * 1024
large_mnnvl_memory2 = MnnvlMemory(self.mapping, large_allocation2_size)
allocate2_size_aligned = TestMnnvlMemory.align_memory(large_allocation2_size)
assert MnnvlMemory.current_mem_offset == allocate2_size_aligned
assert large_mnnvl_memory2.rank_stride == (1 << 30)
del tensor1
@pytest.mark.skipif(
not MnnvlMemory.supports_mnnvl(),
reason="Mnnvl memory is not supported on this platform",
)
def test_moe_alltoall_multi_rank_single_gpu(self):
torch.cuda.set_device(self.rank)
max_world_size = 8
assert self.world_size <= max_world_size, (
f"should run with world_size at most {max_world_size}"
)
torch.manual_seed(self.world_size)
input_entry_per_rank, vector_dim, dtype = 128, 256, torch.float16
# Create a random input tensor
input_tensor = torch.randn(
input_entry_per_rank * self.world_size,
vector_dim,
dtype=dtype,
device=torch.device("cuda"),
)
ref_output_tensor = torch.zeros(
input_entry_per_rank * self.world_size,
vector_dim,
dtype=dtype,
device=torch.device("cuda"),
)
target_rank_ids = torch.randint(
0,
self.world_size,
(input_entry_per_rank * self.world_size,),
dtype=torch.int32,
device=torch.device("cuda"),
)
input_tensors_all_ranks = list(torch.split(input_tensor, input_entry_per_rank))
target_rank_ids_all_ranks = list(
torch.split(target_rank_ids, input_entry_per_rank)
)
send_ids_all_ranks = []
send_counts_all_ranks = []
send_cumsum_all_ranks = []
send_start_end_all_ranks = []
# each rank do its own local compute to get how to send data to other ranks.
for rank in range(self.world_size):
send_start_end = []
local_target_rank_ids = target_rank_ids_all_ranks[rank]
sorted_local_target_rank_ids, local_send_id = torch.sort(
local_target_rank_ids
)
local_send_id = local_send_id.to(torch.int32)
padded_sorted_local_target_rank_ids = torch.cat(
(
sorted_local_target_rank_ids,
torch.arange(
self.world_size, dtype=torch.int32, device=torch.device("cuda")
),
)
)
unique_target_rank_ids, local_send_counts = torch.unique(
padded_sorted_local_target_rank_ids, return_counts=True
)
local_send_counts = local_send_counts.to(torch.int32)
assert unique_target_rank_ids.numel() == self.world_size, (
"unique_target_rank_ids must be equal to world_size"
)
local_send_counts -= 1 # remove padding
local_send_cumsum = torch.cumsum(local_send_counts, dim=0).to(torch.int32)
send_ids_all_ranks.append(local_send_id)
send_counts_all_ranks.append(local_send_counts)
send_cumsum_all_ranks.append(local_send_cumsum)
local_send_cumsum_cpu = local_send_cumsum.cpu().tolist()
for i in range(len(local_send_cumsum_cpu)):
send_start_end.append(
(
local_send_cumsum_cpu[i - 1] if i > 0 else 0,
local_send_cumsum_cpu[i],
)
)
send_start_end_all_ranks.append(send_start_end)
recv_ids_all_ranks = []
recv_cumsum_all_ranks = []
ref_output_tensors_all_ranks = []
total_recv_all_ranks_cpu = []
output_indice_offset = 0
output_start_current_rank = 0
# each rank do compute based on other ranks' send counts to get how to receive data from other ranks.
for rank in range(self.world_size):
local_recv_counts = torch.zeros(
self.world_size, dtype=torch.int32, device=torch.device("cuda")
)
for other_rank in range(self.world_size):
local_recv_counts[other_rank] = send_counts_all_ranks[other_rank][rank]
local_recv_count_pair = local_recv_counts[other_rank].cpu().item()
send_rank_start_end = send_start_end_all_ranks[other_rank][rank]
ref_output_tensor[
output_indice_offset : output_indice_offset + local_recv_count_pair
] = input_tensors_all_ranks[other_rank][
send_ids_all_ranks[other_rank][
send_rank_start_end[0] : send_rank_start_end[1]
]
]
output_indice_offset += local_recv_count_pair
local_recv_cumsum = torch.cumsum(local_recv_counts, dim=0).to(torch.int32)
recv_cumsum_all_ranks.append(local_recv_cumsum)
total_recv_count = local_recv_cumsum[-1].cpu()
total_recv_all_ranks_cpu.append(total_recv_count)
ref_output_tensors_all_ranks.append(
ref_output_tensor[
output_start_current_rank : output_start_current_rank
+ total_recv_count
]
)
output_start_current_rank += total_recv_count
local_recv_ids = torch.arange(
total_recv_count, dtype=torch.int32, device=torch.device("cuda")
)
recv_ids_all_ranks.append(local_recv_ids)
alltoall_info = MoEAlltoallInfo(
None,
send_cumsum_all_ranks[self.rank],
send_ids_all_ranks[self.rank],
recv_cumsum_all_ranks[self.rank],
recv_ids_all_ranks[self.rank],
None,
ref_output_tensors_all_ranks[self.rank].shape[0],
)
alltoall_workspace = MnnvlMoe.get_moe_workspaces(self.mapping)
self.comm.Barrier()
output = MnnvlMoe.mnnvl_moe_alltoallv(
input_tensors_all_ranks[self.rank],
alltoall_info,
alltoall_workspace,
self.rank,
self.world_size,
)
self.comm.Barrier()
torch.testing.assert_close(
output, ref_output_tensors_all_ranks[self.rank], atol=1e-5, rtol=1e-5
)