sglang0.4.5.post1/sgl-kernel/benchmark/bench_moe_align_block_size.py

386 lines
12 KiB
Python

import argparse
import itertools
import torch
import triton
import triton.language as tl
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from vllm import _custom_ops as ops
USE_RANDOM_PERM = False
def ceil_div(a, b):
return (a + b - 1) // b
@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
tokens_cnts_ptr,
num_experts: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = pid * tokens_per_thread
off_c = (pid + 1) * num_experts
for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
@triton.jit
def moe_align_block_size_stage2(
tokens_cnts_ptr,
num_experts: tl.constexpr,
):
pid = tl.program_id(0)
last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
@triton.jit
def moe_align_block_size_stage3(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)
@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)
for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)
start_idx = pid * tokens_per_thread
off_t = pid * num_experts
for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
def moe_align_block_size_triton(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts,)
tokens_cnts = torch.zeros(
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
)
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
tokens_per_thread = ceil_div(numel, num_experts)
moe_align_block_size_stage1[grid](
topk_ids,
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
)
moe_align_block_size_stage2[grid](
tokens_cnts,
num_experts,
)
moe_align_block_size_stage3[(1,)](
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
)
def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
topk_ids = torch.stack(
[
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
for _ in range(num_tokens)
]
)
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids_cuda = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids_cuda.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids_cuda = torch.zeros(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad_cuda = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device
)
token_cnts_buffer = torch.zeros(
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
sorted_ids_triton.fill_(topk_ids.numel())
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
sorted_ids_vllm = torch.empty_like(sorted_ids_cuda)
sorted_ids_vllm.fill_(topk_ids.numel())
expert_ids_vllm = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_cuda)
# compare the performance of cuda, triton and vllm implementation
sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_cuda,
expert_ids_cuda,
num_tokens_post_pad_cuda,
token_cnts_buffer,
cumsum_buffer,
)
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids_triton,
expert_ids_triton,
num_tokens_post_pad_triton,
)
try:
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_vllm,
expert_ids_vllm,
num_tokens_post_pad_vllm,
)
print(f"✅ VLLM implementation works with {num_experts} experts!")
vllm_works = True
except RuntimeError as e:
print(f"❌ VLLM implementation failed with {num_experts} experts: {e}")
vllm_works = False
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton
):
print("✅ SGL and Triton implementations match")
else:
print("❌ SGL and Triton implementations do not match")
print("SGL expert_ids:", expert_ids_cuda)
print("Triton expert_ids:", expert_ids_triton)
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
if (
vllm_works
and torch.allclose(expert_ids_cuda, expert_ids_vllm)
and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_vllm)
):
print("✅ SGL and VLLM implementations match")
else:
if not vllm_works:
print("⚠️ VLLM comparison skipped due to failure")
else:
print("❌ SGL and VLLM implementations do not match")
print("SGL expert_ids:", expert_ids_cuda)
print("VLLM expert_ids:", expert_ids_vllm)
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
# Test range
num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
num_experts_range = [8, 32, 64, 128, 256]
topk_range = [2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda")
for i in range(num_tokens):
topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[
:topk
]
return topk_ids
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs,
line_arg="provider",
line_vals=["sgl", "triton", "vllm"],
line_names=["SGL", "Triton", "VLLM"],
styles=[("blue", "-"), ("red", "-"), ("green", "-")],
ylabel="us",
plot_name="moe-align-block-size-performance",
args={},
)
)
def benchmark(num_tokens, num_experts, topk, provider):
block_size = 128
if USE_RANDOM_PERM:
topk_ids = get_topk_ids(num_tokens, num_experts, topk)
else:
topk_ids = torch.randint(
0,
num_experts,
(num_tokens, topk),
dtype=torch.int32,
device="cuda",
)
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
token_cnts_buffer = torch.zeros(
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
token_cnts_buffer,
cumsum_buffer,
),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
quantiles=quantiles,
)
else: # vllm
try:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
quantiles=quantiles,
)
except RuntimeError as e:
print(f"❌ VLLM benchmark failed with {num_experts} experts: {e}")
# Return extreme values to indicate failure in the chart
return float("inf"), float("inf"), float("inf")
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/moe_align_blocks/",
help="Path to save moe align benchmark results",
)
parser.add_argument(
"--num_experts",
type=int,
default=256,
choices=[8, 16, 32, 64, 128, 256],
help="Number of experts for benchmark",
)
parser.add_argument(
"--topk",
type=int,
default=8,
choices=[2, 4, 8],
help="Top-k value for benchmark",
)
parser.add_argument(
"--skip_full_benchmark",
action="store_true",
help="Only run the calculate_diff function, skip full benchmarking",
)
args = parser.parse_args()
calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
if not args.skip_full_benchmark:
print(f"\n📊 Running performance benchmark for {args.num_experts} experts...")
benchmark.run(print_data=True)