import itertools import math from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch import triton import triton.testing from sgl_kernel import sgl_per_tensor_quant_fp8 from vllm import _custom_ops as ops from sglang.srt.utils import is_hip _is_hip = is_hip() fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn def vllm_scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: return ops.scaled_fp8_quant(input, scale) def sglang_scaled_fp8_quant( input: torch.Tensor, scale: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: fp8_type_: torch.dtype = torch.float8_e4m3fn output = torch.empty_like(input, device=input.device, dtype=fp8_type_) is_static = True if scale is None: scale = torch.zeros(1, device=input.device, dtype=torch.float32) is_static = False sgl_per_tensor_quant_fp8(input, output, scale, is_static) return output, scale def calculate_diff(batch_size: int, seq_len: int): """Calculate difference between VLLM and SGLang implementations.""" device = torch.device("cuda") x = torch.rand((batch_size, seq_len), dtype=torch.float16, device=device) vllm_out, vllm_scale = vllm_scaled_fp8_quant(x) sglang_out, sglang_scale = sglang_scaled_fp8_quant(x) scale_diff = torch.abs(vllm_scale - sglang_scale).item() output_diff = torch.abs(vllm_out.float() - sglang_out.float()).mean().item() if torch.allclose( vllm_out.to(torch.float32), sglang_out.to(torch.float32), rtol=1e-3, atol=1e-5 ) and torch.allclose(vllm_scale, sglang_scale, rtol=1e-3, atol=1e-5): print("✅ All implementations match") else: print("❌ Implementations differ") batch_size_range = [16, 32, 64, 128] seq_len_range = [64, 128, 256, 512, 1024, 2048] configs = list(itertools.product(batch_size_range, seq_len_range)) @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size", "seq_len"], x_vals=configs, line_arg="provider", line_vals=["vllm", "sglang"], line_names=["VLLM", "SGL Kernel"], styles=[("blue", "-"), ("green", "-")], ylabel="us", plot_name="per-tensor-quant-fp8-performance", args={}, ) ) def benchmark(batch_size, seq_len, provider): dtype = torch.float16 device = torch.device("cuda") x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype) quantiles = [0.5, 0.2, 0.8] if provider == "vllm": fn = lambda: vllm_scaled_fp8_quant(x.clone()) elif provider == "sglang": fn = lambda: sglang_scaled_fp8_quant(x.clone()) ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) return 1000 * ms, 1000 * max_ms, 1000 * min_ms if __name__ == "__main__": calculate_diff(batch_size=4, seq_len=4096) benchmark.run(print_data=True)