sglang_v0.5.2/flashinfer_0.3.1/benchmarks/routines/flashinfer_benchmark_utils.py

157 lines
4.1 KiB
Python

import torch
from flashinfer.testing.utils import set_seed
# Output columns for the test results.
output_column_dict = {
"perf": [
"routine",
"median_time",
"std_time",
"tflops",
"tb_per_sec",
"backend",
],
"attention": [
"page_size",
"batch_size",
"s_qo",
"s_kv",
"num_qo_heads",
"num_kv_heads",
"head_dim_qk",
"head_dim_vo",
"head_dim_ckv",
"head_dim_kpe",
"causal",
"q_dtype",
"kv_dtype",
"avg_actual_seq_len",
"random_actual_seq_len",
],
"gemm": [
"m",
"n",
"k",
"group_size",
"tile_size",
"scale_major_mode",
"out_dtype",
"mma_sm",
"use_128x4_sf_layout",
],
"moe": [
"num_tokens",
"hidden_size",
"intermediate_size",
"num_experts",
"top_k",
"n_group",
"topk_group",
"routed_scaling_factor",
"local_expert_offset",
"local_num_experts",
"tile_tokens_dim",
"routing_method",
"use_shuffled_weight",
"weight_layout",
"use_routing_bias",
"use_routing_scales_on_input",
"input_dtype",
"weight_dtype",
"gated_act",
# CUTLASS fused MoE specific
"cutlass_variant",
"quantized_input",
"tp_size",
"tp_rank",
"ep_size",
"ep_rank",
],
"general": [
"refcheck",
"no_cuda_graph",
"allow_output_mismatch",
"random_seed",
"case_tag",
"generate_repro_command",
"repro_command",
],
}
full_output_columns = (
output_column_dict["perf"]
+ output_column_dict["attention"]
+ output_column_dict["gemm"]
+ output_column_dict["moe"]
+ output_column_dict["general"]
)
benchmark_apis = {
"attention": [
"BatchDecodeWithPagedKVCacheWrapper",
"BatchPrefillWithPagedKVCacheWrapper",
"BatchPrefillWithRaggedKVCacheWrapper",
"BatchMLAPagedAttentionWrapper",
],
"gemm": [
"gemm_fp8_nt_groupwise",
"group_gemm_fp8_nt_groupwise",
"bmm_fp8",
"mm_fp4",
],
"moe": [
"trtllm_fp4_block_scale_moe",
"trtllm_fp8_block_scale_moe",
"trtllm_fp8_per_tensor_scale_moe",
"cutlass_fused_moe",
],
}
def print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec):
output_backend_width = 15
print(
f"[PERF] {backend.ljust(output_backend_width)[:output_backend_width]}:: median time {median_time:.3f} ms; std {std_time:.3f} ms; achieved tflops {tflops:.3f} TFLOPs/sec; achieved tb_per_sec {tb_per_sec:.3f} TB/sec"
)
def get_device(args):
# Synchronize to ensure that the device is ready after previous tests
torch.cuda.empty_cache()
torch.cuda.synchronize()
set_seed(args.random_seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()).replace(" ", "_")
if args.verbose >= 2:
print(f"[VVERBOSE] {gpu_name = }")
return device
def is_close_stats(input, other, rtol=1e-5, atol=1e-8):
close_tensor = torch.isclose(input, other, rtol=rtol, atol=atol)
num_elements = close_tensor.numel()
num_different_elements = num_elements - close_tensor.sum().item()
return (
num_different_elements, # number of different elements
num_elements, # total number of elements in tensor
num_different_elements / num_elements * 100.0,
)
def dtype_str_to_torch_dtype(dtype_str):
if dtype_str == "bfloat16":
return torch.bfloat16
elif dtype_str == "float16":
return torch.float16
elif dtype_str == "float32":
return torch.float32
elif dtype_str == "float64":
return torch.float64
elif dtype_str == "fp8_e4m3":
return torch.float8_e4m3fn
elif dtype_str == "fp8_e5m2":
return torch.float8_e5m2
else:
raise ValueError(f"Unsupported dtype: {dtype_str}")