133 lines
4.2 KiB
Python
133 lines
4.2 KiB
Python
"""
|
|
Torch-native implementation for FusedMoE. This is used for torch.compile.
|
|
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
|
|
"""
|
|
|
|
from typing import Callable, Optional
|
|
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
from sglang.srt.layers.moe.topk import select_experts
|
|
|
|
|
|
def fused_moe_forward_native(
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
activation: str = "silu",
|
|
inplace: bool = True,
|
|
no_combine: bool = False,
|
|
) -> torch.Tensor:
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
use_grouped_topk=use_grouped_topk,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
correction_bias=correction_bias,
|
|
torch_native=True,
|
|
)
|
|
|
|
w13_weights = layer.w13_weight[topk_ids]
|
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
|
w2_weights = layer.w2_weight[topk_ids]
|
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
|
if activation == "silu":
|
|
x1 = F.silu(x1)
|
|
elif activation == "gelu":
|
|
x1 = F.gelu(x1)
|
|
else:
|
|
raise ValueError(f"Unsupported activation: {activation=}")
|
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
|
|
|
|
|
def moe_forward_native(
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
correction_bias: Optional[torch.Tensor] = None,
|
|
activation: str = "silu",
|
|
) -> torch.Tensor:
|
|
|
|
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
|
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
use_grouped_topk=use_grouped_topk,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
correction_bias=correction_bias,
|
|
torch_native=True,
|
|
)
|
|
|
|
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
|
|
len_experts = layer.num_experts
|
|
|
|
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
|
|
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
|
|
tokens_per_expert = cnts.sum(dim=0)
|
|
idxs = topk_ids.view(-1).argsort()
|
|
|
|
sorted_tokens = x[idxs // topk_ids.shape[1]]
|
|
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
|
|
|
if activation == "silu":
|
|
act = SiluAndMul()
|
|
elif activation == "gelu":
|
|
act = GeluAndMul()
|
|
else:
|
|
raise ValueError(f"Unsupported activation: {activation=}")
|
|
|
|
outputs = []
|
|
start_idx = 0
|
|
for i, num_tokens in enumerate(tokens_per_expert):
|
|
end_idx = start_idx + num_tokens
|
|
if num_tokens == 0:
|
|
continue
|
|
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
|
|
|
|
layer_w13_weight = layer.w13_weight[i]
|
|
layer_w2_weight = layer.w2_weight[i]
|
|
|
|
gate_up = F.linear(tokens_for_this_expert, layer_w13_weight)
|
|
gate_up = act(gate_up)
|
|
expert_out = F.linear(gate_up, layer_w2_weight)
|
|
outputs.append(expert_out)
|
|
start_idx = end_idx
|
|
|
|
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
|
new_x = torch.empty_like(outs)
|
|
|
|
new_x[idxs] = outs
|
|
final_out = (
|
|
new_x.view(*topk_ids.shape, -1)
|
|
.type(topk_weights.dtype)
|
|
.mul_(topk_weights.unsqueeze(dim=-1))
|
|
.sum(dim=1)
|
|
.type(new_x.dtype)
|
|
)
|
|
return final_out
|