517 lines
17 KiB
Python
517 lines
17 KiB
Python
# python3 benchmark/fbgemm/benchmark_fbgemm_grouped_gemm.py --model Qwen/Qwen2-57B-A14B-Instruct --tp-size 4 --use-fp8-w8a8
|
|
import argparse
|
|
|
|
import torch
|
|
import triton
|
|
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
|
|
quantize_fp8_row,
|
|
triton_quantize_fp8_row,
|
|
)
|
|
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
|
|
grouped_gemm as fbgemm_grouped_gemm,
|
|
)
|
|
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
|
|
grouped_gemm_fp8_rowwise as fbgemm_grouped_gemm_fp8_rowwise,
|
|
)
|
|
from transformers import AutoConfig
|
|
|
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
grouped_gemm_triton as sglang_grouped_gemm,
|
|
)
|
|
|
|
|
|
def get_model_config(model_name: str, tp_size: int):
|
|
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
|
|
|
if config.architectures[0] == "DbrxForCausalLM":
|
|
num_groups = config.ffn_config.moe_num_experts
|
|
intermediate_size = config.ffn_config.ffn_hidden_size
|
|
elif config.architectures[0] == "JambaForCausalLM":
|
|
num_groups = config.num_experts
|
|
intermediate_size = config.intermediate_size
|
|
elif config.architectures[0] == "Qwen2MoeForCausalLM":
|
|
num_groups = config.num_experts
|
|
intermediate_size = config.moe_intermediate_size
|
|
elif config.architectures[0] == "Qwen3MoeForCausalLM":
|
|
num_groups = config.num_experts
|
|
intermediate_size = config.moe_intermediate_size
|
|
elif config.architectures[0] in [
|
|
"DeepseekV2ForCausalLM",
|
|
"DeepseekV3ForCausalLM",
|
|
]:
|
|
num_groups = config.n_routed_experts
|
|
intermediate_size = config.moe_intermediate_size
|
|
elif config.architectures[0] == "Llama4ForConditionalGeneration":
|
|
num_groups = config.text_config.num_local_experts
|
|
intermediate_size = config.text_config.intermediate_size
|
|
elif config.architectures[0] in [
|
|
"Grok1ForCausalLM",
|
|
"Grok1ImgGen",
|
|
"Grok1AForCausalLM",
|
|
]:
|
|
num_groups = config.num_local_experts
|
|
intermediate_size = config.moe_intermediate_size
|
|
else:
|
|
num_groups = config.num_local_experts
|
|
intermediate_size = config.intermediate_size
|
|
|
|
shape_configs = {
|
|
"num_groups": num_groups,
|
|
"hidden_size": config.hidden_size,
|
|
"intermediate_size": intermediate_size,
|
|
"dtype": config.torch_dtype,
|
|
}
|
|
print(f"{shape_configs=}")
|
|
return shape_configs
|
|
|
|
|
|
def create_test_data(batch_size, num_groups, hidden_size, intermediate_size):
|
|
torch.manual_seed(42)
|
|
|
|
tokens_per_group = batch_size // num_groups
|
|
m_sizes = torch.full(
|
|
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
|
|
)
|
|
|
|
x = torch.randn(batch_size, hidden_size, dtype=torch.bfloat16, device="cuda")
|
|
|
|
base_weights = torch.randn(
|
|
num_groups, intermediate_size, hidden_size, dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
|
|
w_fbgemm = base_weights.reshape(num_groups * intermediate_size, hidden_size)
|
|
w_sglang = base_weights
|
|
|
|
c_fbgemm = torch.empty(
|
|
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
c_sglang = torch.empty(
|
|
batch_size, intermediate_size, dtype=torch.bfloat16, device="cuda"
|
|
)
|
|
|
|
seg_indptr = torch.zeros(num_groups + 1, dtype=torch.int32, device="cuda")
|
|
for i in range(1, num_groups + 1):
|
|
seg_indptr[i] = seg_indptr[i - 1] + tokens_per_group
|
|
|
|
weight_indices = torch.arange(num_groups, dtype=torch.int32, device="cuda")
|
|
|
|
return (
|
|
x,
|
|
w_fbgemm,
|
|
w_sglang,
|
|
c_fbgemm,
|
|
c_sglang,
|
|
m_sizes,
|
|
seg_indptr,
|
|
weight_indices,
|
|
)
|
|
|
|
|
|
def create_fp8_test_data(
|
|
batch_size, num_groups, hidden_size, intermediate_size, backend="triton"
|
|
):
|
|
"""
|
|
Create test data for FP8 grouped GEMM operations.
|
|
|
|
Args:
|
|
batch_size: Total batch size
|
|
num_groups: Number of groups
|
|
hidden_size: Hidden dimension size
|
|
intermediate_size: Intermediate dimension size
|
|
backend: "triton" for Triton GEMM, "cutlass" for CUTLASS GEMM
|
|
|
|
Returns:
|
|
For triton: (x_fp8, w_fp8, m_sizes, x_scale, w_scale)
|
|
For cutlass: (x, wq, w_scale, m_sizes)
|
|
"""
|
|
torch.manual_seed(42)
|
|
|
|
tokens_per_group = batch_size // num_groups
|
|
|
|
# Create weight matrices for each group
|
|
w_list = []
|
|
for _ in range(num_groups):
|
|
w = torch.randn(
|
|
intermediate_size, hidden_size, dtype=torch.float16, device="cuda"
|
|
)
|
|
w_list.append(w)
|
|
|
|
# Quantize weights using quantize_fp8_row for each group
|
|
wq_list, w_scale_list = zip(*[quantize_fp8_row(w) for w in w_list])
|
|
|
|
if backend == "triton":
|
|
# Triton format: concatenated weights
|
|
w_fp8 = torch.concat(wq_list, dim=0).contiguous()
|
|
w_scale = torch.concat(w_scale_list, dim=0).contiguous()
|
|
|
|
# Create m_sizes as int32 for triton
|
|
m_sizes = torch.full(
|
|
(num_groups,), tokens_per_group, dtype=torch.int32, device="cuda"
|
|
)
|
|
|
|
# Create and quantize input
|
|
x_fp16 = torch.randn(
|
|
batch_size, hidden_size, dtype=torch.float16, device="cuda"
|
|
)
|
|
x_fp8, x_scale = triton_quantize_fp8_row(x_fp16)
|
|
x_scale = x_scale.view(batch_size, -1)
|
|
|
|
return x_fp8, w_fp8, m_sizes, x_scale, w_scale
|
|
|
|
elif backend == "cutlass":
|
|
# CUTLASS format: stacked weights
|
|
wq = torch.stack(wq_list, dim=0).contiguous()
|
|
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
|
|
|
|
# Create m_sizes as int64 for cutlass
|
|
m_values = [tokens_per_group] * num_groups
|
|
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device="cuda")
|
|
|
|
# Create input data - separate for each group then concat
|
|
x_list = []
|
|
for _ in range(num_groups):
|
|
x = torch.randn(
|
|
tokens_per_group, hidden_size, dtype=torch.float16, device="cuda"
|
|
)
|
|
x_list.append(x)
|
|
|
|
# Concatenate inputs into single tensor
|
|
x = torch.concat(x_list, dim=0).contiguous()
|
|
|
|
return x, wq, w_scale, m_sizes
|
|
|
|
else:
|
|
raise ValueError(f"Unsupported backend: {backend}")
|
|
|
|
|
|
def calculate_memory_bandwidth(m_sizes, hidden_size, intermediate_size, dtype):
|
|
"""
|
|
Calculate memory bandwidth based on accessed expert weights.
|
|
|
|
Args:
|
|
m_sizes: Tensor containing batch sizes for each group
|
|
hidden_size: Hidden dimension size
|
|
intermediate_size: Intermediate dimension size
|
|
dtype: Data type of weights
|
|
|
|
Returns:
|
|
Memory size in bytes for accessed expert weights
|
|
"""
|
|
# Count non-zero groups (active experts)
|
|
if hasattr(m_sizes, "cpu"):
|
|
active_experts = torch.count_nonzero(m_sizes).item()
|
|
else:
|
|
active_experts = sum(1 for m in m_sizes if m > 0)
|
|
|
|
# Calculate bytes per element based on dtype
|
|
if dtype in [torch.float16, torch.bfloat16]:
|
|
bytes_per_element = 2
|
|
elif dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
|
bytes_per_element = 1
|
|
elif dtype == torch.float32:
|
|
bytes_per_element = 4
|
|
else:
|
|
# Default to 2 bytes for unknown dtypes
|
|
bytes_per_element = 2
|
|
|
|
# Memory per expert weight matrix
|
|
memory_per_expert = hidden_size * intermediate_size * bytes_per_element
|
|
|
|
# Total memory for active experts
|
|
total_memory_bytes = active_experts * memory_per_expert
|
|
|
|
return total_memory_bytes
|
|
|
|
|
|
def get_benchmark_config(use_fp8_w8a8=False):
|
|
if use_fp8_w8a8:
|
|
return {
|
|
"line_vals": [
|
|
"fbgemm_triton_grouped_gemm_fp8",
|
|
"fbgemm_cutlass_f8f8bf16_rowwise",
|
|
"sglang_grouped_gemm",
|
|
],
|
|
"line_names": [
|
|
"FBGEMM Triton Grouped GEMM FP8",
|
|
"FBGEMM CUTLASS F8F8BF16 Rowwise",
|
|
"SGLang Grouped GEMM FP8",
|
|
],
|
|
"styles": [("blue", "-"), ("orange", "-"), ("red", "-")],
|
|
}
|
|
else:
|
|
return {
|
|
"line_vals": ["fbgemm_triton_grouped_gemm", "sglang_grouped_gemm"],
|
|
"line_names": [
|
|
"FBGEMM Triton Grouped GEMM BF16",
|
|
"SGLang Grouped GEMM BF16",
|
|
],
|
|
"styles": [("blue", "-"), ("green", "-")],
|
|
}
|
|
|
|
|
|
def run_benchmark(
|
|
model_config, use_fp8_w8a8=False, save_path="./benchmark_grouped_gemm/"
|
|
):
|
|
config = get_benchmark_config(use_fp8_w8a8)
|
|
|
|
benchmark_config = triton.testing.Benchmark(
|
|
x_names=["batch_size"],
|
|
x_vals=[256, 512, 1024, 2048, 4096],
|
|
line_arg="provider",
|
|
line_vals=config["line_vals"],
|
|
line_names=config["line_names"],
|
|
styles=config["styles"],
|
|
ylabel="Bandwidth (GB/s)",
|
|
plot_name="grouped-gemm-performance",
|
|
args={},
|
|
)
|
|
|
|
@triton.testing.perf_report(benchmark_config)
|
|
def dynamic_benchmark(batch_size, provider, model_config, use_fp8_w8a8=False):
|
|
print(f"Benchmarking {provider} with batch_size={batch_size}")
|
|
torch.cuda.manual_seed_all(0)
|
|
|
|
num_groups = model_config["num_groups"]
|
|
hidden_size = model_config["hidden_size"]
|
|
intermediate_size = model_config["intermediate_size"]
|
|
|
|
if provider == "fbgemm_triton_grouped_gemm_fp8":
|
|
try:
|
|
test_data = create_fp8_test_data(
|
|
batch_size,
|
|
num_groups,
|
|
hidden_size,
|
|
intermediate_size,
|
|
backend="triton",
|
|
)
|
|
x_fp8, w_fp8, m_sizes, x_scale, w_scale = test_data
|
|
|
|
# Calculate memory bandwidth
|
|
memory_bytes = calculate_memory_bandwidth(
|
|
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
|
|
)
|
|
|
|
def run_func():
|
|
return fbgemm_grouped_gemm_fp8_rowwise(
|
|
x_fp8, w_fp8, m_sizes, x_scale, w_scale, use_fast_accum=True
|
|
)
|
|
|
|
except Exception as e:
|
|
print(f"FP8 not supported, skipping: {e}")
|
|
return float("inf"), float("inf"), float("inf")
|
|
|
|
elif provider == "fbgemm_cutlass_f8f8bf16_rowwise":
|
|
try:
|
|
test_data = create_fp8_test_data(
|
|
batch_size,
|
|
num_groups,
|
|
hidden_size,
|
|
intermediate_size,
|
|
backend="cutlass",
|
|
)
|
|
x, wq, w_scale, m_sizes = test_data
|
|
|
|
# Calculate memory bandwidth
|
|
memory_bytes = calculate_memory_bandwidth(
|
|
m_sizes, hidden_size, intermediate_size, torch.float8_e4m3fn
|
|
)
|
|
|
|
# Quantize input using triton_quantize_fp8_row
|
|
xq, x_scale = triton_quantize_fp8_row(x)
|
|
x_scale = x_scale.view(batch_size, -1)
|
|
|
|
def run_func():
|
|
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
|
|
xq, wq, x_scale, w_scale, m_sizes
|
|
)
|
|
|
|
except Exception as e:
|
|
print(
|
|
f"CUTLASS f8f8bf16_rowwise_grouped_stacked not supported, "
|
|
f"skipping: {e}"
|
|
)
|
|
return float("inf"), float("inf"), float("inf")
|
|
else:
|
|
test_data = create_test_data(
|
|
batch_size, num_groups, hidden_size, intermediate_size
|
|
)
|
|
(
|
|
x,
|
|
w_fbgemm,
|
|
w_sglang,
|
|
c_fbgemm,
|
|
c_sglang,
|
|
m_sizes,
|
|
seg_indptr,
|
|
weight_indices,
|
|
) = test_data
|
|
|
|
# Calculate memory bandwidth for BF16 operations
|
|
memory_bytes = calculate_memory_bandwidth(
|
|
m_sizes, hidden_size, intermediate_size, torch.bfloat16
|
|
)
|
|
|
|
if provider == "fbgemm_triton_grouped_gemm":
|
|
|
|
def run_func():
|
|
return fbgemm_grouped_gemm(
|
|
x, w_fbgemm, m_sizes, use_fast_accum=True
|
|
)
|
|
|
|
else:
|
|
|
|
def run_func():
|
|
return sglang_grouped_gemm(
|
|
x,
|
|
w_sglang,
|
|
c_sglang,
|
|
num_groups,
|
|
weight_column_major=True,
|
|
seg_indptr=seg_indptr,
|
|
weight_indices=weight_indices,
|
|
c_dtype=c_sglang.dtype,
|
|
)
|
|
|
|
for _ in range(10):
|
|
try:
|
|
run_func()
|
|
except Exception as e:
|
|
print(f"Error during warmup for {provider}: {e}")
|
|
return float("inf"), float("inf"), float("inf")
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
try:
|
|
quantiles = [0.5, 0.2, 0.8]
|
|
ms, min_ms, max_ms = triton.testing.do_bench(run_func, quantiles=quantiles)
|
|
|
|
# Convert time (ms) to bandwidth (GB/s)
|
|
# Bandwidth = Memory (bytes) / Time (seconds)
|
|
# Convert ms to seconds and bytes to GB (1e9)
|
|
gb_per_s = (memory_bytes / 1e9) / (ms / 1000)
|
|
# min bandwidth = max time, max bandwidth = min time
|
|
min_gb_per_s = (memory_bytes / 1e9) / (max_ms / 1000)
|
|
max_gb_per_s = (memory_bytes / 1e9) / (min_ms / 1000)
|
|
|
|
return gb_per_s, min_gb_per_s, max_gb_per_s
|
|
except Exception as e:
|
|
print(f"Error during benchmarking for {provider}: {e}")
|
|
return 0.0, 0.0, 0.0
|
|
|
|
dynamic_benchmark.run(
|
|
show_plots=True,
|
|
print_data=True,
|
|
save_path=save_path,
|
|
model_config=model_config,
|
|
use_fp8_w8a8=use_fp8_w8a8,
|
|
)
|
|
|
|
|
|
def verify_correctness(model_config):
|
|
print("Verifying correctness...")
|
|
batch_size = 128
|
|
num_groups = model_config["num_groups"]
|
|
hidden_size = model_config["hidden_size"]
|
|
intermediate_size = model_config["intermediate_size"]
|
|
|
|
test_data = create_test_data(batch_size, num_groups, hidden_size, intermediate_size)
|
|
(
|
|
x,
|
|
w_fbgemm,
|
|
w_sglang,
|
|
c_fbgemm,
|
|
c_sglang,
|
|
m_sizes,
|
|
seg_indptr,
|
|
weight_indices,
|
|
) = test_data
|
|
|
|
result_fbgemm = fbgemm_grouped_gemm(x, w_fbgemm, m_sizes, use_fast_accum=True)
|
|
|
|
result_sglang = sglang_grouped_gemm(
|
|
x,
|
|
w_sglang,
|
|
c_sglang,
|
|
num_groups,
|
|
weight_column_major=True,
|
|
seg_indptr=seg_indptr,
|
|
weight_indices=weight_indices,
|
|
c_dtype=c_sglang.dtype,
|
|
)
|
|
|
|
if torch.allclose(result_fbgemm, result_sglang, rtol=1e-3, atol=1e-3):
|
|
print("✓ BF16 Correctness verification passed!")
|
|
else:
|
|
max_diff = torch.max(torch.abs(result_fbgemm - result_sglang))
|
|
print(f"✗ BF16 Correctness verification failed! Max diff: {max_diff}")
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Benchmark FBGEMM vs SGLang Grouped GEMM"
|
|
)
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
|
|
help="Model name to get configuration from",
|
|
)
|
|
parser.add_argument(
|
|
"--tp-size", type=int, default=1, help="Tensor parallelism size"
|
|
)
|
|
parser.add_argument(
|
|
"--use-fp8-w8a8", action="store_true", help="Enable FP8 W8A8 benchmark"
|
|
)
|
|
parser.add_argument(
|
|
"--save-path",
|
|
type=str,
|
|
default="./benchmark_grouped_gemm/",
|
|
help="Path to save benchmark results",
|
|
)
|
|
parser.add_argument(
|
|
"--verify-correctness",
|
|
action="store_true",
|
|
help="Verify correctness before benchmarking",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
model_config = get_model_config(args.model, args.tp_size)
|
|
except Exception as e:
|
|
print(f"Failed to get model config: {e}")
|
|
print("Using default configuration...")
|
|
model_config = {
|
|
"num_groups": 8,
|
|
"hidden_size": 4096,
|
|
"intermediate_size": 14336,
|
|
"dtype": torch.bfloat16,
|
|
}
|
|
|
|
print("Running benchmark with:")
|
|
print(f" num_groups: {model_config['num_groups']}")
|
|
print(f" hidden_size: {model_config['hidden_size']}")
|
|
print(f" intermediate_size: {model_config['intermediate_size']}")
|
|
print(f" use_fp8_w8a8: {args.use_fp8_w8a8}")
|
|
|
|
if args.verify_correctness:
|
|
if not verify_correctness(model_config):
|
|
print("Correctness verification failed. Exiting...")
|
|
return
|
|
|
|
try:
|
|
run_benchmark(
|
|
model_config=model_config,
|
|
use_fp8_w8a8=args.use_fp8_w8a8,
|
|
save_path=args.save_path,
|
|
)
|
|
except Exception as e:
|
|
print(f"Benchmark failed: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|