from typing import Tuple import deep_gemm import torch import triton import triton.language as tl from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor # Import shared functionality from the regular GEMM benchmark from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import ( per_block_cast_to_fp8, per_token_cast_to_fp8, ) def construct_grouped_and_flat_fp8( x: torch.Tensor, y: torch.Tensor, num_groups: int, is_masked: bool ) -> Tuple[ Tuple[torch.Tensor, torch.Tensor], # grouped x_fp8 Tuple[torch.Tensor, torch.Tensor], # grouped y_fp8 Tuple[torch.Tensor, torch.Tensor], # flat x_fp8 Tuple[torch.Tensor, torch.Tensor], # flat y_fp8 torch.Tensor, # output torch.Tensor, # reference output ]: # Verify input shapes m, k = x.shape n, k_y = y.shape assert k == k_y, f"Incompatible shapes: x({m}, {k}), y({n}, {k_y})" assert m % num_groups == 0, f"m({m}) must be divisible by num_groups({num_groups})" assert m % 4 == 0, f"TMA alignment error: {m}" # Reshape inputs for grouped processing m_per_group = m // num_groups x_grouped = x.view(num_groups, m_per_group, k) y_grouped = y.unsqueeze(0).expand(num_groups, n, k) # Initialize output tensors out = torch.empty((num_groups, m_per_group, n), device="cuda", dtype=torch.bfloat16) ref_out = torch.einsum("gmk,gnk->gmn", x_grouped, y_grouped) # Quantize grouped tensors x_fp8_grouped = ( torch.empty_like(x_grouped, dtype=torch.float8_e4m3fn), torch.empty( (num_groups, m_per_group, k // 128), device="cuda", dtype=torch.float ), ) y_fp8_grouped = ( torch.empty_like(y_grouped, dtype=torch.float8_e4m3fn), torch.empty( (num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float ), ) for i in range(num_groups): x_fp8_grouped[0][i], x_fp8_grouped[1][i] = per_token_cast_to_fp8(x_grouped[i]) y_fp8_grouped[0][i], y_fp8_grouped[1][i] = per_block_cast_to_fp8(y_grouped[i]) # Quantize flat tensors x_fp8_flat = per_token_cast_to_fp8(x) y_fp8_flat = per_block_cast_to_fp8(y) # For non-masked input, merge the group and M dims in output if not is_masked: x_fp8_grouped = ( x_fp8_grouped[0].view(-1, k), per_token_cast_to_fp8(x_grouped.view(-1, k))[1], ) out, ref_out = out.view(-1, n), ref_out.view(-1, n) # Transpose earlier for testing x_fp8_grouped = ( x_fp8_grouped[0], get_col_major_tma_aligned_tensor(x_fp8_grouped[1]), ) x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1])) return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out # Since we don't have a group gemm kernel in SGLang/vLLM, we implemented a # custom kernel based on the Triton tutorial. # https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html @triton.jit def fp8_gemm_group_triton_kernel( # Pointers to matrices a_ptr, b_ptr, c_ptr, # Pointers to scaling factors a_scale_ptr, b_scale_ptr, # Matrix dimensions M, N, K, # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, # Strides for scaling factors stride_a_scale_m, stride_a_scale_k, stride_b_scale_n, stride_b_scale_k, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ): """Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors. A has shape (M, K), B has shape (K, N) and C has shape (M, N) Note: Block sizes must be multiples of 32 for optimal TMA performance. """ # Map program ids to the block of C it should compute pid_group = tl.program_id(axis=0) # Group ID pid_n = tl.program_id(axis=1) # N dimension ID # Compute the M block ID within this group group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M) pid_m_within_group = tl.program_id(axis=2) % group_size_m pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group # Create pointers for the first blocks of A and B offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) # Initialize accumulator accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Main loop for k_block in range(0, tl.cdiv(K, BLOCK_SIZE_K)): k_offset = k_block * BLOCK_SIZE_K # Load the next block of A and B, with masks a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k_offset, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_offset, other=0.0) # Calculate indices for scaling factors for this K block a_scale_ptrs = a_scale_ptr + ( offs_am * stride_a_scale_m + k_block * stride_a_scale_k ) b_scale_ptrs = b_scale_ptr + ( pid_n * stride_b_scale_n + k_block * stride_b_scale_k ) # Perform matrix multiplication in FP8 res = tl.dot(a, b) # Load scaling factors for the current block a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1] b_scale = tl.load(b_scale_ptrs) # Apply scaling factors to the accumulated result accumulator += res * a_scale * b_scale # Advance pointers a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk # Convert to bfloat16 for output c = accumulator.to(tl.bfloat16) # Write back the result offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups): """ Perform matrix multiplication with FP8 inputs and proper scaling. Args: a_tuple: Tuple of (quantized_tensor, scale_factors) for input A b_tuple: Tuple of (quantized_tensor, scale_factors) for input B c: Output tensor in BF16 format num_groups: Number of groups for grouped GEMM Returns: Result tensor in BF16 format """ # Unpack the tuples a, a_scale = a_tuple b, b_scale = b_tuple M, K = a.shape _, N = b.shape # Configure block sizes - must be multiples of 32 for TMA alignment BLOCK_SIZE_M = 128 BLOCK_SIZE_N = 128 BLOCK_SIZE_K = 128 # Calculate grid dimensions num_pid_m = triton.cdiv(M, BLOCK_SIZE_M) num_pid_n = triton.cdiv(N, BLOCK_SIZE_N) num_groups_grid = triton.cdiv(num_pid_m, num_groups) # 3D grid launch - (group, n_blocks, m_blocks_per_group) grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m)) fp8_gemm_group_triton_kernel[grid]( a, b, c, a_scale, b_scale, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), a_scale.stride(0), 1, # Stride in the K dimension may be 1 b_scale.stride(0), 1 if b_scale.dim() > 1 else 0, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, GROUP_SIZE_M=num_groups, ) return c def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices): deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( x_fp8_grouped, y_fp8_grouped, out, m_indices, ) return out def calculate_diff(m: int, n: int, k: int, num_groups: int): print(f"Shape (m={m}, n={n}, k={k}") x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = ( construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False) ) m_per_group = m // num_groups out_deepgemm = out.clone() m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int) m_indices = ( m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1) ) fp8_gemm_group_deepgemm( x_fp8_grouped, y_fp8_grouped, out_deepgemm, m_indices, ) torch.cuda.synchronize() # Prepare inputs for Triton a, a_scale = x_fp8_flat b, b_scale = y_fp8_flat b = b.T.contiguous() # Ensure scales are in the right format and contiguous a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous() M, _ = a.shape _, N = b.shape c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups) torch.cuda.synchronize() diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item() diff_torch_triton = torch.abs(out_torch - out_triton).mean().item() diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item() print(f"Shape m={m}, n={n}, k={k}:") print(f"Torch output: {out_torch[0, 0:5]}") print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}") print(f"Triton output: {out_triton[0, 0:5]}") print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}") print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}") print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}") deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch) triton_torch_diff = calc_diff(out_triton, out_torch) deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton) DIFF_THRESHOLD = 0.001 all_match = ( deepgemm_torch_diff < DIFF_THRESHOLD and triton_torch_diff < DIFF_THRESHOLD and deepgemm_triton_diff < DIFF_THRESHOLD ) if all_match: print("✅ All implementations match\n") else: print("❌ Some implementations differ:") print( f" - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}" f" - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}" f" - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}" ) def get_weight_shapes(tp_size): # cannot TP total = [ (512 + 64, 7168), ((128 + 64) * 128, 7168), (128 * (128 + 128), 512), (7168, 16384), (7168, 18432), ] # N can TP n_tp = [ (18432 * 2, 7168), ((128 + 64) * 128, 7168), (128 * (128 + 128), 512), (24576, 1536), (4096, 7168), ] # K can TP k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)] weight_shapes = [] for t in total: weight_shapes.append(t) for n_t in n_tp: new_t = (n_t[0] // tp_size, n_t[1]) weight_shapes.append(new_t) for k_t in k_tp: new_t = (k_t[0], k_t[1] // tp_size) weight_shapes.append(new_t) return weight_shapes def create_benchmark_configs(tp_size): configs = [] weight_shapes = get_weight_shapes(tp_size) batch_sizes = [2048, 4096] group_sizes = [4, 8] for n, k in weight_shapes: for m in batch_sizes: for num_groups in group_sizes: configs.append((m, n, k, num_groups, tp_size)) return configs def get_benchmark(tp_size): all_configs = create_benchmark_configs(tp_size) @triton.testing.perf_report( triton.testing.Benchmark( x_names=["m", "n", "k", "num_groups", "tp_size"], x_vals=[config for config in all_configs], line_arg="provider", line_vals=["deepgemm", "triton"], line_names=["DeepGEMM", "Triton"], styles=[("blue", "-"), ("red", "-")], ylabel="ms", plot_name=f"fp8-group-gemm-performance-comparison-tp{tp_size}", args={}, ) ) def benchmark(m, n, k, num_groups, tp_size, provider): print( f"Shape (m={m}, n={n}, k={k}, tp={tp_size}, num_groups={num_groups}, Provider: {provider}" ) x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = ( construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False) ) m_per_group = m // num_groups m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int) m_indices = ( m_indices.unsqueeze(-1) .expand(num_groups, m_per_group) .contiguous() .view(-1) ) quantiles = [0.5, 0.2, 0.8] if provider == "deepgemm": ms, min_ms, max_ms = triton.testing.do_bench( lambda: fp8_gemm_group_deepgemm( x_fp8_grouped, y_fp8_grouped, out, m_indices, ), quantiles=quantiles, ) elif provider == "triton": # Prepare inputs for Triton # We did it outside of the lambda function to make it fair comparison like deepgemm a, a_scale = x_fp8_flat b, b_scale = y_fp8_flat b = b.T.contiguous() # Ensure scales are in the right format and contiguous a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous() M, _ = a.shape _, N = b.shape c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16) ms, min_ms, max_ms = triton.testing.do_bench( lambda: fp8_gemm_group_triton( (a, a_scale), (b, b_scale), c, num_groups, ), quantiles=quantiles, ) # Calculate TFLOPS flops = 2 * m * n * k # multiply-adds tflops = flops / (ms * 1e-3) / 1e12 print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}") return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms return benchmark if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument( "--save_path", type=str, default="./configs/benchmark_ops/fp8_group_gemm/", help="Path to save deepgemm fp8 group gemm benchmark results", ) parser.add_argument( "--run_correctness", action="store_true", help="Whether to run correctness test", ) parser.add_argument( "--tp_size", type=int, default=1, help="Tensor parallelism size to benchmark (default: 1)", ) args = parser.parse_args() # Set random seed for reproducibility torch.manual_seed(0) torch.cuda.manual_seed(0) # Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148 torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Run correctness tests on a few examples if args.run_correctness: print("Running correctness tests...") calculate_diff(8192, 7168, 4096, 4) calculate_diff(8192, 2048, 7168, 4) calculate_diff(4096, 7168, 4096, 8) calculate_diff(4096, 2048, 7168, 8) calculate_diff(4096, 576, 7168, 8) # Get the benchmark function with the specified tp_size benchmark = get_benchmark(args.tp_size) print(f"Running performance benchmark for TP size = {args.tp_size}...") benchmark.run(print_data=True, save_path=args.save_path)