617 lines
18 KiB
Python
617 lines
18 KiB
Python
"""
|
|
Copyright (c) 2025 by FlashInfer team.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
import functools
|
|
from dataclasses import dataclass
|
|
from types import SimpleNamespace
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
|
|
from ..jit import JitSpec
|
|
from ..jit import env as jit_env
|
|
from ..jit import gen_jit_spec
|
|
from ..utils import register_custom_op
|
|
from .mapping import Mapping
|
|
from .mnnvl import MnnvlMemory, MnnvlConfig
|
|
|
|
|
|
def gen_comm_alltoall_module() -> JitSpec:
|
|
return gen_jit_spec(
|
|
"comm",
|
|
[
|
|
jit_env.FLASHINFER_CSRC_DIR / "trtllm_alltoall.cu",
|
|
jit_env.FLASHINFER_CSRC_DIR / "trtllm_alltoall_prepare.cu",
|
|
],
|
|
)
|
|
|
|
|
|
@functools.cache
|
|
def get_comm_alltoall_module():
|
|
module = gen_comm_alltoall_module().build_and_load()
|
|
|
|
@register_custom_op(
|
|
"flashinfer::moe_comm_prepare_indices",
|
|
mutates_args=[],
|
|
)
|
|
def moe_comm_prepare_indices(
|
|
gathered_target_rank_ids: torch.Tensor,
|
|
real_rank_token_count_cum_sum: Optional[torch.Tensor],
|
|
max_token_count_per_rank: int,
|
|
expert_count: int,
|
|
top_k: int,
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
) -> Tuple[
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
]:
|
|
device = gathered_target_rank_ids.device
|
|
max_send_ranks_per_token = max(top_k, ep_size)
|
|
local_gather_indices = torch.empty(
|
|
(max_token_count_per_rank * ep_size), device=device, dtype=torch.int
|
|
)
|
|
send_rank_count_cum_sum = torch.empty(
|
|
(ep_size,), device=device, dtype=torch.int
|
|
)
|
|
send_rank_local_indices = torch.empty(
|
|
(max_token_count_per_rank * max_send_ranks_per_token),
|
|
device=device,
|
|
dtype=torch.int,
|
|
)
|
|
recv_rank_count_cum_sum = torch.empty((ep_size), device=device, dtype=torch.int)
|
|
recv_rank_local_indices = torch.empty(
|
|
(max_token_count_per_rank * ep_size), device=device, dtype=torch.int
|
|
)
|
|
backward_recv_rank_local_indice = torch.empty(
|
|
(max_token_count_per_rank * max_send_ranks_per_token),
|
|
device=device,
|
|
dtype=torch.int,
|
|
)
|
|
module.moe_comm_prepare_indices(
|
|
gathered_target_rank_ids,
|
|
real_rank_token_count_cum_sum,
|
|
local_gather_indices,
|
|
send_rank_count_cum_sum,
|
|
send_rank_local_indices,
|
|
recv_rank_count_cum_sum,
|
|
recv_rank_local_indices,
|
|
backward_recv_rank_local_indice,
|
|
max_token_count_per_rank,
|
|
expert_count,
|
|
top_k,
|
|
ep_rank,
|
|
ep_size,
|
|
)
|
|
return (
|
|
local_gather_indices,
|
|
send_rank_count_cum_sum,
|
|
send_rank_local_indices,
|
|
recv_rank_count_cum_sum,
|
|
recv_rank_local_indices,
|
|
backward_recv_rank_local_indice,
|
|
)
|
|
|
|
@register_custom_op(
|
|
"flashinfer::moe_local_gather",
|
|
mutates_args=["local_expert_ids", "local_scales"],
|
|
)
|
|
def moe_local_gather(
|
|
recv_rank_cum_sum: torch.Tensor,
|
|
local_gather_indices: torch.Tensor,
|
|
gathered_expert_ids: torch.Tensor,
|
|
gathered_scales: torch.Tensor,
|
|
local_expert_ids: torch.Tensor,
|
|
local_scales: torch.Tensor,
|
|
max_token_count_per_rank: int,
|
|
expert_count: int,
|
|
top_k: int,
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
) -> None:
|
|
module.moe_local_gather(
|
|
recv_rank_cum_sum,
|
|
local_gather_indices,
|
|
gathered_expert_ids,
|
|
gathered_scales,
|
|
local_expert_ids,
|
|
local_scales,
|
|
max_token_count_per_rank,
|
|
expert_count,
|
|
top_k,
|
|
ep_rank,
|
|
ep_size,
|
|
)
|
|
|
|
@register_custom_op(
|
|
"flashinfer::moe_comm",
|
|
mutates_args=["output"],
|
|
)
|
|
def moe_comm(
|
|
input: torch.Tensor,
|
|
send_rank_cum_sum: torch.Tensor,
|
|
send_indices: torch.Tensor,
|
|
output: torch.Tensor,
|
|
recv_rank_cum_sum: torch.Tensor,
|
|
recv_indices: torch.Tensor,
|
|
all_workspaces: torch.Tensor,
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
) -> None:
|
|
module.moe_comm(
|
|
input,
|
|
send_rank_cum_sum,
|
|
send_indices,
|
|
output,
|
|
recv_rank_cum_sum,
|
|
recv_indices,
|
|
all_workspaces,
|
|
ep_rank,
|
|
ep_size,
|
|
)
|
|
|
|
@register_custom_op(
|
|
"flashinfer::set_moe_max_usable_sm_count",
|
|
mutates_args=[],
|
|
)
|
|
def set_moe_max_usable_sm_count(
|
|
max_sm_count: int,
|
|
) -> None:
|
|
module.set_moe_max_usable_sm_count(max_sm_count)
|
|
|
|
@register_custom_op(
|
|
"flashinfer::get_moe_commworkspace_size_per_rank",
|
|
mutates_args=[],
|
|
)
|
|
def get_moe_commworkspace_size_per_rank(
|
|
ep_size: int,
|
|
) -> int:
|
|
return module.get_moe_commworkspace_size_per_rank(ep_size)
|
|
|
|
@register_custom_op(
|
|
"flashinfer::get_moe_prepare_workspace_size_per_rank",
|
|
mutates_args=[],
|
|
)
|
|
def get_moe_prepare_workspace_size_per_rank(
|
|
ep_size: int,
|
|
) -> int:
|
|
return module.get_moe_prepare_workspace_size_per_rank(ep_size)
|
|
|
|
@register_custom_op(
|
|
"flashinfer::moe_prepare",
|
|
mutates_args=[],
|
|
)
|
|
def moe_prepare(
|
|
experts_ids: torch.Tensor,
|
|
scales: Optional[torch.Tensor],
|
|
experts_statics: Optional[torch.Tensor],
|
|
workspace: torch.Tensor,
|
|
max_token_count_per_rank: int,
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
expert_count: int,
|
|
slot_count: int,
|
|
top_k: int,
|
|
) -> Tuple[
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
]:
|
|
return module.moe_prepare(
|
|
experts_ids,
|
|
scales,
|
|
experts_statics,
|
|
workspace,
|
|
max_token_count_per_rank,
|
|
ep_rank,
|
|
ep_size,
|
|
expert_count,
|
|
slot_count,
|
|
top_k,
|
|
)
|
|
|
|
return SimpleNamespace(
|
|
moe_comm_prepare_indices=moe_comm_prepare_indices,
|
|
moe_local_gather=moe_local_gather,
|
|
moe_comm=moe_comm,
|
|
set_moe_max_usable_sm_count=set_moe_max_usable_sm_count,
|
|
get_moe_commworkspace_size_per_rank=get_moe_commworkspace_size_per_rank,
|
|
get_moe_prepare_workspace_size_per_rank=get_moe_prepare_workspace_size_per_rank,
|
|
moe_prepare=moe_prepare,
|
|
)
|
|
|
|
|
|
def moe_comm_prepare_indices(
|
|
gathered_target_rank_ids: torch.Tensor,
|
|
real_rank_token_count_cum_sum: Optional[torch.Tensor],
|
|
max_token_count_per_rank: int,
|
|
expert_count: int,
|
|
top_k: int,
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
) -> Tuple[
|
|
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
|
|
]:
|
|
return get_comm_alltoall_module().moe_comm_prepare_indices(
|
|
gathered_target_rank_ids,
|
|
real_rank_token_count_cum_sum,
|
|
max_token_count_per_rank,
|
|
expert_count,
|
|
top_k,
|
|
ep_rank,
|
|
ep_size,
|
|
)
|
|
|
|
|
|
def moe_local_gather(
|
|
recv_rank_cum_sum: torch.Tensor,
|
|
local_gather_indices: torch.Tensor,
|
|
gathered_expert_ids: torch.Tensor,
|
|
gathered_scales: torch.Tensor,
|
|
local_expert_ids: torch.Tensor,
|
|
local_scales: torch.Tensor,
|
|
max_token_count_per_rank: int,
|
|
expert_count: int,
|
|
top_k: int,
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
) -> None:
|
|
get_comm_alltoall_module().moe_local_gather(
|
|
recv_rank_cum_sum,
|
|
local_gather_indices,
|
|
gathered_expert_ids,
|
|
gathered_scales,
|
|
local_expert_ids,
|
|
local_scales,
|
|
max_token_count_per_rank,
|
|
expert_count,
|
|
top_k,
|
|
ep_rank,
|
|
ep_size,
|
|
)
|
|
|
|
|
|
def moe_comm(
|
|
input: torch.Tensor,
|
|
send_rank_cum_sum: torch.Tensor,
|
|
send_indices: torch.Tensor,
|
|
output: torch.Tensor,
|
|
recv_rank_cum_sum: torch.Tensor,
|
|
recv_indices: torch.Tensor,
|
|
all_workspaces: torch.Tensor,
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
) -> None:
|
|
get_comm_alltoall_module().moe_comm(
|
|
input,
|
|
send_rank_cum_sum,
|
|
send_indices,
|
|
output,
|
|
recv_rank_cum_sum,
|
|
recv_indices,
|
|
all_workspaces,
|
|
ep_rank,
|
|
ep_size,
|
|
)
|
|
|
|
|
|
def set_moe_max_usable_sm_count(
|
|
max_sm_count: int,
|
|
) -> None:
|
|
get_comm_alltoall_module().set_moe_max_usable_sm_count(max_sm_count)
|
|
|
|
|
|
def get_moe_commworkspace_size_per_rank(
|
|
ep_size: int,
|
|
) -> int:
|
|
return get_comm_alltoall_module().get_moe_commworkspace_size_per_rank(ep_size)
|
|
|
|
|
|
def get_moe_prepare_workspace_size_per_rank(
|
|
ep_size: int,
|
|
) -> int:
|
|
return get_comm_alltoall_module().get_moe_prepare_workspace_size_per_rank(ep_size)
|
|
|
|
|
|
def moe_prepare(
|
|
experts_ids: torch.Tensor,
|
|
scales: Optional[torch.Tensor],
|
|
experts_statics: Optional[torch.Tensor],
|
|
workspace: torch.Tensor,
|
|
max_token_count_per_rank: int,
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
expert_count: int,
|
|
slot_count: int,
|
|
top_k: int,
|
|
) -> Tuple[
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
torch.Tensor,
|
|
]:
|
|
return get_comm_alltoall_module().moe_prepare(
|
|
experts_ids,
|
|
scales,
|
|
experts_statics,
|
|
workspace,
|
|
max_token_count_per_rank,
|
|
ep_rank,
|
|
ep_size,
|
|
expert_count,
|
|
slot_count,
|
|
top_k,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class MoEAlltoallInfo:
|
|
local_gather_indices: torch.Tensor
|
|
send_rank_count_cumsum: torch.Tensor
|
|
send_rank_local_indices: torch.Tensor
|
|
recv_rank_count_cumsum: torch.Tensor
|
|
recv_rank_local_indices: torch.Tensor
|
|
backward_recv_rank_local_indices: torch.Tensor
|
|
local_token_allocation_count: int
|
|
|
|
|
|
class MnnvlMoe:
|
|
moe_workspace: MnnvlMemory = None
|
|
moe_prepare_workspace: MnnvlMemory = None
|
|
moe_workspace_tensor: torch.Tensor = None
|
|
moe_prepare_workspace_tensor: torch.Tensor = None
|
|
moe_mapping: Mapping = None
|
|
|
|
@staticmethod
|
|
def get_moe_workspaces(mapping: Mapping, config: Optional[MnnvlConfig] = None):
|
|
if MnnvlMoe.moe_workspace is not None:
|
|
assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now"
|
|
return MnnvlMoe.moe_workspace_tensor
|
|
|
|
MnnvlMoe.moe_mapping = mapping
|
|
workspace_size_per_rank = get_moe_commworkspace_size_per_rank(mapping.tp_size)
|
|
if config:
|
|
MnnvlMemory.set_comm_from_config(mapping, config) # type: ignore[attr-defined]
|
|
MnnvlMoe.moe_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
|
|
MnnvlMoe.moe_workspace_tensor = MnnvlMoe.moe_workspace.as_torch_strided_tensor(
|
|
torch.uint64
|
|
)
|
|
return MnnvlMoe.moe_workspace_tensor
|
|
|
|
@staticmethod
|
|
def get_moe_prepare_workspace(
|
|
mapping: Mapping, config: Optional[MnnvlConfig] = None
|
|
):
|
|
if MnnvlMoe.moe_prepare_workspace_tensor is not None:
|
|
assert mapping == MnnvlMoe.moe_mapping, "only one moe mapping supported now"
|
|
return MnnvlMoe.moe_prepare_workspace_tensor
|
|
workspace_size_per_rank = get_moe_prepare_workspace_size_per_rank(
|
|
mapping.tp_size
|
|
)
|
|
if config:
|
|
MnnvlMemory.set_comm_from_config(mapping, config) # type: ignore[attr-defined]
|
|
MnnvlMoe.moe_prepare_workspace = MnnvlMemory(mapping, workspace_size_per_rank)
|
|
MnnvlMoe.moe_prepare_workspace_tensor = (
|
|
MnnvlMoe.moe_prepare_workspace.as_torch_strided_tensor(torch.uint64)
|
|
)
|
|
return MnnvlMoe.moe_prepare_workspace_tensor
|
|
|
|
@staticmethod
|
|
def compute_target_rank_id(
|
|
token_selected_experts: torch.Tensor, expert_count: int, ep_size: int
|
|
):
|
|
assert expert_count % ep_size == 0, (
|
|
"expert_count should be divisible by ep_size"
|
|
)
|
|
expert_per_rank = expert_count // ep_size
|
|
token_target_rank_ids = token_selected_experts // expert_per_rank
|
|
return token_target_rank_ids
|
|
|
|
@staticmethod
|
|
def mnnvl_moe_alltoallv_prepare_without_allgather(
|
|
expert_ids: torch.Tensor,
|
|
scales: torch.Tensor,
|
|
expert_statics: Optional[torch.Tensor],
|
|
workspace: torch.Tensor,
|
|
max_token_count_per_rank: int,
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
expert_count: int,
|
|
slot_count: int,
|
|
top_k: int,
|
|
):
|
|
(
|
|
prepared_local_experts,
|
|
prepared_local_scales,
|
|
local_send_rank_count_cumsum,
|
|
local_send_rank_indices,
|
|
local_recv_rank_count_cumsum,
|
|
local_recv_rank_indices,
|
|
backward_local_recv_rank_indices,
|
|
gathered_expert_statics,
|
|
) = moe_prepare(
|
|
expert_ids,
|
|
scales,
|
|
expert_statics,
|
|
workspace,
|
|
max_token_count_per_rank,
|
|
ep_rank,
|
|
ep_size,
|
|
expert_count,
|
|
slot_count,
|
|
top_k,
|
|
)
|
|
|
|
local_token_allocation_count = max_token_count_per_rank * ep_size
|
|
# Looks like we don't need this.
|
|
local_gather_indices = None
|
|
|
|
alltoall_info = MoEAlltoallInfo(
|
|
local_gather_indices,
|
|
local_send_rank_count_cumsum,
|
|
local_send_rank_indices,
|
|
local_recv_rank_count_cumsum,
|
|
local_recv_rank_indices,
|
|
backward_local_recv_rank_indices,
|
|
local_token_allocation_count,
|
|
)
|
|
|
|
return (
|
|
alltoall_info,
|
|
prepared_local_experts,
|
|
prepared_local_scales,
|
|
gathered_expert_statics,
|
|
)
|
|
|
|
@staticmethod
|
|
def mnnvl_moe_alltoallv_prepare(
|
|
gathered_target_rank_ids: torch.Tensor,
|
|
real_rank_token_count_cumsum: torch.Tensor,
|
|
gathered_expert_ids: torch.Tensor,
|
|
gathered_scales: torch.Tensor,
|
|
max_token_count_per_rank: int,
|
|
expert_count: int,
|
|
top_k: int,
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
):
|
|
(
|
|
local_gather_indices,
|
|
send_rank_count_cumsum,
|
|
send_rank_local_indices,
|
|
recv_rank_count_cumsum,
|
|
recv_rank_local_indices,
|
|
backward_recv_rank_local_indices,
|
|
) = moe_comm_prepare_indices(
|
|
gathered_target_rank_ids,
|
|
real_rank_token_count_cumsum,
|
|
max_token_count_per_rank,
|
|
expert_count,
|
|
top_k,
|
|
ep_rank,
|
|
ep_size,
|
|
)
|
|
|
|
local_token_allocation_count = max_token_count_per_rank * ep_size
|
|
|
|
local_expert_ids = torch.empty(
|
|
local_token_allocation_count,
|
|
top_k,
|
|
dtype=torch.int32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
local_scales = torch.empty(
|
|
local_token_allocation_count,
|
|
top_k,
|
|
dtype=torch.float32,
|
|
device=torch.device("cuda"),
|
|
)
|
|
|
|
moe_local_gather(
|
|
recv_rank_count_cumsum,
|
|
local_gather_indices,
|
|
gathered_expert_ids,
|
|
gathered_scales,
|
|
local_expert_ids,
|
|
local_scales,
|
|
max_token_count_per_rank,
|
|
expert_count,
|
|
top_k,
|
|
ep_rank,
|
|
ep_size,
|
|
)
|
|
|
|
alltoall_info = MoEAlltoallInfo(
|
|
local_gather_indices,
|
|
send_rank_count_cumsum,
|
|
send_rank_local_indices,
|
|
recv_rank_count_cumsum,
|
|
recv_rank_local_indices,
|
|
backward_recv_rank_local_indices,
|
|
local_token_allocation_count,
|
|
)
|
|
return alltoall_info, local_expert_ids, local_scales
|
|
|
|
@staticmethod
|
|
def mnnvl_moe_alltoallv(
|
|
x: torch.Tensor,
|
|
alltoall_info: MoEAlltoallInfo,
|
|
workspace: torch.Tensor,
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
):
|
|
assert x.dim() == 2, "only 2D tensor supported, please reshape."
|
|
output_tensor = torch.empty(
|
|
alltoall_info.local_token_allocation_count,
|
|
x.shape[1],
|
|
dtype=x.dtype,
|
|
device=torch.device("cuda"),
|
|
)
|
|
moe_comm(
|
|
x,
|
|
alltoall_info.send_rank_count_cumsum,
|
|
alltoall_info.send_rank_local_indices,
|
|
output_tensor,
|
|
alltoall_info.recv_rank_count_cumsum,
|
|
alltoall_info.recv_rank_local_indices,
|
|
workspace,
|
|
ep_rank,
|
|
ep_size,
|
|
)
|
|
return output_tensor
|
|
|
|
@staticmethod
|
|
def mnnvl_moe_alltoallv_combine(
|
|
x: torch.Tensor,
|
|
alltoall_info: MoEAlltoallInfo,
|
|
workspace: torch.Tensor,
|
|
ep_rank: int,
|
|
ep_size: int,
|
|
top_k: int,
|
|
token_count: int,
|
|
):
|
|
assert x.dim() == 2, "2D tensor supported, please reshape."
|
|
output_tensor = torch.zeros(
|
|
token_count * top_k, x.shape[1], dtype=x.dtype, device=torch.device("cuda")
|
|
)
|
|
moe_comm(
|
|
x,
|
|
alltoall_info.recv_rank_count_cumsum,
|
|
alltoall_info.recv_rank_local_indices,
|
|
output_tensor,
|
|
alltoall_info.send_rank_count_cumsum,
|
|
alltoall_info.backward_recv_rank_local_indices,
|
|
workspace,
|
|
ep_rank,
|
|
ep_size,
|
|
)
|
|
return torch.sum(
|
|
output_tensor.reshape(token_count, top_k, x.shape[1]), dim=1, keepdim=False
|
|
)
|