417 lines
14 KiB
Python
417 lines
14 KiB
Python
try:
|
|
from deep_ep import Buffer
|
|
|
|
use_deepep = True
|
|
except ImportError:
|
|
use_deepep = False
|
|
|
|
from typing import Optional, Tuple
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
deepep_permute_triton_kernel,
|
|
deepep_post_reorder_triton_kernel,
|
|
deepep_run_moe_deep_preprocess,
|
|
)
|
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
|
|
_buffer_normal = None
|
|
_buffer_low_latency = None
|
|
|
|
|
|
def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
|
"""
|
|
Copy from DeepEP example usage in model inference prefilling.
|
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
|
|
"""
|
|
|
|
global _buffer_normal
|
|
|
|
num_nvl_bytes, num_rdma_bytes = 0, 0
|
|
for config in (
|
|
Buffer.get_dispatch_config(group.size()),
|
|
Buffer.get_combine_config(group.size()),
|
|
):
|
|
num_nvl_bytes = max(
|
|
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
|
|
)
|
|
num_rdma_bytes = max(
|
|
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
|
|
)
|
|
|
|
if (
|
|
_buffer_normal is None
|
|
or _buffer_normal.group != group
|
|
or _buffer_normal.num_nvl_bytes < num_nvl_bytes
|
|
or _buffer_normal.num_rdma_bytes < num_rdma_bytes
|
|
):
|
|
_buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes)
|
|
return _buffer_normal
|
|
|
|
|
|
def get_buffer_low_latency(
|
|
group: dist.ProcessGroup,
|
|
num_max_dispatch_tokens_per_rank: int,
|
|
hidden: int,
|
|
num_experts: int,
|
|
):
|
|
"""
|
|
Copy from DeepEP example usage in model inference decoding.
|
|
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
|
"""
|
|
|
|
global _buffer_low_latency
|
|
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
|
|
num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
|
|
)
|
|
|
|
if (
|
|
_buffer_low_latency is None
|
|
or _buffer_low_latency.group != group
|
|
or not _buffer_low_latency.low_latency_mode
|
|
or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes
|
|
):
|
|
assert num_experts % group.size() == 0
|
|
_buffer_low_latency = Buffer(
|
|
group,
|
|
0,
|
|
num_rdma_bytes,
|
|
low_latency_mode=True,
|
|
num_qps_per_rank=num_experts // group.size(),
|
|
)
|
|
return _buffer_low_latency
|
|
|
|
|
|
class DeepEPDispatcher:
|
|
"""
|
|
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
|
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
group: torch.distributed.ProcessGroup,
|
|
router_topk: int,
|
|
permute_fusion: bool = False,
|
|
capacity_factor: float = None,
|
|
num_experts: int = None,
|
|
num_local_experts: int = None,
|
|
hidden_size: int = None,
|
|
params_dtype: torch.dtype = None,
|
|
async_finish: bool = False,
|
|
):
|
|
self.group = group
|
|
self.router_topk = router_topk
|
|
self.capacity_factor = capacity_factor
|
|
self.permute_fusion = permute_fusion
|
|
self.num_experts = num_experts
|
|
self.num_local_experts = num_local_experts
|
|
self.hidden_size = hidden_size
|
|
self.recv_expert_count = None
|
|
self.params_dtype = params_dtype
|
|
self.params_bytes = 2
|
|
# Metadata
|
|
self.token_indices = None
|
|
self.token_probs = None
|
|
# Handle used for combine operation
|
|
self.handle = None
|
|
self.async_finish = async_finish
|
|
|
|
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
|
|
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
|
self.num_max_dispatch_tokens_per_rank = 128
|
|
|
|
if not use_deepep:
|
|
raise ImportError(
|
|
"DeepEP is not installed. Please install DeepEP package from "
|
|
"https://github.com/deepseek-ai/deepep."
|
|
)
|
|
self.buffer_normal = get_buffer_normal(
|
|
self.group, self.hidden_size * self.params_bytes
|
|
)
|
|
self.buffer_low_latency = None
|
|
# Todo: enable low latency dispatch
|
|
"""
|
|
self.buffer_low_latency = get_buffer_low_latency(
|
|
self.group,
|
|
self.num_max_dispatch_tokens_per_rank,
|
|
self.hidden_size * self.params_bytes,
|
|
self.num_experts,
|
|
)
|
|
"""
|
|
|
|
def deepep_permute(
|
|
self,
|
|
hidden_states,
|
|
fp8_dtype=None,
|
|
use_fp8_w8a8=False,
|
|
use_block_quant=False,
|
|
):
|
|
reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
|
self.topk_idx, self.num_experts
|
|
)
|
|
num_total_tokens = reorder_topk_ids.numel()
|
|
gateup_input = torch.empty(
|
|
(int(num_total_tokens), hidden_states.shape[1]),
|
|
device=hidden_states.device,
|
|
dtype=(
|
|
fp8_dtype
|
|
if (use_fp8_w8a8 and not use_block_quant)
|
|
else hidden_states.dtype
|
|
),
|
|
)
|
|
# PreReorder
|
|
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
|
hidden_states,
|
|
gateup_input,
|
|
src2dst,
|
|
self.topk_idx,
|
|
None,
|
|
self.router_topk,
|
|
hidden_states.shape[1],
|
|
BLOCK_SIZE=512,
|
|
)
|
|
self.src2dst = src2dst
|
|
return reorder_topk_ids, seg_indptr, gateup_input
|
|
|
|
def dispatch(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
num_experts: int,
|
|
forward_mode: ForwardMode,
|
|
num_max_dispatch_tokens_per_rank: int = 128,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
topk_idx = topk_idx.to(torch.int64)
|
|
# Todo: enable low latency dispatch
|
|
if True: # not forward_mode.is_decode():
|
|
(
|
|
hidden_states,
|
|
topk_idx,
|
|
topk_weights,
|
|
num_recv_tokens_per_expert_list,
|
|
handle,
|
|
event,
|
|
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
|
self.tokens_per_expert = torch.tensor(
|
|
num_recv_tokens_per_expert_list,
|
|
device=hidden_states.device,
|
|
dtype=torch.int64,
|
|
)
|
|
else:
|
|
hidden_states, recv_expert_count, handle, event, hook = (
|
|
self.dispatch_low_latency(
|
|
hidden_states,
|
|
topk_idx,
|
|
num_max_dispatch_tokens_per_rank,
|
|
num_experts,
|
|
)
|
|
)
|
|
self.recv_expert_count = recv_expert_count
|
|
|
|
if self.async_finish:
|
|
event.current_stream_wait()
|
|
|
|
self.handle = handle
|
|
self.topk_idx = topk_idx
|
|
self.topk_weights = topk_weights
|
|
if hidden_states.shape[0] > 0:
|
|
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
|
|
hidden_states, fp8_dtype=hidden_states.dtype
|
|
)
|
|
else:
|
|
reorder_topk_ids = torch.empty(
|
|
(0,), device=hidden_states.device, dtype=torch.int64
|
|
)
|
|
seg_indptr = torch.zeros(
|
|
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
|
)
|
|
return hidden_states, reorder_topk_ids, seg_indptr
|
|
|
|
def dispatch_normal(
|
|
self,
|
|
x: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
num_experts: int,
|
|
):
|
|
previous_event = Buffer.capture() if self.async_finish else None
|
|
|
|
(
|
|
num_tokens_per_rank,
|
|
num_tokens_per_rdma_rank,
|
|
num_tokens_per_expert,
|
|
is_token_in_rank,
|
|
previous_event,
|
|
) = self.buffer_normal.get_dispatch_layout(
|
|
topk_idx,
|
|
num_experts,
|
|
previous_event=previous_event,
|
|
async_finish=self.async_finish,
|
|
allocate_on_comm_stream=previous_event is not None,
|
|
)
|
|
|
|
(
|
|
recv_x,
|
|
recv_topk_idx,
|
|
recv_topk_weights,
|
|
num_recv_tokens_per_expert_list,
|
|
handle,
|
|
event,
|
|
) = self.buffer_normal.dispatch(
|
|
x,
|
|
topk_idx=topk_idx,
|
|
topk_weights=topk_weights,
|
|
num_tokens_per_rank=num_tokens_per_rank,
|
|
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
|
is_token_in_rank=is_token_in_rank,
|
|
num_tokens_per_expert=num_tokens_per_expert,
|
|
previous_event=previous_event,
|
|
async_finish=self.async_finish,
|
|
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
|
)
|
|
|
|
return (
|
|
recv_x,
|
|
recv_topk_idx,
|
|
recv_topk_weights,
|
|
num_recv_tokens_per_expert_list,
|
|
handle,
|
|
event,
|
|
)
|
|
|
|
def dispatch_low_latency(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
num_max_dispatch_tokens_per_rank: int,
|
|
num_experts: int,
|
|
):
|
|
"""
|
|
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'
|
|
# Please please make sure to change DeepEP code in internode_ll.cu dispatch / combine first and then reinstall!
|
|
# More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
|
|
+
|
|
diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
|
|
index f60e933..cddaabf 100644
|
|
--- a/csrc/kernels/internode_ll.cu
|
|
+++ b/csrc/kernels/internode_ll.cu
|
|
@@ -307,14 +307,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
|
int num_topk, int num_experts, int rank, int num_ranks,
|
|
void* workspace, cudaStream_t stream, int phases) {
|
|
constexpr int kNumMaxTopK = 9;
|
|
- constexpr int kNumWarpsPerGroup = 10;
|
|
- constexpr int kNumWarpGroups = 3;
|
|
+ constexpr int kNumWarpsPerGroup = 8;
|
|
+ constexpr int kNumWarpGroups = 4;
|
|
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
|
+
|
|
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
|
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
|
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
|
- EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
|
+ // EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
|
+
|
|
// Workspace checks
|
|
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
|
@@ -505,8 +505,8 @@ void combine(void* combined_x,
|
|
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
|
int num_topk, int num_experts, int rank, int num_ranks,
|
|
void* workspace, cudaStream_t stream, int phases) {
|
|
- constexpr int kNumWarpsPerGroup = 10;
|
|
- constexpr int kNumWarpGroups = 3;
|
|
+ constexpr int kNumWarpsPerGroup = 8;
|
|
+ constexpr int kNumWarpGroups = 4;
|
|
constexpr int kNumMaxTopk = 9;
|
|
+
|
|
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
|
"""
|
|
|
|
recv_hidden_states, recv_expert_count, handle, event, hook = (
|
|
self.buffer_low_latency.low_latency_dispatch(
|
|
hidden_states,
|
|
topk_idx,
|
|
num_max_dispatch_tokens_per_rank,
|
|
num_experts,
|
|
async_finish=self.async_finish,
|
|
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
|
)
|
|
)
|
|
# hook()
|
|
return recv_hidden_states, recv_expert_count, handle, event, hook
|
|
|
|
def combine(
|
|
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
|
# Todo: enable low latency combine
|
|
if True: # not forward_mode.is_decode():
|
|
if hidden_states.shape[0] > 0:
|
|
num_tokens = self.src2dst.shape[0] // self.router_topk
|
|
output = torch.empty(
|
|
(num_tokens, hidden_states.shape[1]),
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
|
hidden_states,
|
|
output,
|
|
self.src2dst,
|
|
self.topk_idx,
|
|
self.topk_weights,
|
|
self.router_topk,
|
|
hidden_states.shape[1],
|
|
BLOCK_SIZE=512,
|
|
)
|
|
else:
|
|
output = torch.zeros(
|
|
(0, hidden_states.shape[1]),
|
|
device=hidden_states.device,
|
|
dtype=hidden_states.dtype,
|
|
)
|
|
hidden_states, event = self.combine_normal(output, self.handle)
|
|
else:
|
|
hidden_states, event, hook = self.combine_low_latency(
|
|
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
|
)
|
|
|
|
if self.async_finish:
|
|
event.current_stream_wait()
|
|
|
|
self.handle = None
|
|
return hidden_states
|
|
|
|
def combine_normal(self, x: torch.Tensor, handle: Tuple):
|
|
previous_event = Buffer.capture() if self.async_finish else None
|
|
|
|
combined_x, _, event = self.buffer_normal.combine(
|
|
x,
|
|
handle,
|
|
async_finish=self.async_finish,
|
|
previous_event=previous_event,
|
|
allocate_on_comm_stream=previous_event is not None,
|
|
)
|
|
return combined_x, event
|
|
|
|
def combine_low_latency(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_idx: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
handle: Tuple,
|
|
):
|
|
combined_hidden_states, event_overlap, hook = (
|
|
self.buffer_low_latency.low_latency_combine(
|
|
hidden_states,
|
|
topk_idx,
|
|
topk_weights,
|
|
handle,
|
|
async_finish=self.async_finish,
|
|
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
|
)
|
|
)
|
|
# hook()
|
|
return combined_hidden_states, event_overlap, hook
|