sglang0.4.5.post1/python/sglang/srt/layers/moe/router.py

343 lines
10 KiB
Python

from typing import Tuple
import torch
import triton
import triton.language as tl
from sglang.srt.layers.moe.topk import fused_topk
@triton.jit
def fused_moe_router_kernel(
input_ptr, # input (bs, hidden_dim)
moe_router_weight_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk)
topk_ids_ptr, # output (bs, topk)
num_experts: tl.constexpr,
topk: tl.constexpr,
moe_softcapping: tl.constexpr,
moe_renormalize: tl.constexpr, # not supported
hidden_dim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
offsets = tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_dim
# moe_router_weight is k major
expert_offsets = tl.arange(0, num_experts)[:, None]
router_mask = mask[None, :]
w_router = tl.load(
moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :],
mask=router_mask,
other=0.0,
)
x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0)
# todo: tl.dot?
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
# logit softcap
logits_scaled = logits / moe_softcapping
exped = tl.exp(2 * logits_scaled)
top = exped - 1
bottom = exped + 1
logits_softcapped = top / bottom * moe_softcapping
# topk
# assert 1 <= topk <= num_experts
# 5.38 us
top1 = tl.argmax(logits_softcapped, axis=0)
tl.store(topk_ids_ptr + pid * topk + 0, top1) # 5.63 us
top1_v = tl.max(logits_softcapped, axis=0)
invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)
tl.store(
topk_weights_ptr + pid * topk + 0,
invsumexp,
) # 5.73 us
if topk >= 2:
top2 = tl.argmax(
tl.where(
tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf")
),
axis=0,
)
tl.store(topk_ids_ptr + pid * topk + 1, top2)
top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0)
tl.store(
topk_weights_ptr + pid * topk + 1,
tl.exp(top2_v - top1_v) * invsumexp,
) # 5.95us
# probably slow
if topk > 2:
topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype)
topk_mask = tl.where(
tl.arange(0, num_experts) != top1, topk_mask, float("-inf")
)
topk_mask = tl.where(
tl.arange(0, num_experts) != top2, topk_mask, float("-inf")
)
for i in range(2, topk):
topi = tl.argmax(logits_softcapped + topk_mask, axis=0)
topk_mask = tl.where(
tl.arange(0, num_experts) != topi, topk_mask, float("-inf")
)
tl.store(topk_ids_ptr + pid * topk + i, topi)
topi_v = tl.sum(
logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0
)
tl.store(
topk_weights_ptr + pid * topk + i,
tl.exp(topi_v - top1_v) * invsumexp,
)
# assert not moe_renormalize, "moe weight renormalization not implemented"
def fused_moe_router_impl(
x: torch.Tensor,
router_weight: torch.Tensor,
topk: int,
moe_softcapping: float,
):
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
bs, hidden_dim = x.shape
num_experts = router_weight.shape[0]
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
grid = lambda meta: (bs,)
config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
),
}
fused_moe_router_kernel[grid](
x,
router_weight,
topk_weights,
topk_ids,
num_experts=num_experts,
topk=topk,
moe_softcapping=moe_softcapping,
moe_renormalize=False,
hidden_dim=hidden_dim,
**config,
)
return topk_weights, topk_ids
@triton.jit
def fused_moe_router_large_bs_kernel(
a_ptr, # input (bs, hidden_dim)
b_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk)
topk_ids_ptr, # output (bs, topk)
bs,
num_experts: tl.constexpr,
topk: tl.constexpr, # only support topk == 1
moe_softcapping: tl.constexpr,
moe_renormalize: tl.constexpr, # not supported
K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
stride_am: tl.constexpr,
stride_bn: tl.constexpr,
):
# 1. get block id
pid = tl.program_id(axis=0)
# 2. create pointers for the first block of A and B
# 2.1. setup a_ptrs with offsets in m and k
offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]
bs_mask = offs_m < bs
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
a_ptrs = a_ptr + (offs_m * stride_am + offs_k)
# 2.2. setup b_ptrs with offsets in k and n.
# Note: b matrix is k-major.
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
offs_n = tl.arange(0, BLOCK_SIZE_N)[:, None]
expert_mask = offs_n < num_experts
b_ptrs = b_ptr + (offs_n * stride_bn + offs_k)
# 3. Create an accumulator of float32 of size [BLOCK_SIZE_M, BLOCK_SIZE_N]
# 3.1. iterate in K dimension
# 3.2. transpose tile B
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K // BLOCK_SIZE_K): # hidden_dim % BLOCK_SIZE_K == 0
a = tl.load(
a_ptrs,
mask=bs_mask,
other=0.0,
).to(tl.float32)
b = tl.load(b_ptrs, mask=expert_mask, other=0.0).to(tl.float32).T
acc += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
# 4. logit softcap
logits_scaled = acc / moe_softcapping
exped = tl.exp(2 * logits_scaled)
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
# 5. top1
cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
top1_v = tl.max(
tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
)
invsumexp = 1.0 / tl.sum(
tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
)
# 6. store to output
offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
topk_mask = offs_topk < bs
tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
tl.store(
topk_weights_ptr + offs_topk,
invsumexp,
mask=topk_mask,
)
def fused_moe_router_large_bs_impl(
x: torch.Tensor,
router_weight: torch.Tensor,
topk: int,
moe_softcapping: float,
BLOCK_SIZE_M: int,
BLOCK_SIZE_N: int,
BLOCK_SIZE_K: int,
):
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
bs, hidden_dim = x.shape
num_experts = router_weight.shape[0]
assert num_experts <= BLOCK_SIZE_N
assert hidden_dim % BLOCK_SIZE_K == 0
assert topk == 1
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
fused_moe_router_large_bs_kernel[grid](
a_ptr=x,
b_ptr=router_weight,
topk_weights_ptr=topk_weights,
topk_ids_ptr=topk_ids,
bs=bs,
num_experts=num_experts,
topk=topk,
moe_softcapping=moe_softcapping,
moe_renormalize=False,
K=hidden_dim,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
stride_am=hidden_dim,
stride_bn=hidden_dim,
)
return topk_weights, topk_ids
def fused_moe_router_shim(
moe_softcapping,
hidden_states,
gating_output,
topk,
renormalize,
):
assert not renormalize
assert (
len(hidden_states.shape) == 2
and hidden_states.shape[1] == gating_output.shape[1]
)
bs, hidden_dim = hidden_states.shape
num_experts = gating_output.shape[0]
BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 16
BLOCK_SIZE_K = 256
if (
bs >= 512
and topk == 1
and num_experts <= BLOCK_SIZE_N
and hidden_dim % BLOCK_SIZE_K == 0
):
return fused_moe_router_large_bs_impl(
x=hidden_states,
router_weight=gating_output,
topk=topk,
moe_softcapping=moe_softcapping,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
else:
return fused_moe_router_impl(
x=hidden_states,
router_weight=gating_output,
topk=topk,
moe_softcapping=moe_softcapping,
)
class FusedMoeRouter:
def __init__(self, router_linear, topk, moe_softcapping) -> None:
self.router_linear = router_linear
self.topk = topk
self.moe_softcapping = moe_softcapping
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
def forward(
self, x: torch.Tensor, residual: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
if x.is_cuda:
return self.forward_cuda(x, residual)
else:
return self.forward_vllm(x, residual)
def forward_cuda(
self, x: torch.Tensor, autotune=False
) -> Tuple[torch.Tensor, torch.Tensor]:
return fused_moe_router_shim(
moe_softcapping=self.moe_softcapping,
hidden_states=x,
gating_output=self.router_linear.weight,
topk=self.topk,
renormalize=False,
)
def forward_vllm(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# g, _ = self.router_linear.forward(x)
g = x.float() @ self.router_linear.weight.T.float()
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
return fused_topk(x, g, self.topk, False)