# Benchmarks SGLang kernels versus vLLM across # (kernel, dtype, batch_size, seq_len, dim) and prints speed-up. import argparse import itertools import re from typing import List, Tuple import sgl_kernel import torch import torch.nn.functional as F import triton import triton.testing from sgl_kernel import gelu_quick # activation-only kernel from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from vllm import _custom_ops as vllm_ops if not hasattr(vllm_ops, "silu_and_mul"): vllm_ops = torch.ops._C def str2int_list(arg: str) -> List[int]: if arg in ("", None): return [] if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None: raise argparse.ArgumentTypeError(f"Bad int list: {arg}") return [int(x) for x in arg.split(",")] def calculate_diff( kernel: str, dtype: torch.dtype, batch_size: int, seq_len: int, dim: int ) -> bool: """Compare vLLM with SGLang for one shape.""" device = torch.device("cuda") # activation-only quick GELU if kernel == "gelu_quick": x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device) ref_out = torch.zeros_like(x) getattr(vllm_ops, kernel)(ref_out, x) test_out = getattr(sgl_kernel, kernel)(x) # fused activation x mul kernels else: x = torch.randn(batch_size, seq_len, 2 * dim, dtype=dtype, device=device) ref_out = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device) getattr(vllm_ops, kernel)(ref_out, x) test_out = getattr(sgl_kernel, kernel)(x) ok = torch.allclose(ref_out, test_out, rtol=1e-3, atol=1e-5) tag = "✅ match" if ok else "❌ mismatch" print( f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | " f"L={seq_len:3d} | D={dim:5d}] {tag}" ) return ok kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"] dtypes = [torch.float16, torch.bfloat16] def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]: return list(itertools.product(kernels, dtypes, bsizes, slens, dims_)) default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16 default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64 default_dims = [2**i for i in range(7, 15)] # 128...16384 @triton.testing.perf_report( triton.testing.Benchmark( x_names=["kernel", "dtype", "batch_size", "seq_len", "dim"], x_vals=[], line_arg="provider", line_vals=["vllm", "sglang", "speedup"], line_names=["vLLM", "SGL Kernel", "Speed-up (x)"], styles=[("blue", "-"), ("green", "-"), ("red", "--")], ylabel="µs (median) or × (speed-up)", plot_name="activation-performance", args={}, ) ) def benchmark(kernel, dtype, batch_size, seq_len, dim, provider): device = torch.device("cuda") in_mult = 1 if kernel == "gelu_quick" else 2 x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device) y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device) vllm_kernel = getattr(vllm_ops, kernel) sglang_kernel = getattr(sgl_kernel, kernel) def baseline(): tmp = y0.clone() vllm_kernel(tmp, x) return tmp def sglang(): return sglang_kernel(x) # one-time correctness check if provider == "vllm" and not calculate_diff( kernel, dtype, batch_size, seq_len, dim ): raise ValueError("Mismatch – abort benchmark") # timing helper def timed(fn): for _ in range(5): fn() torch.cuda.synchronize() ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) return 1000 * ms, 1000 * qmax, 1000 * qmin if provider == "vllm": return timed(baseline) if provider == "sglang": return timed(sglang) # provider == "speedup" t_ref, _, _ = timed(baseline) t_sgl, _, _ = timed(sglang) spd = t_ref / t_sgl return (spd, spd, spd) if __name__ == "__main__": p = argparse.ArgumentParser("Activation kernel benchmark") p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes) p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens) p.add_argument("--dims", type=str2int_list, default=default_dims) p.add_argument("--verify_only", action="store_true") args = p.parse_args() # coerce lists if isinstance(args.batch_sizes, str): args.batch_sizes = str2int_list(args.batch_sizes) if isinstance(args.seq_lens, str): args.seq_lens = str2int_list(args.seq_lens) if isinstance(args.dims, str): args.dims = str2int_list(args.dims) # patch perf_report grid benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.dims) if hasattr(benchmark, "benchmarks"): benchmark.benchmarks.x_vals = benchmark_grid else: benchmark.benchmark.x_vals = benchmark_grid if args.verify_only: ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0]) print("✅ sanity pass" if ok else "❌ mismatch") else: benchmark.run(print_data=True)