// Copyright 2019 Google LLC // // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. #ifndef THIRD_PARTY_XNNPACK_TEST_UNARY_OPS_H_ #define THIRD_PARTY_XNNPACK_TEST_UNARY_OPS_H_ #pragma once #include #include #include #include #include #include #include #include #include #include "xnnpack.h" #include "xnnpack/buffer.h" #include "xnnpack/math.h" #include "xnnpack/reference-utils.h" static float TolExact(float) { return 0.0f; } static float TolExact16(float y_ref) { return std::abs(y_ref) * 1.0e-3f; } static float TolRelative(float y_ref, float rel_tol) { // Note that `y_ref * rel_tol`, i.e. the expected absolute difference, // may round differently than `y_ref * (1 + rel_tol) - y_ref`, i.e. the // effective absolute difference computed in `float`s. We therefore use // the latter form since it is the true difference between two `float`s // within the given relative tolerance. return std::abs(y_ref * (1.0f + rel_tol)) - std::abs(y_ref); } static float TolMixed(float y_ref, float abs_tol, float rel_tol) { return std::max(abs_tol, std::abs(y_ref) * (1.0f + rel_tol) - std::abs(y_ref)); } struct Interval { float min; float max; static Interval All() { return {-std::numeric_limits::infinity(), std::numeric_limits::infinity()}; } static Interval Positive(xnn_datatype datatype) { switch (datatype) { case xnn_datatype_fp16: return {0.001f, std::numeric_limits::infinity()}; case xnn_datatype_fp32: return {std::numeric_limits::epsilon(), std::numeric_limits::infinity()}; default: return {1.0f, std::numeric_limits::infinity()}; } } }; // This struct describes a unary operator enough such that we can test them // without knowing anything about the specific operator. struct UnaryOpInfo { virtual ~UnaryOpInfo() = default; virtual float ReferenceImpl(float x, const xnn_unary_params& params) const { XNN_UNREACHABLE; } virtual int ReferenceImpl(int x, const xnn_unary_params& params) const { XNN_UNREACHABLE; } // Get the parameters to use by default for this operator. virtual xnn_unary_params DefaultParams() const { return xnn_unary_params(); } // Compute the tolerance for error given the reference result and the // datatype. virtual float Tolerance(float y_ref, xnn_datatype datatype) const { switch (datatype) { case xnn_datatype_qint8: case xnn_datatype_quint8: return 1; case xnn_datatype_fp16: return TolExact16(y_ref); default: return TolExact(y_ref); } } virtual Interval Domain(xnn_datatype) const { return Interval::All(); } // Quantization parameters to use by default. virtual xnn_quantization_params InputQuantizationParams( xnn_datatype datatype) const { switch (datatype) { case xnn_datatype_quint8: return {150, 1.0f}; default: return {0, 1.0f}; } } virtual xnn_quantization_params OutputQuantizationParams( xnn_datatype datatype) const { switch (datatype) { case xnn_datatype_quint8: return {100, 1.0f}; default: return {0, 1.0f}; } } }; struct Convert : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return x; } int ReferenceImpl(int x, const xnn_unary_params&) const override { return x; } }; struct ReLU : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::max(x, 0.0f); } int ReferenceImpl(int x, const xnn_unary_params&) const override { return std::max(x, 0); } }; struct Abs : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::abs(x); } int ReferenceImpl(int x, const xnn_unary_params&) const override { return std::abs(x); } }; struct Negate : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return -x; } int ReferenceImpl(int x, const xnn_unary_params&) const override { return -x; } }; struct Clamp : public UnaryOpInfo { xnn_unary_params DefaultParams() const override { xnn_unary_params params; params.clamp.min = -40.0f; params.clamp.max = 50.0f; return params; } float ReferenceImpl(float x, const xnn_unary_params& params) const override { return std::min(std::max(x, params.clamp.min), params.clamp.max); } int ReferenceImpl(int x, const xnn_unary_params& params) const override { return std::min(std::max(x, params.clamp.min), params.clamp.max); } xnn_quantization_params InputQuantizationParams( xnn_datatype datatype) const override { return {0, 1.0f}; } xnn_quantization_params OutputQuantizationParams( xnn_datatype datatype) const override { return {0, 1.0f}; } }; struct ELU : public UnaryOpInfo { xnn_unary_params DefaultParams() const override { xnn_unary_params params; params.elu.alpha = 1.0f; return params; } float ReferenceImpl(float x, const xnn_unary_params& params) const override { return std::signbit(x) ? params.elu.alpha * std::expm1(x) : x; } float Tolerance(float y_ref, xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp32: return TolMixed(y_ref, 5.0e-6f, 1.0e-5f); case xnn_datatype_fp16: return TolMixed(y_ref, 1.0e-4f, 5.0e-3f); default: return 1; } } Interval Domain(xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp16: return {-9.0f, 9.0f}; default: return {-20.0f, 20.0f}; } } }; struct GELU : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return x * 0.5f * (1.0f + std::erf(x * std::sqrt(2) / 2)); } float Tolerance(float y_ref, xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp32: case xnn_datatype_fp16: case xnn_datatype_bf16: return TolMixed(y_ref, 10 * std::numeric_limits::epsilon(), 5 * std::numeric_limits::epsilon()); case xnn_datatype_qint8: case xnn_datatype_quint8: return 1; default: XNN_UNREACHABLE; } } Interval Domain(xnn_datatype) const override { return {-10.0f, 10.0f}; } }; struct HardSwish : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return (x / 6.0) * std::max(std::min(x + 3.0, 6.0), 0.0); } float Tolerance(float y_ref, xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp32: return TolMixed(y_ref, 5.0e-6f, 1.0e-5f); case xnn_datatype_fp16: case xnn_datatype_bf16: return TolMixed(y_ref, 1.0e-3f, 1.0e-2f); case xnn_datatype_qint8: case xnn_datatype_quint8: return 1; default: XNN_UNREACHABLE; } } Interval Domain(xnn_datatype) const override { return {-4.0f, 4.0f}; } }; struct LeakyReLU : public UnaryOpInfo { xnn_unary_params DefaultParams() const override { xnn_unary_params params; params.leaky_relu.negative_slope = 0.5f; return params; } float ReferenceImpl(float x, const xnn_unary_params& params) const override { return std::signbit(x) ? x * params.leaky_relu.negative_slope : x; } float Tolerance(float y_ref, xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp32: return TolExact(y_ref); case xnn_datatype_fp16: case xnn_datatype_bf16: return TolMixed(y_ref, 1.0e-4f, 1.0e-3f); case xnn_datatype_qint8: case xnn_datatype_quint8: return 1; default: XNN_UNREACHABLE; } } }; struct RoundToNearestEven : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::nearbyint(x); } }; struct RoundTowardsZero : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::trunc(x); } }; struct RoundUp : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::ceil(x); } }; struct RoundDown : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::floor(x); } }; struct Sigmoid : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { if (x > 100) { return 1.0f; } else if (x < -100) { return 0.0f; } else { const double e = std::exp(static_cast(x)); return e / (1.0 + e); } } float Tolerance(float y_ref, xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp32: return TolMixed(y_ref, 5.0e-6f, 1.0e-5f); case xnn_datatype_fp16: case xnn_datatype_bf16: return TolMixed(y_ref, 1.0e-4f, 5.0e-3f); case xnn_datatype_qint8: case xnn_datatype_quint8: return 1; default: return TolExact(y_ref); } } Interval Domain(xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp16: return {-25.0f, 25.0f}; default: return {-125.0f, 125.0f}; } } }; struct Square : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return x * x; } int ReferenceImpl(int x, const xnn_unary_params&) const override { return static_cast(static_cast(x) * static_cast(x)); } float Tolerance(float y_ref, xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp32: return TolExact(y_ref); case xnn_datatype_fp16: case xnn_datatype_bf16: return TolMixed(y_ref, 1.0e-4f, 5.0e-3f); case xnn_datatype_qint8: case xnn_datatype_quint8: return 1; case xnn_datatype_int32: // Overflow makes this hard to test. return std::numeric_limits::infinity(); default: XNN_UNREACHABLE; } } }; struct SquareRoot : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::sqrt(x); } float Tolerance(float y_ref, xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp32: return TolRelative(y_ref, 2.5f * std::numeric_limits::epsilon()); case xnn_datatype_fp16: case xnn_datatype_bf16: return TolMixed(y_ref, 1.0e-4f, 5.0e-3f); case xnn_datatype_qint8: case xnn_datatype_quint8: return 1; default: XNN_UNREACHABLE; } } Interval Domain(xnn_datatype datatype) const override { if (datatype == xnn_datatype_fp16) { return {0.001f, 10.0f}; } else { return Interval::Positive(datatype); } } }; struct TanH : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::tanh(x); } float Tolerance(float y_ref, xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp32: return TolRelative( y_ref, 4.0f * std::numeric_limits::epsilon()); // 4 ULP case xnn_datatype_fp16: case xnn_datatype_bf16: return TolMixed(y_ref, /*abs_tol=*/1.0e-4f, /*rel_tol=*/5.0e-3f); default: return 1; } } Interval Domain(xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp16: return {-5.0f, 5.0f}; default: return {-10.0f, 10.0f}; } } }; struct ReciprocalSquareRoot : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return 1.0 / std::sqrt(x); } float Tolerance(float y_ref, xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp32: return TolRelative(y_ref, 4 * std::numeric_limits::epsilon()); case xnn_datatype_fp16: case xnn_datatype_bf16: return TolMixed(y_ref, 1.0e-4f, 5.0e-3f); case xnn_datatype_qint8: case xnn_datatype_quint8: return 1; default: return TolExact(y_ref); } } Interval Domain(xnn_datatype datatype) const override { if (datatype == xnn_datatype_fp16) { return {1.0e-4f, 10.0f}; } else { return Interval::Positive(datatype); } } }; struct Log : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::log(x); } float Tolerance(float y_ref, xnn_datatype datatype) const override { return TolMixed(y_ref, 2 * std::numeric_limits::epsilon(), 6 * std::numeric_limits::epsilon()); } Interval Domain(xnn_datatype datatype) const override { return {std::numeric_limits::epsilon(), 1000.0f}; } }; struct Exp : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::exp(x); } float Tolerance(float y_ref, xnn_datatype datatype) const override { return TolMixed(y_ref, 2 * std::numeric_limits::epsilon(), 6 * std::numeric_limits::epsilon()); } Interval Domain(xnn_datatype) const override { return {-10.0f, 10.0f}; } }; struct CubeRoot : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::cbrt(x); } float Tolerance(float y_ref, xnn_datatype datatype) const override { switch (datatype) { case xnn_datatype_fp32: return TolRelative(y_ref, 2.5f * std::numeric_limits::epsilon()); case xnn_datatype_fp16: case xnn_datatype_bf16: return TolMixed(y_ref, 1.0e-4f, 5.0e-3f); case xnn_datatype_qint8: case xnn_datatype_quint8: return 1; default: XNN_UNREACHABLE; } } Interval Domain(xnn_datatype datatype) const override { if (datatype == xnn_datatype_fp16 || datatype == xnn_datatype_bf16) { return {0.001f, 10.0f}; } else { return Interval::Positive(datatype); } } }; struct Cosine : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::cos(x); } float Tolerance(float y_ref, xnn_datatype datatype) const override { return 1e-6f; } Interval Domain(xnn_datatype) const override { return {-10.0f, 10.0f}; } }; struct Sine : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return std::sin(x); } float Tolerance(float y_ref, xnn_datatype datatype) const override { return 1e-6f; } Interval Domain(xnn_datatype) const override { return {-10.0f, 10.0f}; } }; struct CountLeadingZeros : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return (float)math_clz_u32((int)x); } int ReferenceImpl(int x, const xnn_unary_params&) const override { return math_clz_u32(x); } }; struct BitwiseNot : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return ~(int)x; } int ReferenceImpl(int x, const xnn_unary_params&) const override { return ~x; } }; struct Popcount : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return (float)math_popcount_u32((int)x); } int ReferenceImpl(int x, const xnn_unary_params&) const override { return math_popcount_u32(x); } }; struct Sign : public UnaryOpInfo { float ReferenceImpl(float x, const xnn_unary_params&) const override { return x < 0.0f ? -1.0f : x > 0.0f ? 1.0f : 0.0f; } int ReferenceImpl(int x, const xnn_unary_params&) const override { return x < 0 ? -1 : x > 0 ? 1 : 0; } }; const UnaryOpInfo* GetUnaryOpInfo(xnn_unary_operator op); // Generate random data in the given domain, where the domain is given as // unquantized values. template void FillRandom(Rng& rng, T* x, size_t n, const Interval& domain, const xnn_quantization_params& quantization = {0, 1.0f}) { float min = domain.min; float max = domain.max; min = min * quantization.scale + quantization.zero_point; max = max * quantization.scale + quantization.zero_point; min = std::max(domain.min, xnnpack::NumericLimits::min()); max = std::min(domain.max, xnnpack::NumericLimits::max()); min = std::max(min, -1e6f); max = std::min(max, 1e6f); std::uniform_real_distribution dist(min, max); for (size_t i = 0; i < n; ++i) { x[i] = static_cast(dist(rng)); } } // Compute the result of a unary operator using the reference implementation. template void UnaryReferenceImpl( const xnnpack::quantized* x, size_t n, xnnpack::quantized* y, const UnaryOp& op_info, const xnn_quantization_params& input_quantization = {0, 1.0f}, const xnn_quantization_params& output_quantization = {0, 1.0f}, const xnn_unary_params& params = xnn_unary_params()) { for (size_t i = 0; i < n; i++) { float x_i = (x[i] - input_quantization.zero_point) * input_quantization.scale; float y_i = op_info.ReferenceImpl(x_i, params); y_i = y_i / output_quantization.scale + output_quantization.zero_point; y[i] = xnnpack::round_float_to_int(y_i); } } // Compute the result of a unary operator using the reference implementation. template void UnaryReferenceImpl( const In* x, size_t n, xnnpack::quantized* y, const UnaryOp& op_info, const xnn_quantization_params& input_quantization = {0, 1.0f}, const xnn_quantization_params& output_quantization = {0, 1.0f}, const xnn_unary_params& params = xnn_unary_params()) { static_assert(!xnnpack::is_quantized::value, ""); for (size_t i = 0; i < n; i++) { float y_i = op_info.ReferenceImpl(static_cast(x[i]), params); y_i = y_i / output_quantization.scale + output_quantization.zero_point; y[i] = xnnpack::round_float_to_int(y_i); } } // Compute the result of a unary operator using the reference implementation. template void UnaryReferenceImpl( const xnnpack::quantized* x, size_t n, Out* y, const UnaryOp& op_info, const xnn_quantization_params& input_quantization = {0, 1.0f}, const xnn_quantization_params& output_quantization = {0, 1.0f}, const xnn_unary_params& params = xnn_unary_params()) { static_assert(!xnnpack::is_quantized::value, ""); for (size_t i = 0; i < n; i++) { float x_i = (x[i] - input_quantization.zero_point) * input_quantization.scale; float y_i = op_info.ReferenceImpl(x_i, params); if (std::is_integral::value) { y[i] = xnnpack::round_float_to_int(y_i); } else { y[i] = y_i; } } } // Compute the result of a unary operator using the reference implementation. template void UnaryReferenceImpl( const In* x, size_t n, Out* y, const UnaryOp& op_info, const xnn_quantization_params& input_quantization = {0, 1.0f}, const xnn_quantization_params& output_quantization = {0, 1.0f}, const xnn_unary_params& params = xnn_unary_params()) { static_assert(!xnnpack::is_quantized::value, ""); static_assert(!xnnpack::is_quantized::value, ""); for (size_t i = 0; i < n; i++) { float y_i; if (std::is_integral::value && std::is_integral::value) { y[i] = op_info.ReferenceImpl((int)x[i], params); } else { if (std::is_integral::value) { y_i = op_info.ReferenceImpl((int)x[i], params); } else { y_i = op_info.ReferenceImpl(static_cast(x[i]), params); } if (std::is_integral::value) { y[i] = xnnpack::round_float_to_int(y_i); } else { y[i] = y_i; } } } } #endif // THIRD_PARTY_XNNPACK_TEST_UNARY_OPS_H_