import argparse import sys from routines.attention import parse_attention_args, run_attention_test from routines.flashinfer_benchmark_utils import ( benchmark_apis, full_output_columns, output_column_dict, ) from routines.gemm import parse_gemm_args, run_gemm_test from routines.moe import parse_moe_args, run_moe_test def run_test(args): """ Route & run a single FlashInfer test case with test routine. Args: args: Parsed command line arguments containing test configuration """ ## Depending on routine type, route to corresponding test routine if args.routine in benchmark_apis["attention"]: res = run_attention_test(args) elif args.routine in benchmark_apis["gemm"]: res = run_gemm_test(args) elif args.routine in benchmark_apis["moe"]: res = run_moe_test(args) else: raise ValueError(f"Unsupported routine: {args.routine}") # Write results to output file if specified if args.output_path is not None: with open(args.output_path, "a") as fout: for cur_res in res: for key in output_column_dict["general"]: cur_res[key] = getattr(args, key) output_line = ",".join( [str(cur_res[col]) for col in full_output_columns] ) fout.write(output_line + "\n") fout.flush() return def parse_args(line=sys.argv[1:]): """ Parse command line arguments for test configuration. First parse shared arguments, then parse routine-specific arguments. Args: line: Command line arguments (default: sys.argv[1:]) Returns: Parsed argument namespace """ ## Shared arguments parser = argparse.ArgumentParser() parser.add_argument( "--routine", "-R", type=str, required=True, choices=list(benchmark_apis["attention"]) + list(benchmark_apis["gemm"]) + list(benchmark_apis["moe"]), ) args, _ = parser.parse_known_args(line[:]) parser.add_argument( "--no_cuda_graph", action="store_true", default=False, help="Disable CUDA graph to execute kernels outside of the graph.", ) parser.add_argument( "--refcheck", action="store_true", default=False, help="Run reference check that ensures outputs correct.", ) parser.add_argument( "--allow_output_mismatch", action="store_true", default=False, help="Allow output mismatch between backends during reference checks. Error message will be printed but test will continue.", ) parser.add_argument( "--random_seed", type=int, default=42, help="Random seed for reproducibility." ) parser.add_argument( "--verbose", "-v", action="count", help="Set verbosity level.", default=0 ) parser.add_argument( "--output_path", type=str, required=False, default=None, help="Output path for results. If not specified, results will not be written to a file.", ) parser.add_argument( "--num_iters", "-n", type=int, required=False, default=30, help="Number of iterations to run for measurement.", ) parser.add_argument( "--dry_run_iters", "-d", type=int, required=False, default=5, help="Number of dry runs.", ) parser.add_argument( "--case_tag", type=str, required=False, default=None, help="Optional tag for the test case for annotating output.", ) parser.add_argument( "--generate_repro_command", action="store_true", default=False, help="If set, will print reproducer command and store it to output csv.", ) parser.add_argument( "--repro_command", type=str, required=False, default="", help="Placeholder for generated reproducer command for the test case. Not to be used directly.", ) ## Check routine and pass on to routine-specific argument parser if args.routine in benchmark_apis["attention"]: args = parse_attention_args(line, parser) elif args.routine in benchmark_apis["gemm"]: args = parse_gemm_args(line, parser) elif args.routine in benchmark_apis["moe"]: args = parse_moe_args(line, parser) else: raise ValueError(f"Unsupported routine: {args.routine}") if args.generate_repro_command: args.repro_command = "python3 flashinfer_benchmark.py " + " ".join(line) return args if __name__ == "__main__": # Parse testlist argument first testlist_parser = argparse.ArgumentParser(add_help=False) testlist_parser.add_argument( "--testlist", type=str, required=False, default=None, help="Optional testlist file to run multiple cases.", ) testlist_parser.add_argument( "--output_path", type=str, required=False, default=None, help="Output path for results csv.", ) testlist_args, _ = testlist_parser.parse_known_args() # Setup output file if specified if testlist_args.output_path is not None: with open(testlist_args.output_path, "w") as fout: fout.write(",".join(full_output_columns) + "\n") # Process tests either from testlist file or command line arguments if testlist_args.testlist is not None: # If testlist, run each test in the testlist with open(testlist_args.testlist, "r") as f: import shlex for line in f.readlines(): line = line.strip() if not line or line.startswith("#"): continue try: line_args = parse_args(shlex.split(line)) line_args.output_path = testlist_args.output_path run_test(line_args) except Exception as e: print(f"[ERROR] Error running test: {line}") print(f"[ERROR] Error: {e}") continue else: # If no testlist, just run the command args = parse_args() args.output_path = testlist_args.output_path run_test(args)