487 lines
16 KiB
Python
487 lines
16 KiB
Python
from typing import Tuple
|
|
|
|
import deep_gemm
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor
|
|
|
|
# Import shared functionality from the regular GEMM benchmark
|
|
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
|
|
per_block_cast_to_fp8,
|
|
per_token_cast_to_fp8,
|
|
)
|
|
|
|
|
|
def construct_grouped_and_flat_fp8(
|
|
x: torch.Tensor, y: torch.Tensor, num_groups: int, is_masked: bool
|
|
) -> Tuple[
|
|
Tuple[torch.Tensor, torch.Tensor], # grouped x_fp8
|
|
Tuple[torch.Tensor, torch.Tensor], # grouped y_fp8
|
|
Tuple[torch.Tensor, torch.Tensor], # flat x_fp8
|
|
Tuple[torch.Tensor, torch.Tensor], # flat y_fp8
|
|
torch.Tensor, # output
|
|
torch.Tensor, # reference output
|
|
]:
|
|
# Verify input shapes
|
|
m, k = x.shape
|
|
n, k_y = y.shape
|
|
assert k == k_y, f"Incompatible shapes: x({m}, {k}), y({n}, {k_y})"
|
|
assert m % num_groups == 0, f"m({m}) must be divisible by num_groups({num_groups})"
|
|
assert m % 4 == 0, f"TMA alignment error: {m}"
|
|
|
|
# Reshape inputs for grouped processing
|
|
m_per_group = m // num_groups
|
|
x_grouped = x.view(num_groups, m_per_group, k)
|
|
y_grouped = y.unsqueeze(0).expand(num_groups, n, k)
|
|
|
|
# Initialize output tensors
|
|
out = torch.empty((num_groups, m_per_group, n), device="cuda", dtype=torch.bfloat16)
|
|
ref_out = torch.einsum("gmk,gnk->gmn", x_grouped, y_grouped)
|
|
|
|
# Quantize grouped tensors
|
|
x_fp8_grouped = (
|
|
torch.empty_like(x_grouped, dtype=torch.float8_e4m3fn),
|
|
torch.empty(
|
|
(num_groups, m_per_group, k // 128), device="cuda", dtype=torch.float
|
|
),
|
|
)
|
|
y_fp8_grouped = (
|
|
torch.empty_like(y_grouped, dtype=torch.float8_e4m3fn),
|
|
torch.empty(
|
|
(num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float
|
|
),
|
|
)
|
|
for i in range(num_groups):
|
|
x_fp8_grouped[0][i], x_fp8_grouped[1][i] = per_token_cast_to_fp8(x_grouped[i])
|
|
y_fp8_grouped[0][i], y_fp8_grouped[1][i] = per_block_cast_to_fp8(y_grouped[i])
|
|
|
|
# Quantize flat tensors
|
|
x_fp8_flat = per_token_cast_to_fp8(x)
|
|
y_fp8_flat = per_block_cast_to_fp8(y)
|
|
|
|
# For non-masked input, merge the group and M dims in output
|
|
if not is_masked:
|
|
x_fp8_grouped = (
|
|
x_fp8_grouped[0].view(-1, k),
|
|
per_token_cast_to_fp8(x_grouped.view(-1, k))[1],
|
|
)
|
|
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
|
|
|
|
# Transpose earlier for testing
|
|
x_fp8_grouped = (
|
|
x_fp8_grouped[0],
|
|
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
|
|
)
|
|
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))
|
|
|
|
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
|
|
|
|
|
|
# Since we don't have a group gemm kernel in SGLang/vLLM, we implemented a
|
|
# custom kernel based on the Triton tutorial.
|
|
# https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html
|
|
@triton.jit
|
|
def fp8_gemm_group_triton_kernel(
|
|
# Pointers to matrices
|
|
a_ptr,
|
|
b_ptr,
|
|
c_ptr,
|
|
# Pointers to scaling factors
|
|
a_scale_ptr,
|
|
b_scale_ptr,
|
|
# Matrix dimensions
|
|
M,
|
|
N,
|
|
K,
|
|
# The stride variables represent how much to increase the ptr by when moving by 1
|
|
# element in a particular dimension.
|
|
stride_am,
|
|
stride_ak,
|
|
stride_bk,
|
|
stride_bn,
|
|
stride_cm,
|
|
stride_cn,
|
|
# Strides for scaling factors
|
|
stride_a_scale_m,
|
|
stride_a_scale_k,
|
|
stride_b_scale_n,
|
|
stride_b_scale_k,
|
|
# Meta-parameters
|
|
BLOCK_SIZE_M: tl.constexpr,
|
|
BLOCK_SIZE_N: tl.constexpr,
|
|
BLOCK_SIZE_K: tl.constexpr,
|
|
GROUP_SIZE_M: tl.constexpr,
|
|
):
|
|
"""Kernel for computing the matmul C = A x B with FP8 inputs and scaling factors.
|
|
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
|
|
|
|
Note: Block sizes must be multiples of 32 for optimal TMA performance.
|
|
"""
|
|
# Map program ids to the block of C it should compute
|
|
pid_group = tl.program_id(axis=0) # Group ID
|
|
pid_n = tl.program_id(axis=1) # N dimension ID
|
|
|
|
# Compute the M block ID within this group
|
|
group_size_m = min(M - pid_group * GROUP_SIZE_M, GROUP_SIZE_M)
|
|
pid_m_within_group = tl.program_id(axis=2) % group_size_m
|
|
pid_m = pid_group * GROUP_SIZE_M + pid_m_within_group
|
|
|
|
# Create pointers for the first blocks of A and B
|
|
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
|
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
|
|
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
|
|
|
|
# Initialize accumulator
|
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
|
|
# Main loop
|
|
for k_block in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
|
k_offset = k_block * BLOCK_SIZE_K
|
|
|
|
# Load the next block of A and B, with masks
|
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k_offset, other=0.0)
|
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k_offset, other=0.0)
|
|
|
|
# Calculate indices for scaling factors for this K block
|
|
a_scale_ptrs = a_scale_ptr + (
|
|
offs_am * stride_a_scale_m + k_block * stride_a_scale_k
|
|
)
|
|
b_scale_ptrs = b_scale_ptr + (
|
|
pid_n * stride_b_scale_n + k_block * stride_b_scale_k
|
|
)
|
|
|
|
# Perform matrix multiplication in FP8
|
|
res = tl.dot(a, b)
|
|
|
|
# Load scaling factors for the current block
|
|
a_scale = tl.load(a_scale_ptrs)[:, None] # [BLOCK_SIZE_M, 1]
|
|
b_scale = tl.load(b_scale_ptrs)
|
|
|
|
# Apply scaling factors to the accumulated result
|
|
accumulator += res * a_scale * b_scale
|
|
|
|
# Advance pointers
|
|
a_ptrs += BLOCK_SIZE_K * stride_ak
|
|
b_ptrs += BLOCK_SIZE_K * stride_bk
|
|
|
|
# Convert to bfloat16 for output
|
|
c = accumulator.to(tl.bfloat16)
|
|
|
|
# Write back the result
|
|
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
|
|
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
|
tl.store(c_ptrs, c, mask=c_mask)
|
|
|
|
|
|
def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
|
|
"""
|
|
Perform matrix multiplication with FP8 inputs and proper scaling.
|
|
|
|
Args:
|
|
a_tuple: Tuple of (quantized_tensor, scale_factors) for input A
|
|
b_tuple: Tuple of (quantized_tensor, scale_factors) for input B
|
|
c: Output tensor in BF16 format
|
|
num_groups: Number of groups for grouped GEMM
|
|
|
|
Returns:
|
|
Result tensor in BF16 format
|
|
"""
|
|
# Unpack the tuples
|
|
a, a_scale = a_tuple
|
|
b, b_scale = b_tuple
|
|
|
|
M, K = a.shape
|
|
_, N = b.shape
|
|
|
|
# Configure block sizes - must be multiples of 32 for TMA alignment
|
|
BLOCK_SIZE_M = 128
|
|
BLOCK_SIZE_N = 128
|
|
BLOCK_SIZE_K = 128
|
|
|
|
# Calculate grid dimensions
|
|
num_pid_m = triton.cdiv(M, BLOCK_SIZE_M)
|
|
num_pid_n = triton.cdiv(N, BLOCK_SIZE_N)
|
|
num_groups_grid = triton.cdiv(num_pid_m, num_groups)
|
|
|
|
# 3D grid launch - (group, n_blocks, m_blocks_per_group)
|
|
grid = (num_groups_grid, num_pid_n, min(num_groups, num_pid_m))
|
|
|
|
fp8_gemm_group_triton_kernel[grid](
|
|
a,
|
|
b,
|
|
c,
|
|
a_scale,
|
|
b_scale,
|
|
M,
|
|
N,
|
|
K,
|
|
a.stride(0),
|
|
a.stride(1),
|
|
b.stride(0),
|
|
b.stride(1),
|
|
c.stride(0),
|
|
c.stride(1),
|
|
a_scale.stride(0),
|
|
1, # Stride in the K dimension may be 1
|
|
b_scale.stride(0),
|
|
1 if b_scale.dim() > 1 else 0,
|
|
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
|
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
|
GROUP_SIZE_M=num_groups,
|
|
)
|
|
|
|
return c
|
|
|
|
|
|
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
|
|
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
|
x_fp8_grouped,
|
|
y_fp8_grouped,
|
|
out,
|
|
m_indices,
|
|
)
|
|
return out
|
|
|
|
|
|
def calculate_diff(m: int, n: int, k: int, num_groups: int):
|
|
print(f"Shape (m={m}, n={n}, k={k}")
|
|
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
|
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
|
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
|
|
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
|
|
)
|
|
m_per_group = m // num_groups
|
|
out_deepgemm = out.clone()
|
|
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
|
|
m_indices = (
|
|
m_indices.unsqueeze(-1).expand(num_groups, m_per_group).contiguous().view(-1)
|
|
)
|
|
|
|
fp8_gemm_group_deepgemm(
|
|
x_fp8_grouped,
|
|
y_fp8_grouped,
|
|
out_deepgemm,
|
|
m_indices,
|
|
)
|
|
torch.cuda.synchronize()
|
|
|
|
# Prepare inputs for Triton
|
|
a, a_scale = x_fp8_flat
|
|
b, b_scale = y_fp8_flat
|
|
b = b.T.contiguous()
|
|
# Ensure scales are in the right format and contiguous
|
|
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
|
|
M, _ = a.shape
|
|
_, N = b.shape
|
|
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
|
out_triton = fp8_gemm_group_triton((a, a_scale), (b, b_scale), c, num_groups)
|
|
torch.cuda.synchronize()
|
|
|
|
diff_torch_deepgemm = torch.abs(out_torch - out_deepgemm).mean().item()
|
|
diff_torch_triton = torch.abs(out_torch - out_triton).mean().item()
|
|
diff_deepgemm_triton = torch.abs(out_deepgemm - out_triton).mean().item()
|
|
|
|
print(f"Shape m={m}, n={n}, k={k}:")
|
|
print(f"Torch output: {out_torch[0, 0:5]}")
|
|
print(f"DeepGEMM output: {out_deepgemm[0, 0:5]}")
|
|
print(f"Triton output: {out_triton[0, 0:5]}")
|
|
print(f"Mean absolute difference (Torch-DeepGEMM): {diff_torch_deepgemm}")
|
|
print(f"Mean absolute difference (Torch-Triton): {diff_torch_triton}")
|
|
print(f"Mean absolute difference (DeepGEMM-Triton): {diff_deepgemm_triton}")
|
|
|
|
deepgemm_torch_diff = calc_diff(out_deepgemm, out_torch)
|
|
triton_torch_diff = calc_diff(out_triton, out_torch)
|
|
deepgemm_triton_diff = calc_diff(out_deepgemm, out_triton)
|
|
|
|
DIFF_THRESHOLD = 0.001
|
|
all_match = (
|
|
deepgemm_torch_diff < DIFF_THRESHOLD
|
|
and triton_torch_diff < DIFF_THRESHOLD
|
|
and deepgemm_triton_diff < DIFF_THRESHOLD
|
|
)
|
|
if all_match:
|
|
print("✅ All implementations match\n")
|
|
else:
|
|
print("❌ Some implementations differ:")
|
|
print(
|
|
f" - Torch vs DeepGEMM: {'✅' if deepgemm_torch_diff < DIFF_THRESHOLD else '❌'}"
|
|
f" - Torch vs Triton: {'✅' if triton_torch_diff < DIFF_THRESHOLD else '❌'}"
|
|
f" - DeepGEMM vs Triton: {'✅' if deepgemm_triton_diff < DIFF_THRESHOLD else '❌'}"
|
|
)
|
|
|
|
|
|
def get_weight_shapes(tp_size):
|
|
# cannot TP
|
|
total = [
|
|
(512 + 64, 7168),
|
|
((128 + 64) * 128, 7168),
|
|
(128 * (128 + 128), 512),
|
|
(7168, 16384),
|
|
(7168, 18432),
|
|
]
|
|
# N can TP
|
|
n_tp = [
|
|
(18432 * 2, 7168),
|
|
((128 + 64) * 128, 7168),
|
|
(128 * (128 + 128), 512),
|
|
(24576, 1536),
|
|
(4096, 7168),
|
|
]
|
|
# K can TP
|
|
k_tp = [(7168, 18432), (7168, 16384), (7168, 2048)]
|
|
|
|
weight_shapes = []
|
|
for t in total:
|
|
weight_shapes.append(t)
|
|
for n_t in n_tp:
|
|
new_t = (n_t[0] // tp_size, n_t[1])
|
|
weight_shapes.append(new_t)
|
|
for k_t in k_tp:
|
|
new_t = (k_t[0], k_t[1] // tp_size)
|
|
weight_shapes.append(new_t)
|
|
|
|
return weight_shapes
|
|
|
|
|
|
def create_benchmark_configs(tp_size):
|
|
configs = []
|
|
weight_shapes = get_weight_shapes(tp_size)
|
|
batch_sizes = [2048, 4096]
|
|
group_sizes = [4, 8]
|
|
for n, k in weight_shapes:
|
|
for m in batch_sizes:
|
|
for num_groups in group_sizes:
|
|
configs.append((m, n, k, num_groups, tp_size))
|
|
|
|
return configs
|
|
|
|
|
|
def get_benchmark(tp_size):
|
|
all_configs = create_benchmark_configs(tp_size)
|
|
|
|
@triton.testing.perf_report(
|
|
triton.testing.Benchmark(
|
|
x_names=["m", "n", "k", "num_groups", "tp_size"],
|
|
x_vals=[config for config in all_configs],
|
|
line_arg="provider",
|
|
line_vals=["deepgemm", "triton"],
|
|
line_names=["DeepGEMM", "Triton"],
|
|
styles=[("blue", "-"), ("red", "-")],
|
|
ylabel="ms",
|
|
plot_name=f"fp8-group-gemm-performance-comparison-tp{tp_size}",
|
|
args={},
|
|
)
|
|
)
|
|
def benchmark(m, n, k, num_groups, tp_size, provider):
|
|
print(
|
|
f"Shape (m={m}, n={n}, k={k}, tp={tp_size}, num_groups={num_groups}, Provider: {provider}"
|
|
)
|
|
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
|
|
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
|
|
x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, out_torch = (
|
|
construct_grouped_and_flat_fp8(x, y, num_groups, is_masked=False)
|
|
)
|
|
m_per_group = m // num_groups
|
|
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
|
|
m_indices = (
|
|
m_indices.unsqueeze(-1)
|
|
.expand(num_groups, m_per_group)
|
|
.contiguous()
|
|
.view(-1)
|
|
)
|
|
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
|
|
if provider == "deepgemm":
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: fp8_gemm_group_deepgemm(
|
|
x_fp8_grouped,
|
|
y_fp8_grouped,
|
|
out,
|
|
m_indices,
|
|
),
|
|
quantiles=quantiles,
|
|
)
|
|
elif provider == "triton":
|
|
# Prepare inputs for Triton
|
|
# We did it outside of the lambda function to make it fair comparison like deepgemm
|
|
a, a_scale = x_fp8_flat
|
|
b, b_scale = y_fp8_flat
|
|
b = b.T.contiguous()
|
|
# Ensure scales are in the right format and contiguous
|
|
a_scale, b_scale = a_scale.contiguous(), b_scale.contiguous()
|
|
M, _ = a.shape
|
|
_, N = b.shape
|
|
c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
|
|
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
lambda: fp8_gemm_group_triton(
|
|
(a, a_scale),
|
|
(b, b_scale),
|
|
c,
|
|
num_groups,
|
|
),
|
|
quantiles=quantiles,
|
|
)
|
|
|
|
# Calculate TFLOPS
|
|
flops = 2 * m * n * k # multiply-adds
|
|
tflops = flops / (ms * 1e-3) / 1e12
|
|
|
|
print(f"Time: {ms*1000:.2f} ms, TFLOPS: {tflops:.2f}")
|
|
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
|
|
|
|
return benchmark
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--save_path",
|
|
type=str,
|
|
default="./configs/benchmark_ops/fp8_group_gemm/",
|
|
help="Path to save deepgemm fp8 group gemm benchmark results",
|
|
)
|
|
parser.add_argument(
|
|
"--run_correctness",
|
|
action="store_true",
|
|
help="Whether to run correctness test",
|
|
)
|
|
parser.add_argument(
|
|
"--tp_size",
|
|
type=int,
|
|
default=1,
|
|
help="Tensor parallelism size to benchmark (default: 1)",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Set random seed for reproducibility
|
|
torch.manual_seed(0)
|
|
torch.cuda.manual_seed(0)
|
|
|
|
# Enable TF32, adapted from https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py#L148
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
# Run correctness tests on a few examples
|
|
if args.run_correctness:
|
|
print("Running correctness tests...")
|
|
calculate_diff(8192, 7168, 4096, 4)
|
|
calculate_diff(8192, 2048, 7168, 4)
|
|
calculate_diff(4096, 7168, 4096, 8)
|
|
calculate_diff(4096, 2048, 7168, 8)
|
|
calculate_diff(4096, 576, 7168, 8)
|
|
|
|
# Get the benchmark function with the specified tp_size
|
|
benchmark = get_benchmark(args.tp_size)
|
|
|
|
print(f"Running performance benchmark for TP size = {args.tp_size}...")
|
|
benchmark.run(print_data=True, save_path=args.save_path)
|