sglang_v0.5.2/pytorch_2.8.0/third_party/XNNPACK/tools/generate-gemm-test.py

1035 lines
33 KiB
Python
Executable File

#!/usr/bin/env python
# Copyright 2019 Google LLC
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import codecs
import collections
import os
import re
import sys
import zlib
import yaml
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from primes import next_prime
import xngen
import xnncommon
parser = argparse.ArgumentParser(description="XNNPACK generator")
parser.add_argument(
"-s", "--spec", metavar="FILE", required=True, help="Spec (YAML) file")
parser.add_argument(
"-o",
"--output-test",
action="append",
metavar="FILE",
required=True,
help="Test output (C++ source) file(s)")
parser.add_argument(
"-b",
"--output-bench",
metavar="FILE",
required=False,
help="Benchmark output (C++ source) file(s)")
parser.set_defaults(defines=list())
def split_ukernel_name(name):
common_name, target_name = name.split("__", 1)
common_parts = common_name.split("_")
param_spec = common_parts[-1]
if "s" in param_spec:
param_spec, sr = param_spec.split("s", 1)
sr = int(sr)
else:
sr = 1
if "c" in param_spec:
param_spec, kr = param_spec.split("c", 1)
kr = int(kr)
else:
kr = 1
if "v" in param_spec:
vector_tile = True
param_spec, _ = param_spec.split("v", 1)
else:
vector_tile = False
mr, nr = map(int, param_spec.split("x"))
arch, isa, assembly = xnncommon.parse_target_name(target_name)
mr_packed = re.search(r"mstep([0-9]+)", target_name)
if mr_packed:
mr_packed = mr // int(mr_packed.group(1))
else:
mr_packed = mr
requantization = common_parts[-3]
if requantization not in ["fp32", "rndnu", "rndnu16"]:
requantization = None
return mr, nr, kr, sr, mr_packed, vector_tile, requantization, arch, isa, assembly
GEMM_BENCH_CODE = """\
$if CPP_CHECK:
#if ${CPP_CHECK}
static void ${UKERNEL_NAME}(benchmark::State& state, const char* net) {
GEMMBenchmark(state,
${GEMM},
$if INIT_PARAMS is not None:
${INIT_PARAMS},
$if PACK_FN is not None:
${PACK_FN},
$if PACKED_STRIDE_FN is not None:
${PACKED_STRIDE_FN},
/*mr=*/${MR}, /*nr=*/${NR}${NR_SCALE}, /*kr=*/${KR}, /*sr=*/${SR},
$if DATATYPE in ('qp8',):
/*mr_packed=*/${MR_PACKED},
$if ISA_CHECK:
benchmark::utils::${ISA_CHECK});
$else:
/*isa_check=*/nullptr);
}\n
$if KERNELTYPE in ['qb4w']:
BENCHMARK_GEMM_BL(${UKERNEL_NAME})
$else:
BENCHMARK_GEMM(${UKERNEL_NAME})
$if CPP_CHECK:
#endif // ${CPP_CHECK}
"""
GEMM_CREATE_TESTS_CODE = """\
std::vector<GemmTestParams> CreateTests(
size_t k_block, size_t adj_k_block,
size_t mr, size_t nr, size_t kr, size_t sr,
$if DATATYPE in ('qp8'):
size_t mr_packed,
bool is_igemm,
bool unsigned_inputs,
std::function<void(GemmMicrokernelTester& tester)> test_func,
std::function<void()> isa_check = nullptr) {
std::string kbs = std::to_string(k_block);
std::string kb2s = std::to_string(k_block * 2);
std::string akbs = std::to_string(adj_k_block);
$if NR_SCALE != "":
nr = nr${NR_SCALE};
std::string nrs = std::to_string(nr);
$if DATATYPE in ('qp8',):
const GemmMicrokernelTester tester = GemmMicrokernelTester()
.mr(mr).nr(nr).kr(kr).sr(sr).mr_packed(mr_packed).unsigned_inputs(unsigned_inputs);
$else:
const GemmMicrokernelTester tester = GemmMicrokernelTester()
.mr(mr).nr(nr).kr(kr).sr(sr).unsigned_inputs(unsigned_inputs);
std::vector<GemmTestParams> gemm_tests;
gemm_tests.reserve(42);
gemm_tests.push_back(GemmTestParams(
"k_eq_" + kbs,
tester.clone()
.m(mr).n(nr).k(k_block)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check));
$if DATATYPE != "qp8":
gemm_tests.push_back(GemmTestParams(
"strided_cn",
tester.clone()
.m(mr).n(nr).k(k_block)
.cn_stride(xnnpack::NextPrime(nr + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check));
if (!is_igemm) {
gemm_tests.push_back(GemmTestParams(
"k_eq_" + kbs + "_strided_a",
tester.clone()
.m(mr).n(nr).k(k_block)
.a_stride(xnnpack::NextPrime(k_block + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check));
}
gemm_tests.push_back(GemmTestParams(
"k_eq_" + kbs + "_subtile",
tester.clone()
.k(k_block).iterations(1)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_n(1, nr)
.loop_m(1, mr));
gemm_tests.push_back(GemmTestParams(
"k_eq_" + kbs + "_subtile_m",
tester.clone()
.n(nr).k(k_block).iterations(1)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_m(1, mr));
gemm_tests.push_back(GemmTestParams(
"k_eq_" + kbs + "_subtile_n",
tester.clone()
.m(mr).k(k_block).iterations(1)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_n(1, nr));
$if IS_PIPELINED:
gemm_tests.push_back(GemmTestParams(
"k_eq_" + kb2s,
tester.clone()
.m(mr).n(nr).k(k_block * 2)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check));
if (!is_igemm) {
gemm_tests.push_back(GemmTestParams(
"k_eq_" + kb2s + "_strided_a",
tester.clone()
.m(mr).n(nr).k(k_block * 2)
.a_stride(xnnpack::NextPrime(k_block * 2 + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check));
}
gemm_tests.push_back(GemmTestParams(
"k_eq_" + kb2s + "_subtile",
tester.clone()
.k(k_block * 2).iterations(1)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_n(1, nr)
.loop_m(1, mr));
$if KERNELTYPE not in ['qb4w']:
if (k_block > 1) {
gemm_tests.push_back(GemmTestParams(
"k_lt_" + akbs,
tester.clone()
.m(mr).n(nr)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(1, adj_k_block - 1));
if (!is_igemm) {
gemm_tests.push_back(GemmTestParams(
"k_lt_" + akbs + "_strided_a",
tester.clone()
.m(mr).n(nr)
.a_stride(xnnpack::NextPrime(adj_k_block + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(1, adj_k_block - 1));
}
gemm_tests.push_back(GemmTestParams(
"k_lt_" + akbs + "_subtile",
tester.clone()
.iterations(1)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(1, adj_k_block - 1)
.loop_n(1, nr)
.loop_m(1, mr));
}
gemm_tests.push_back(GemmTestParams(
"k_gt_" + akbs,
tester.clone()
.m(mr).n(nr)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block));
if (is_igemm) {
gemm_tests.push_back(GemmTestParams(
"k_gt_" + akbs + "_strided_a",
tester.clone()
.m(mr).n(nr)
.a_stride(xnnpack::NextPrime(adj_k_block * 2 + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block));
}
gemm_tests.push_back(GemmTestParams(
"k_gt_" + akbs + "_subtile",
tester.clone()
.iterations(1)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(adj_k_block + 1, adj_k_block * 2 - 1, k_block)
.loop_n(1, nr)
.loop_m(1, mr));
if (k_block > 1) {
gemm_tests.push_back(GemmTestParams(
"k_div_" + kbs,
tester.clone()
.m(mr).n(nr)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(adj_k_block + k_block, k_block * 5, k_block));
if (is_igemm) {
gemm_tests.push_back(GemmTestParams(
"k_div_" + kbs + "_strided_a",
tester.clone()
.m(mr).n(nr)
.a_stride(xnnpack::NextPrime(k_block * 3 + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(adj_k_block + k_block, k_block * 3, k_block));
}
gemm_tests.push_back(GemmTestParams(
"k_div_" + kbs + "_subtile",
tester.clone()
.iterations(1)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(adj_k_block + k_block, k_block * 5, k_block)
.loop_n(1, nr)
.loop_m(1, mr));
}
gemm_tests.push_back(GemmTestParams(
"n_gt_" + nrs,
tester.clone()
.m(mr)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
$if NR_SCALE != "":
.loop_n(nr + 1, nr * 2 - 1, 4)
$else:
.loop_n(nr + 1, nr * 2 - 1)
.loop_k(1, k_block * 3, k_block + 1));
$if DATATYPE != "qp8":
gemm_tests.push_back(GemmTestParams(
"n_gt_" + nrs + "_strided_cn",
tester.clone()
.m(mr)
.cn_stride(xnnpack::NextPrime(nr + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
$if NR_SCALE != "":
.loop_n(nr + 1, nr * 2 - 1, 4)
$else:
.loop_n(nr + 1, nr * 2 - 1)
.loop_k(1, k_block * 3, k_block + 1));
if (!is_igemm) {
gemm_tests.push_back(GemmTestParams(
"n_gt_" + nrs + "_strided_a",
tester.clone()
.m(mr)
.a_stride(xnnpack::NextPrime(k_block * 3 + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
$if NR_SCALE != "":
.loop_n(nr + 1, nr * 2 - 1, 4)
$else:
.loop_n(nr + 1, nr * 2 - 1)
.loop_k(1, k_block * 3, k_block));
}
gemm_tests.push_back(GemmTestParams(
"n_gt_" + nrs + "_subtile",
tester.clone()
.iterations(1)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
$if NR_SCALE != "":
.loop_n(nr + 1, nr * 2 - 1, 4)
$else:
.loop_n(nr + 1, nr * 2 - 1)
.loop_k(1, k_block * 3, k_block + 1)
.loop_m(1, mr));
gemm_tests.push_back(GemmTestParams(
"n_div_" + nrs,
tester.clone()
.m(mr)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_n(nr * 2, nr * 3, nr)
.loop_k(1, k_block * 3, k_block + 1));
$if DATATYPE != "qp8":
gemm_tests.push_back(GemmTestParams(
"n_div_" + nrs + "_strided_cn",
tester.clone()
.m(mr)
.cn_stride(xnnpack::NextPrime(nr + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_n(nr * 2, nr * 3, nr)
.loop_k(1, k_block * 3, k_block + 1));
if (!is_igemm) {
gemm_tests.push_back(GemmTestParams(
"n_div_" + nrs + "_strided_a",
tester.clone()
.m(mr)
.a_stride(xnnpack::NextPrime(k_block * 3 + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_n(nr * 2, nr * 3, nr)
.loop_k(1, k_block * 3, k_block));
}
gemm_tests.push_back(GemmTestParams(
"n_div_" + nrs + "_subtile",
tester.clone()
.iterations(1)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_n(nr * 2, nr * 3, nr)
.loop_k(1, k_block * 3, k_block + 1)
.loop_m(1, mr));
if (is_igemm) {
gemm_tests.push_back(GemmTestParams(
"small_kernel",
tester.clone()
.m(mr).n(nr).ks(3)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(1, k_block * 3, k_block + 1));
gemm_tests.push_back(GemmTestParams(
"small_kernel_subtile",
tester.clone()
.ks(3).iterations(1)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(1, k_block * 3, k_block + 1)
.loop_n(1, nr)
.loop_m(1, mr));
gemm_tests.push_back(GemmTestParams(
"n_gt_" + nrs + "_small_kernel",
tester.clone()
.m(mr).ks(3)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
$if NR_SCALE != "":
.loop_n(nr + 1, nr * 2 - 1, 4)
$else:
.loop_n(nr + 1, nr * 2 - 1)
.loop_k(1, k_block * 3, k_block + 1));
gemm_tests.push_back(GemmTestParams(
"n_div_" + nrs + "_small_kernel",
tester.clone()
.m(mr).ks(3)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_n(nr * 2, nr * 3, nr)
.loop_k(1, k_block * 3, k_block + 1));
}
gemm_tests.push_back(GemmTestParams(
"strided_cm_subtile",
tester.clone()
.mr(mr).nr(nr).kr(kr).sr(sr)
.cm_stride(xnnpack::NextPrime(nr + 1))
.iterations(1)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(1, k_block * 3, k_block + 1)
.loop_n(1, nr)
.loop_m(1, mr));
if (is_igemm) {
gemm_tests.push_back(GemmTestParams(
"a_offset",
tester.clone()
.m(mr).n(nr).ks(3)
.a_offset(xnnpack::NextPrime(mr * k_block * 3 + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(1, k_block * 3, k_block + 1));
gemm_tests.push_back(GemmTestParams(
"zero",
tester.clone()
.m(mr).n(nr).ks(3)
.a_offset(xnnpack::NextPrime(mr * k_block * 3 + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check)
.loop_k(1, k_block * 3, k_block + 1)
.loop_zi(0, mr - 1));
}
$if ACTIVATION == "MINMAX":
gemm_tests.push_back(GemmTestParams(
"qmin",
tester.clone()
.m(mr).n(nr).k(k_block).qmin(128)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check));
gemm_tests.push_back(GemmTestParams(
"qmax",
tester.clone()
.m(mr).n(nr).k(k_block).qmax(128)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check));
gemm_tests.push_back(GemmTestParams(
"strided_cm",
tester.clone()
.m(mr).n(nr).k(k_block)
.cm_stride(xnnpack::NextPrime(nr + 1))
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
$if KERNELTYPE in ['qb4w']:
.bl(32)
, test_func, isa_check));
$if DATATYPE == "qu8":
gemm_tests.push_back(GemmTestParams(
"no_a_zero_point",
tester.clone()
.m(mr).n(nr).a_zero_point(0)
, test_func, isa_check)
.loop_k(1, k_block * 3, k_block + 1));
$if DATATYPE == "qu8":
gemm_tests.push_back(GemmTestParams(
"no_b_zero_point",
tester.clone()
.m(mr).n(nr).b_zero_point(0)
, test_func, isa_check)
.loop_k(1, k_block * 3, k_block + 1));
gemm_tests.push_back(GemmTestParams(
"b_zero_point",
tester.clone()
.m(mr).n(nr).k(k_block)
, test_func, isa_check)
.loop_bzp(0, 255));
gemm_tests.push_back(GemmTestParams(
"no_zero_point",
tester.clone()
.m(mr).n(nr)
.a_zero_point(0)
.b_zero_point(0)
, test_func, isa_check)
.loop_k(1, k_block * 3, k_block + 1));
$if KERNELTYPE in ['qb4w']:
gemm_tests.push_back(GemmTestParams(
"bl",
tester.clone()
.m(mr).n(nr).k(k_block * 12)
.b_zero_point(8)
, test_func, isa_check)
.loop_k(k_block, k_block * 12, k_block, LoopStepType::Linear)
.loop_bl(32, k_block * 32, 32));
return gemm_tests;
}
"""
GEMM_TEST_CODE = """\
$if CPP_CHECK:
#if ${CPP_CHECK}
INSTANTIATE_TEST_SUITE_P(
${TEST_NAME}, GemmTest,
testing::ValuesIn(CreateTests(
/*k_block=*/${KBLOCK},
/*adj_k_block=*/${ADJKBLOCK},
/*mr=*/${MR}, /*nr=*/${NR}, /*kr=*/${KR}, /*sr=*/${SR},
$if DATATYPE in ('qp8',):
/*mr_packed=*/${MR_PACKED},
/*is_igemm=*/${"true" if UKERNEL_TYPE.startswith("IGEMM") else "false"},
/*unsigned_inputs=*/${"true" if UNSIGNED_INPUTS else "false"},
[](GemmMicrokernelTester& tester) {
tester.Test(${",\\n ".join(TEST_ARGS)});
$if ISA_CHECK:
},
[]() {
${ISA_CHECK};
})),
$else:
})),
[](const testing::TestParamInfo<GemmTest::ParamType>& info) {
return info.param.test_name;
});
$if TEST_NAME.startswith('GENERATE') and DATATYPE in ['f32', 'f16']:
TEST(${TEST_NAME}, subtile_m_upto_mr) {
$if ISA_CHECK:
${ISA_CHECK};
for (uint32_t max_mr = 1; max_mr <= ${MR}; max_mr++) {
for (uint32_t m = 1; m <= max_mr; m++) {
for (size_t k = 1; k <= ${KBLOCK * 2}; k += 1) {
GemmMicrokernelTester()
.mr(max_mr)
$if NR > 1:
.nr(${NR})
$if KR > 1:
.kr(${KR})
$if SR > 1:
.sr(${SR})
.m(m)
$if NR > 1:
.n(${NR})
.k(k)
.iterations(1)
$if KERNELTYPE in ['qb4w', 'qc4w']:
.b_zero_point(8)
.Test(${", ".join(TEST_ARGS)});
}
}
}
}
$if TEST_NAME.startswith('GENERATE') and DATATYPE in ['f32', 'f16'] and PROTOTYPE is not None:
#if XNN_ENABLE_ASSEMBLY
TEST(${TEST_NAME}, matches_assembly) {
$if ISA_CHECK:
${ISA_CHECK};
GemmMicrokernelTester()
$if MR > 1:
.mr(${MR})
$if NR > 1:
.nr(${NR})
$if KR > 1:
.kr(${KR})
$if SR > 1:
.sr(${SR})
$if MR > 1:
.m(${MR})
$if NR > 1:
.n(${NR})
.k(${KBLOCK})
.Test(
${", ".join(TEST_ARGS)},
&${PROTOTYPE});
}
#endif // XNN_ENABLE_ASSEMBLY
$if CPP_CHECK:
#endif // ${CPP_CHECK}
"""
def generate_test_cases(
ukernel,
mr,
nr,
kr,
sr,
mr_packed,
k_block,
unsigned_inputs,
vector_tile,
init_fn,
pack_fn,
packed_stride_fn,
requantization,
is_pipelined,
cpp_check,
isa,
prototype,
):
"""Generates all tests cases for a GEMM micro-kernel.
Args:
ukernel: C name of the micro-kernel function.
mr: MR parameter of the GEMM micro-kernel.
nr: NR parameter of the GEMM micro-kernel.
kr: KR parameter of the GEMM micro-kernel.
sr: SR parameter of the GEMM micro-kernel.
mr_packed: Optional MR parameter for the left-hand packing function.
k_block: Number of K values processed per one iteration of the main loop of
the micro-kernel.
unsigned_inputs: whether the inputs should be converted to unsigned
integers. Some microkernels are more efficient with unsigned inputs.
vector_tile: Indicates if vector tile for NR is specified in vectors rather
than elements.
init_fn: C name of the function to initialize microkernel parameters.
pack_fn: C name of the function to pack the weights.
packed_stride_fn: C name of the function to compute the packed weights
stride.
requantization: name of the requantization scheme used by the microkernel.
is_pipelined: Indicates if the micro-kernel is implemented with software
pipelining. Additional test cases are generated for software pipelined
micro-kernels to separately test prologue + epiloque of the pipelined loop
and iteration of the pipelined loop.
cpp_check: Optional preprocessor macro to check for the availability of the
micro-kernel.
isa: instruction set required to run the micro-kernel. Generated unit test
will skip execution if the host processor doesn't support this ISA.
Returns:
Code for the test case.
"""
_, ukernel_name = ukernel.split("_", 1)
_, datatype, ukernel_type, activation, _ = ukernel.split("_", 4)
kerneltype = datatype
if datatype in ["f16", "f32"] and ukernel_type in ["qc8w", "qc4w"]:
_, datatype, kerneltype, ukernel_type, activation, _ = ukernel.split(
"_", 5
)
datatype = datatype + "_" + kerneltype
if (
datatype in ("qd8", "qp8")
and ukernel_type in ["f16", "f32"]
and activation in ["qc8w", "qc4w", "qb4w"]
):
_, datatype, _, kerneltype, ukernel_type, activation, _ = ukernel.split(
"_", 6
)
if activation == "ukernel":
activation = "linear"
if activation in ["qs8w"]:
_, _, _, _, _, activation, _ = ukernel.split("_", 6)
test_args = [ukernel]
if init_fn:
test_args.append(init_fn)
if pack_fn:
test_args.append(pack_fn)
if packed_stride_fn:
test_args.append(packed_stride_fn)
if init_fn and requantization:
requantization_datatype = {"qc8": "qs8"}.get(datatype, datatype)
test_args.append(
"xnn_%s_requantize_%s" % (requantization_datatype, requantization)
)
nr_scale = ""
if vector_tile:
ctype = {
"qs8": "int8_t",
"qd8": "int32_t",
"qp8": "int8_t",
"qu8": "uint8_t",
"f16": "uint16_t",
"f32": "float",
}[datatype]
nr_scale = {"rvv": " * xnn_init_hardware_config()->vlenb / sizeof(%s)" % ctype}[isa]
test_args = {
"TEST_NAME": ukernel_name.upper().replace("UKERNEL_", ""),
"TEST_ARGS": test_args,
"UKERNEL_TYPE": ukernel_type.upper(),
"DATATYPE": datatype,
"KERNELTYPE": kerneltype,
"ACTIVATION": activation.upper(),
"MR": mr,
"NR": nr,
"KR": kr,
"SR": sr,
"MR_PACKED": mr_packed,
"KBLOCK": k_block,
"UNSIGNED_INPUTS": unsigned_inputs,
"NR_SCALE": nr_scale,
"ADJKBLOCK": 2 * k_block if is_pipelined else k_block,
"IS_PIPELINED": is_pipelined,
"ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
"next_prime": next_prime,
"PROTOTYPE": prototype,
"CPP_CHECK": cpp_check,
}
create_test_case = xngen.preprocess(GEMM_CREATE_TESTS_CODE, test_args)
test_case = xngen.preprocess(GEMM_TEST_CODE, test_args)
benchmark = xngen.preprocess(
GEMM_BENCH_CODE,
{
"UKERNEL_NAME": ukernel_name,
"GEMM": ukernel,
"KERNELTYPE": kerneltype,
"DATATYPE": datatype,
"INIT_PARAMS": init_fn,
"PACK_FN": pack_fn,
"PACKED_STRIDE_FN": packed_stride_fn,
"MR": mr,
"NR": nr,
"KR": kr,
"SR": sr,
"MR_PACKED": mr_packed,
"NR_SCALE": nr_scale,
"ISA_CHECK": xnncommon.generate_isa_utilcheck_macro(isa),
"CPP_CHECK": cpp_check,
},
)
return create_test_case, test_case, benchmark
def main(args):
options = parser.parse_args(args)
num_output_files = len(options.output_test)
with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
spec_yaml = yaml.safe_load(spec_file)
if not isinstance(spec_yaml, list):
raise ValueError("expected a list of micro-kernels in the spec")
tests = """\
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
//
// Auto-generated file. Do not edit!
// Specification: {specification}
// Generator: {generator}
#include <cstddef>
#include <functional>
#include <string>
#include <vector>
#include <gtest/gtest.h>
#include "xnnpack/allocator.h"
#include "xnnpack/common.h"
#include "xnnpack/gemm.h"
#include "xnnpack/igemm.h"
#include "xnnpack/isa-checks.h"
#include "xnnpack/microparams-init.h"
#include "xnnpack/pack.h"
#include "xnnpack/packw.h"
#include "xnnpack/ppmm.h"
#include "xnnpack/requantization.h"
#include "gemm-microkernel-tester.h"
#include "next_prime.h"
""".format(specification=options.spec, generator=sys.argv[0])
benches = """\
// Copyright 2023 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
//
// Auto-generated file. Do not edit!
// Specification: {specification}
// Generator: {generator}
#include <benchmark/benchmark.h>
#include "gemm-benchmark.h"
#include "utils.h"
#include "xnnpack/common.h"
#include "xnnpack/gemm.h"
#include "xnnpack/isa-checks.h"
#include "xnnpack/microfnptr.h"
#include "xnnpack/microparams-init.h"
#include "xnnpack/pack.h"
#include "xnnpack/packw.h"
""".format(specification=options.spec, generator=sys.argv[0])
test_outputs = collections.defaultdict(str)
bench_outputs = benches
isa_hierarchy = xnncommon.isa_hierarchy_map()
# Cached `CreateTests` functions.
idx_from_create_tests_hash = collections.defaultdict(
lambda: len(idx_from_create_tests_hash) + 1
)
create_tests_from_idx = {}
benches = [""] * len(isa_hierarchy)
for ukernel_spec in spec_yaml:
name = ukernel_spec["name"]
k_block = int(ukernel_spec["k-block"])
if "unsigned-inputs" in ukernel_spec:
unsigned_inputs = int(ukernel_spec["unsigned-inputs"])
else:
unsigned_inputs = False
init_fn = ukernel_spec.get("init")
pack_fn = ukernel_spec.get("pack")
packed_stride_fn = ukernel_spec.get("packed-stride")
pipelined = bool(ukernel_spec.get("pipelined", False))
cpp_check = ukernel_spec.get("cpp-check", False)
prototype = ukernel_spec.get("prototype")
(
mr,
nr,
kr,
sr,
mr_packed,
vector_tile,
requantization,
arch,
isa,
assembly,
) = split_ukernel_name(name)
create_tests, test_case, bench_case = generate_test_cases(
name,
mr,
nr,
kr,
sr,
mr_packed,
k_block,
unsigned_inputs,
vector_tile,
init_fn,
pack_fn,
packed_stride_fn,
requantization,
pipelined,
cpp_check,
isa,
prototype,
)
# Store or reuse the `CreateTests` function?
create_tests_hash = hash(create_tests)
create_tests_idx = idx_from_create_tests_hash[create_tests_hash]
if create_tests_idx not in create_tests_from_idx:
create_tests_from_idx[create_tests_idx] = create_tests.replace(
"CreateTests(", f"CreateTests{create_tests_idx}("
)
if isa == 'rvv':
create_tests_from_idx[create_tests_idx] = xnncommon.postprocess_test_case(
create_tests_from_idx[create_tests_idx], arch, isa, assembly)
test_case = test_case.replace(
"CreateTests(", f"CreateTests{create_tests_idx}("
)
# Hash the name of each microkernel and figure out which output file to
# write it to.
output_index = zlib.crc32(bytes(name, "utf-8")) % num_output_files
test_outputs[
options.output_test[output_index]
] += "\n\n" + xnncommon.postprocess_test_case(
test_case, arch, isa, assembly
)
benches[
isa_hierarchy.get(isa, 0)
] += "\n\n" + xnncommon.postprocess_test_case(
bench_case, arch, isa, assembly
)
for arch_idx in reversed(range(len(isa_hierarchy))):
bench_outputs += benches[arch_idx]
bench_outputs += """\n
#ifndef XNNPACK_BENCHMARK_NO_MAIN
BENCHMARK_MAIN();
#endif
"""
if options.output_bench:
# Strip out consecutive preprocessor `endif`/`if` pairs.
for _ in range(2):
bench_outputs = re.sub(
r"^ *\#endif // ([^\n]+)\n+ *\#if \1\n",
"\n",
bench_outputs,
flags=re.MULTILINE,
)
output_name = options.output_bench
xnncommon.overwrite_if_changed(output_name, bench_outputs)
create_tests = (
"namespace {\n\n"
+ "\n".join(create_tests_from_idx.values())
+ "\n} // namespace\n"
)
test_outputs = {
k: tests + "\n" + create_tests + v for k, v in test_outputs.items()
}
# Strip out consecutive preprocessor `endif`/`if` pairs.
for _ in range(2):
test_outputs = {
k: re.sub(
r"^ *\#endif // ([^\n]+)\n+ *\#if \1\n",
"\n",
v,
flags=re.MULTILINE,
)
for k, v in test_outputs.items()
}
for output_name in options.output_test:
xnncommon.overwrite_if_changed(output_name, test_outputs[output_name])
if __name__ == "__main__":
main(sys.argv[1:])