import torch from flashinfer.testing.utils import set_seed # Output columns for the test results. output_column_dict = { "perf": [ "routine", "median_time", "std_time", "tflops", "tb_per_sec", "backend", ], "attention": [ "page_size", "batch_size", "s_qo", "s_kv", "num_qo_heads", "num_kv_heads", "head_dim_qk", "head_dim_vo", "head_dim_ckv", "head_dim_kpe", "causal", "q_dtype", "kv_dtype", "avg_actual_seq_len", "random_actual_seq_len", ], "gemm": [ "m", "n", "k", "group_size", "tile_size", "scale_major_mode", "out_dtype", "mma_sm", "use_128x4_sf_layout", ], "moe": [ "num_tokens", "hidden_size", "intermediate_size", "num_experts", "top_k", "n_group", "topk_group", "routed_scaling_factor", "local_expert_offset", "local_num_experts", "tile_tokens_dim", "routing_method", "use_shuffled_weight", "weight_layout", "use_routing_bias", "use_routing_scales_on_input", "input_dtype", "weight_dtype", "gated_act", # CUTLASS fused MoE specific "cutlass_variant", "quantized_input", "tp_size", "tp_rank", "ep_size", "ep_rank", ], "general": [ "refcheck", "no_cuda_graph", "allow_output_mismatch", "random_seed", "case_tag", "generate_repro_command", "repro_command", ], } full_output_columns = ( output_column_dict["perf"] + output_column_dict["attention"] + output_column_dict["gemm"] + output_column_dict["moe"] + output_column_dict["general"] ) benchmark_apis = { "attention": [ "BatchDecodeWithPagedKVCacheWrapper", "BatchPrefillWithPagedKVCacheWrapper", "BatchPrefillWithRaggedKVCacheWrapper", "BatchMLAPagedAttentionWrapper", ], "gemm": [ "gemm_fp8_nt_groupwise", "group_gemm_fp8_nt_groupwise", "bmm_fp8", "mm_fp4", ], "moe": [ "trtllm_fp4_block_scale_moe", "trtllm_fp8_block_scale_moe", "trtllm_fp8_per_tensor_scale_moe", "cutlass_fused_moe", ], } def print_perf_metrics(backend, median_time, std_time, tflops, tb_per_sec): output_backend_width = 15 print( f"[PERF] {backend.ljust(output_backend_width)[:output_backend_width]}:: median time {median_time:.3f} ms; std {std_time:.3f} ms; achieved tflops {tflops:.3f} TFLOPs/sec; achieved tb_per_sec {tb_per_sec:.3f} TB/sec" ) def get_device(args): # Synchronize to ensure that the device is ready after previous tests torch.cuda.empty_cache() torch.cuda.synchronize() set_seed(args.random_seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()).replace(" ", "_") if args.verbose >= 2: print(f"[VVERBOSE] {gpu_name = }") return device def is_close_stats(input, other, rtol=1e-5, atol=1e-8): close_tensor = torch.isclose(input, other, rtol=rtol, atol=atol) num_elements = close_tensor.numel() num_different_elements = num_elements - close_tensor.sum().item() return ( num_different_elements, # number of different elements num_elements, # total number of elements in tensor num_different_elements / num_elements * 100.0, ) def dtype_str_to_torch_dtype(dtype_str): if dtype_str == "bfloat16": return torch.bfloat16 elif dtype_str == "float16": return torch.float16 elif dtype_str == "float32": return torch.float32 elif dtype_str == "float64": return torch.float64 elif dtype_str == "fp8_e4m3": return torch.float8_e4m3fn elif dtype_str == "fp8_e5m2": return torch.float8_e5m2 else: raise ValueError(f"Unsupported dtype: {dtype_str}")