sglang_v0.5.2/sglang/sgl-kernel/benchmark/bench_activation.py

154 lines
5.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Benchmarks SGLang kernels versus vLLM across
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
import argparse
import itertools
import re
from typing import List, Tuple
import sgl_kernel
import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import gelu_quick # activation-only kernel
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm import _custom_ops as vllm_ops
if not hasattr(vllm_ops, "silu_and_mul"):
vllm_ops = torch.ops._C
def str2int_list(arg: str) -> List[int]:
if arg in ("", None):
return []
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
return [int(x) for x in arg.split(",")]
def calculate_diff(
kernel: str, dtype: torch.dtype, batch_size: int, seq_len: int, dim: int
) -> bool:
"""Compare vLLM with SGLang for one shape."""
device = torch.device("cuda")
# activation-only quick GELU
if kernel == "gelu_quick":
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
ref_out = torch.zeros_like(x)
getattr(vllm_ops, kernel)(ref_out, x)
test_out = getattr(sgl_kernel, kernel)(x)
# fused activation x mul kernels
else:
x = torch.randn(batch_size, seq_len, 2 * dim, dtype=dtype, device=device)
ref_out = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
getattr(vllm_ops, kernel)(ref_out, x)
test_out = getattr(sgl_kernel, kernel)(x)
ok = torch.allclose(ref_out, test_out, rtol=1e-3, atol=1e-5)
tag = "✅ match" if ok else "❌ mismatch"
print(
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
f"L={seq_len:3d} | D={dim:5d}] {tag}"
)
return ok
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"]
dtypes = [torch.float16, torch.bfloat16]
def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(7, 15)] # 128...16384
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["kernel", "dtype", "batch_size", "seq_len", "dim"],
x_vals=[],
line_arg="provider",
line_vals=["vllm", "sglang", "speedup"],
line_names=["vLLM", "SGL Kernel", "Speed-up (x)"],
styles=[("blue", "-"), ("green", "-"), ("red", "--")],
ylabel="µs (median) or × (speed-up)",
plot_name="activation-performance",
args={},
)
)
def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
device = torch.device("cuda")
in_mult = 1 if kernel == "gelu_quick" else 2
x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
vllm_kernel = getattr(vllm_ops, kernel)
sglang_kernel = getattr(sgl_kernel, kernel)
def baseline():
tmp = y0.clone()
vllm_kernel(tmp, x)
return tmp
def sglang():
return sglang_kernel(x)
# one-time correctness check
if provider == "vllm" and not calculate_diff(
kernel, dtype, batch_size, seq_len, dim
):
raise ValueError("Mismatch abort benchmark")
# timing helper
def timed(fn):
for _ in range(5):
fn()
torch.cuda.synchronize()
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
return 1000 * ms, 1000 * qmax, 1000 * qmin
if provider == "vllm":
return timed(baseline)
if provider == "sglang":
return timed(sglang)
# provider == "speedup"
t_ref, _, _ = timed(baseline)
t_sgl, _, _ = timed(sglang)
spd = t_ref / t_sgl
return (spd, spd, spd)
if __name__ == "__main__":
p = argparse.ArgumentParser("Activation kernel benchmark")
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
p.add_argument("--dims", type=str2int_list, default=default_dims)
p.add_argument("--verify_only", action="store_true")
args = p.parse_args()
# coerce lists
if isinstance(args.batch_sizes, str):
args.batch_sizes = str2int_list(args.batch_sizes)
if isinstance(args.seq_lens, str):
args.seq_lens = str2int_list(args.seq_lens)
if isinstance(args.dims, str):
args.dims = str2int_list(args.dims)
# patch perf_report grid
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.dims)
if hasattr(benchmark, "benchmarks"):
benchmark.benchmarks.x_vals = benchmark_grid
else:
benchmark.benchmark.x_vals = benchmark_grid
if args.verify_only:
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0])
print("✅ sanity pass" if ok else "❌ mismatch")
else:
benchmark.run(print_data=True)