sglang_v0.5.2/flashinfer_0.3.1/benchmarks/flashinfer_benchmark.py

202 lines
6.2 KiB
Python

import argparse
import sys
from routines.attention import parse_attention_args, run_attention_test
from routines.flashinfer_benchmark_utils import (
benchmark_apis,
full_output_columns,
output_column_dict,
)
from routines.gemm import parse_gemm_args, run_gemm_test
from routines.moe import parse_moe_args, run_moe_test
def run_test(args):
"""
Route & run a single FlashInfer test case with test routine.
Args:
args: Parsed command line arguments containing test configuration
"""
## Depending on routine type, route to corresponding test routine
if args.routine in benchmark_apis["attention"]:
res = run_attention_test(args)
elif args.routine in benchmark_apis["gemm"]:
res = run_gemm_test(args)
elif args.routine in benchmark_apis["moe"]:
res = run_moe_test(args)
else:
raise ValueError(f"Unsupported routine: {args.routine}")
# Write results to output file if specified
if args.output_path is not None:
with open(args.output_path, "a") as fout:
for cur_res in res:
for key in output_column_dict["general"]:
cur_res[key] = getattr(args, key)
output_line = ",".join(
[str(cur_res[col]) for col in full_output_columns]
)
fout.write(output_line + "\n")
fout.flush()
return
def parse_args(line=sys.argv[1:]):
"""
Parse command line arguments for test configuration.
First parse shared arguments, then parse routine-specific arguments.
Args:
line: Command line arguments (default: sys.argv[1:])
Returns:
Parsed argument namespace
"""
## Shared arguments
parser = argparse.ArgumentParser()
parser.add_argument(
"--routine",
"-R",
type=str,
required=True,
choices=list(benchmark_apis["attention"])
+ list(benchmark_apis["gemm"])
+ list(benchmark_apis["moe"]),
)
args, _ = parser.parse_known_args(line[:])
parser.add_argument(
"--no_cuda_graph",
action="store_true",
default=False,
help="Disable CUDA graph to execute kernels outside of the graph.",
)
parser.add_argument(
"--refcheck",
action="store_true",
default=False,
help="Run reference check that ensures outputs correct.",
)
parser.add_argument(
"--allow_output_mismatch",
action="store_true",
default=False,
help="Allow output mismatch between backends during reference checks. Error message will be printed but test will continue.",
)
parser.add_argument(
"--random_seed", type=int, default=42, help="Random seed for reproducibility."
)
parser.add_argument(
"--verbose", "-v", action="count", help="Set verbosity level.", default=0
)
parser.add_argument(
"--output_path",
type=str,
required=False,
default=None,
help="Output path for results. If not specified, results will not be written to a file.",
)
parser.add_argument(
"--num_iters",
"-n",
type=int,
required=False,
default=30,
help="Number of iterations to run for measurement.",
)
parser.add_argument(
"--dry_run_iters",
"-d",
type=int,
required=False,
default=5,
help="Number of dry runs.",
)
parser.add_argument(
"--case_tag",
type=str,
required=False,
default=None,
help="Optional tag for the test case for annotating output.",
)
parser.add_argument(
"--generate_repro_command",
action="store_true",
default=False,
help="If set, will print reproducer command and store it to output csv.",
)
parser.add_argument(
"--repro_command",
type=str,
required=False,
default="",
help="Placeholder for generated reproducer command for the test case. Not to be used directly.",
)
## Check routine and pass on to routine-specific argument parser
if args.routine in benchmark_apis["attention"]:
args = parse_attention_args(line, parser)
elif args.routine in benchmark_apis["gemm"]:
args = parse_gemm_args(line, parser)
elif args.routine in benchmark_apis["moe"]:
args = parse_moe_args(line, parser)
else:
raise ValueError(f"Unsupported routine: {args.routine}")
if args.generate_repro_command:
args.repro_command = "python3 flashinfer_benchmark.py " + " ".join(line)
return args
if __name__ == "__main__":
# Parse testlist argument first
testlist_parser = argparse.ArgumentParser(add_help=False)
testlist_parser.add_argument(
"--testlist",
type=str,
required=False,
default=None,
help="Optional testlist file to run multiple cases.",
)
testlist_parser.add_argument(
"--output_path",
type=str,
required=False,
default=None,
help="Output path for results csv.",
)
testlist_args, _ = testlist_parser.parse_known_args()
# Setup output file if specified
if testlist_args.output_path is not None:
with open(testlist_args.output_path, "w") as fout:
fout.write(",".join(full_output_columns) + "\n")
# Process tests either from testlist file or command line arguments
if testlist_args.testlist is not None:
# If testlist, run each test in the testlist
with open(testlist_args.testlist, "r") as f:
import shlex
for line in f.readlines():
line = line.strip()
if not line or line.startswith("#"):
continue
try:
line_args = parse_args(shlex.split(line))
line_args.output_path = testlist_args.output_path
run_test(line_args)
except Exception as e:
print(f"[ERROR] Error running test: {line}")
print(f"[ERROR] Error: {e}")
continue
else:
# If no testlist, just run the command
args = parse_args()
args.output_path = testlist_args.output_path
run_test(args)