sglang_v0.5.2/pytorch_2.8.0/third_party/XNNPACK/test/unary-ops.h

672 lines
20 KiB
C++

// Copyright 2019 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#ifndef THIRD_PARTY_XNNPACK_TEST_UNARY_OPS_H_
#define THIRD_PARTY_XNNPACK_TEST_UNARY_OPS_H_
#pragma once
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <limits>
#include <random>
#include <type_traits>
#include "xnnpack.h"
#include "xnnpack/buffer.h"
#include "xnnpack/math.h"
#include "xnnpack/reference-utils.h"
static float TolExact(float) { return 0.0f; }
static float TolExact16(float y_ref) { return std::abs(y_ref) * 1.0e-3f; }
static float TolRelative(float y_ref, float rel_tol) {
// Note that `y_ref * rel_tol`, i.e. the expected absolute difference,
// may round differently than `y_ref * (1 + rel_tol) - y_ref`, i.e. the
// effective absolute difference computed in `float`s. We therefore use
// the latter form since it is the true difference between two `float`s
// within the given relative tolerance.
return std::abs(y_ref * (1.0f + rel_tol)) - std::abs(y_ref);
}
static float TolMixed(float y_ref, float abs_tol, float rel_tol) {
return std::max(abs_tol,
std::abs(y_ref) * (1.0f + rel_tol) - std::abs(y_ref));
}
struct Interval {
float min;
float max;
static Interval All() {
return {-std::numeric_limits<float>::infinity(),
std::numeric_limits<float>::infinity()};
}
static Interval Positive(xnn_datatype datatype) {
switch (datatype) {
case xnn_datatype_fp16:
return {0.001f, std::numeric_limits<float>::infinity()};
case xnn_datatype_fp32:
return {std::numeric_limits<float>::epsilon(),
std::numeric_limits<float>::infinity()};
default:
return {1.0f, std::numeric_limits<float>::infinity()};
}
}
};
// This struct describes a unary operator enough such that we can test them
// without knowing anything about the specific operator.
struct UnaryOpInfo {
virtual ~UnaryOpInfo() = default;
virtual float ReferenceImpl(float x, const xnn_unary_params& params) const {
XNN_UNREACHABLE;
}
virtual int ReferenceImpl(int x, const xnn_unary_params& params) const {
XNN_UNREACHABLE;
}
// Get the parameters to use by default for this operator.
virtual xnn_unary_params DefaultParams() const { return xnn_unary_params(); }
// Compute the tolerance for error given the reference result and the
// datatype.
virtual float Tolerance(float y_ref, xnn_datatype datatype) const {
switch (datatype) {
case xnn_datatype_qint8:
case xnn_datatype_quint8:
return 1;
case xnn_datatype_fp16:
return TolExact16(y_ref);
default:
return TolExact(y_ref);
}
}
virtual Interval Domain(xnn_datatype) const { return Interval::All(); }
// Quantization parameters to use by default.
virtual xnn_quantization_params InputQuantizationParams(
xnn_datatype datatype) const {
switch (datatype) {
case xnn_datatype_quint8:
return {150, 1.0f};
default:
return {0, 1.0f};
}
}
virtual xnn_quantization_params OutputQuantizationParams(
xnn_datatype datatype) const {
switch (datatype) {
case xnn_datatype_quint8:
return {100, 1.0f};
default:
return {0, 1.0f};
}
}
};
struct Convert : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return x;
}
int ReferenceImpl(int x, const xnn_unary_params&) const override { return x; }
};
struct ReLU : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::max(x, 0.0f);
}
int ReferenceImpl(int x, const xnn_unary_params&) const override {
return std::max(x, 0);
}
};
struct Abs : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::abs(x);
}
int ReferenceImpl(int x, const xnn_unary_params&) const override {
return std::abs(x);
}
};
struct Negate : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return -x;
}
int ReferenceImpl(int x, const xnn_unary_params&) const override {
return -x;
}
};
struct Clamp : public UnaryOpInfo {
xnn_unary_params DefaultParams() const override {
xnn_unary_params params;
params.clamp.min = -40.0f;
params.clamp.max = 50.0f;
return params;
}
float ReferenceImpl(float x, const xnn_unary_params& params) const override {
return std::min<float>(std::max<float>(x, params.clamp.min),
params.clamp.max);
}
int ReferenceImpl(int x, const xnn_unary_params& params) const override {
return std::min<int>(std::max<int>(x, params.clamp.min), params.clamp.max);
}
xnn_quantization_params InputQuantizationParams(
xnn_datatype datatype) const override {
return {0, 1.0f};
}
xnn_quantization_params OutputQuantizationParams(
xnn_datatype datatype) const override {
return {0, 1.0f};
}
};
struct ELU : public UnaryOpInfo {
xnn_unary_params DefaultParams() const override {
xnn_unary_params params;
params.elu.alpha = 1.0f;
return params;
}
float ReferenceImpl(float x, const xnn_unary_params& params) const override {
return std::signbit(x) ? params.elu.alpha * std::expm1(x) : x;
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp32:
return TolMixed(y_ref, 5.0e-6f, 1.0e-5f);
case xnn_datatype_fp16:
return TolMixed(y_ref, 1.0e-4f, 5.0e-3f);
default:
return 1;
}
}
Interval Domain(xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp16:
return {-9.0f, 9.0f};
default:
return {-20.0f, 20.0f};
}
}
};
struct GELU : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return x * 0.5f * (1.0f + std::erf(x * std::sqrt(2) / 2));
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp32:
case xnn_datatype_fp16:
case xnn_datatype_bf16:
return TolMixed(y_ref, 10 * std::numeric_limits<float>::epsilon(),
5 * std::numeric_limits<float>::epsilon());
case xnn_datatype_qint8:
case xnn_datatype_quint8:
return 1;
default:
XNN_UNREACHABLE;
}
}
Interval Domain(xnn_datatype) const override { return {-10.0f, 10.0f}; }
};
struct HardSwish : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return (x / 6.0) * std::max(std::min(x + 3.0, 6.0), 0.0);
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp32:
return TolMixed(y_ref, 5.0e-6f, 1.0e-5f);
case xnn_datatype_fp16:
case xnn_datatype_bf16:
return TolMixed(y_ref, 1.0e-3f, 1.0e-2f);
case xnn_datatype_qint8:
case xnn_datatype_quint8:
return 1;
default:
XNN_UNREACHABLE;
}
}
Interval Domain(xnn_datatype) const override { return {-4.0f, 4.0f}; }
};
struct LeakyReLU : public UnaryOpInfo {
xnn_unary_params DefaultParams() const override {
xnn_unary_params params;
params.leaky_relu.negative_slope = 0.5f;
return params;
}
float ReferenceImpl(float x, const xnn_unary_params& params) const override {
return std::signbit(x) ? x * params.leaky_relu.negative_slope : x;
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp32:
return TolExact(y_ref);
case xnn_datatype_fp16:
case xnn_datatype_bf16:
return TolMixed(y_ref, 1.0e-4f, 1.0e-3f);
case xnn_datatype_qint8:
case xnn_datatype_quint8:
return 1;
default:
XNN_UNREACHABLE;
}
}
};
struct RoundToNearestEven : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::nearbyint(x);
}
};
struct RoundTowardsZero : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::trunc(x);
}
};
struct RoundUp : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::ceil(x);
}
};
struct RoundDown : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::floor(x);
}
};
struct Sigmoid : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
if (x > 100) {
return 1.0f;
} else if (x < -100) {
return 0.0f;
} else {
const double e = std::exp(static_cast<double>(x));
return e / (1.0 + e);
}
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp32:
return TolMixed(y_ref, 5.0e-6f, 1.0e-5f);
case xnn_datatype_fp16:
case xnn_datatype_bf16:
return TolMixed(y_ref, 1.0e-4f, 5.0e-3f);
case xnn_datatype_qint8:
case xnn_datatype_quint8:
return 1;
default:
return TolExact(y_ref);
}
}
Interval Domain(xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp16:
return {-25.0f, 25.0f};
default:
return {-125.0f, 125.0f};
}
}
};
struct Square : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return x * x;
}
int ReferenceImpl(int x, const xnn_unary_params&) const override {
return static_cast<int>(static_cast<int64_t>(x) * static_cast<int64_t>(x));
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp32:
return TolExact(y_ref);
case xnn_datatype_fp16:
case xnn_datatype_bf16:
return TolMixed(y_ref, 1.0e-4f, 5.0e-3f);
case xnn_datatype_qint8:
case xnn_datatype_quint8:
return 1;
case xnn_datatype_int32:
// Overflow makes this hard to test.
return std::numeric_limits<float>::infinity();
default:
XNN_UNREACHABLE;
}
}
};
struct SquareRoot : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::sqrt(x);
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp32:
return TolRelative(y_ref, 2.5f * std::numeric_limits<float>::epsilon());
case xnn_datatype_fp16:
case xnn_datatype_bf16:
return TolMixed(y_ref, 1.0e-4f, 5.0e-3f);
case xnn_datatype_qint8:
case xnn_datatype_quint8:
return 1;
default:
XNN_UNREACHABLE;
}
}
Interval Domain(xnn_datatype datatype) const override {
if (datatype == xnn_datatype_fp16) {
return {0.001f, 10.0f};
} else {
return Interval::Positive(datatype);
}
}
};
struct TanH : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::tanh(x);
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp32:
return TolRelative(
y_ref,
4.0f * std::numeric_limits<float>::epsilon()); // 4 ULP
case xnn_datatype_fp16:
case xnn_datatype_bf16:
return TolMixed(y_ref, /*abs_tol=*/1.0e-4f, /*rel_tol=*/5.0e-3f);
default:
return 1;
}
}
Interval Domain(xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp16:
return {-5.0f, 5.0f};
default:
return {-10.0f, 10.0f};
}
}
};
struct ReciprocalSquareRoot : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return 1.0 / std::sqrt(x);
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp32:
return TolRelative(y_ref, 4 * std::numeric_limits<float>::epsilon());
case xnn_datatype_fp16:
case xnn_datatype_bf16:
return TolMixed(y_ref, 1.0e-4f, 5.0e-3f);
case xnn_datatype_qint8:
case xnn_datatype_quint8:
return 1;
default:
return TolExact(y_ref);
}
}
Interval Domain(xnn_datatype datatype) const override {
if (datatype == xnn_datatype_fp16) {
return {1.0e-4f, 10.0f};
} else {
return Interval::Positive(datatype);
}
}
};
struct Log : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::log(x);
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
return TolMixed(y_ref, 2 * std::numeric_limits<float>::epsilon(),
6 * std::numeric_limits<float>::epsilon());
}
Interval Domain(xnn_datatype datatype) const override {
return {std::numeric_limits<float>::epsilon(), 1000.0f};
}
};
struct Exp : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::exp(x);
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
return TolMixed(y_ref, 2 * std::numeric_limits<float>::epsilon(),
6 * std::numeric_limits<float>::epsilon());
}
Interval Domain(xnn_datatype) const override { return {-10.0f, 10.0f}; }
};
struct CubeRoot : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::cbrt(x);
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
switch (datatype) {
case xnn_datatype_fp32:
return TolRelative(y_ref, 2.5f * std::numeric_limits<float>::epsilon());
case xnn_datatype_fp16:
case xnn_datatype_bf16:
return TolMixed(y_ref, 1.0e-4f, 5.0e-3f);
case xnn_datatype_qint8:
case xnn_datatype_quint8:
return 1;
default:
XNN_UNREACHABLE;
}
}
Interval Domain(xnn_datatype datatype) const override {
if (datatype == xnn_datatype_fp16 || datatype == xnn_datatype_bf16) {
return {0.001f, 10.0f};
} else {
return Interval::Positive(datatype);
}
}
};
struct Cosine : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::cos(x);
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
return 1e-6f;
}
Interval Domain(xnn_datatype) const override { return {-10.0f, 10.0f}; }
};
struct Sine : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return std::sin(x);
}
float Tolerance(float y_ref, xnn_datatype datatype) const override {
return 1e-6f;
}
Interval Domain(xnn_datatype) const override { return {-10.0f, 10.0f}; }
};
struct CountLeadingZeros : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return (float)math_clz_u32((int)x);
}
int ReferenceImpl(int x, const xnn_unary_params&) const override {
return math_clz_u32(x);
}
};
struct BitwiseNot : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return ~(int)x;
}
int ReferenceImpl(int x, const xnn_unary_params&) const override {
return ~x;
}
};
struct Popcount : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return (float)math_popcount_u32((int)x);
}
int ReferenceImpl(int x, const xnn_unary_params&) const override {
return math_popcount_u32(x);
}
};
struct Sign : public UnaryOpInfo {
float ReferenceImpl(float x, const xnn_unary_params&) const override {
return x < 0.0f ? -1.0f : x > 0.0f ? 1.0f : 0.0f;
}
int ReferenceImpl(int x, const xnn_unary_params&) const override {
return x < 0 ? -1 : x > 0 ? 1 : 0;
}
};
const UnaryOpInfo* GetUnaryOpInfo(xnn_unary_operator op);
// Generate random data in the given domain, where the domain is given as
// unquantized values.
template <typename T, typename Rng>
void FillRandom(Rng& rng, T* x, size_t n, const Interval& domain,
const xnn_quantization_params& quantization = {0, 1.0f}) {
float min = domain.min;
float max = domain.max;
min = min * quantization.scale + quantization.zero_point;
max = max * quantization.scale + quantization.zero_point;
min = std::max<float>(domain.min, xnnpack::NumericLimits<T>::min());
max = std::min<float>(domain.max, xnnpack::NumericLimits<T>::max());
min = std::max<float>(min, -1e6f);
max = std::min<float>(max, 1e6f);
std::uniform_real_distribution<float> dist(min, max);
for (size_t i = 0; i < n; ++i) {
x[i] = static_cast<T>(dist(rng));
}
}
// Compute the result of a unary operator using the reference implementation.
template <typename In, typename Out, typename UnaryOp>
void UnaryReferenceImpl(
const xnnpack::quantized<In>* x, size_t n, xnnpack::quantized<Out>* y,
const UnaryOp& op_info,
const xnn_quantization_params& input_quantization = {0, 1.0f},
const xnn_quantization_params& output_quantization = {0, 1.0f},
const xnn_unary_params& params = xnn_unary_params()) {
for (size_t i = 0; i < n; i++) {
float x_i = (x[i] - input_quantization.zero_point) * input_quantization.scale;
float y_i = op_info.ReferenceImpl(x_i, params);
y_i = y_i / output_quantization.scale + output_quantization.zero_point;
y[i] = xnnpack::round_float_to_int<Out>(y_i);
}
}
// Compute the result of a unary operator using the reference implementation.
template <typename In, typename Out, typename UnaryOp>
void UnaryReferenceImpl(
const In* x, size_t n, xnnpack::quantized<Out>* y, const UnaryOp& op_info,
const xnn_quantization_params& input_quantization = {0, 1.0f},
const xnn_quantization_params& output_quantization = {0, 1.0f},
const xnn_unary_params& params = xnn_unary_params()) {
static_assert(!xnnpack::is_quantized<In>::value, "");
for (size_t i = 0; i < n; i++) {
float y_i = op_info.ReferenceImpl(static_cast<float>(x[i]), params);
y_i = y_i / output_quantization.scale + output_quantization.zero_point;
y[i] = xnnpack::round_float_to_int<Out>(y_i);
}
}
// Compute the result of a unary operator using the reference implementation.
template <typename In, typename Out, typename UnaryOp>
void UnaryReferenceImpl(
const xnnpack::quantized<In>* x, size_t n, Out* y, const UnaryOp& op_info,
const xnn_quantization_params& input_quantization = {0, 1.0f},
const xnn_quantization_params& output_quantization = {0, 1.0f},
const xnn_unary_params& params = xnn_unary_params()) {
static_assert(!xnnpack::is_quantized<Out>::value, "");
for (size_t i = 0; i < n; i++) {
float x_i = (x[i] - input_quantization.zero_point) * input_quantization.scale;
float y_i = op_info.ReferenceImpl(x_i, params);
if (std::is_integral<Out>::value) {
y[i] = xnnpack::round_float_to_int<Out>(y_i);
} else {
y[i] = y_i;
}
}
}
// Compute the result of a unary operator using the reference implementation.
template <typename In, typename Out, typename UnaryOp>
void UnaryReferenceImpl(
const In* x, size_t n, Out* y, const UnaryOp& op_info,
const xnn_quantization_params& input_quantization = {0, 1.0f},
const xnn_quantization_params& output_quantization = {0, 1.0f},
const xnn_unary_params& params = xnn_unary_params()) {
static_assert(!xnnpack::is_quantized<In>::value, "");
static_assert(!xnnpack::is_quantized<Out>::value, "");
for (size_t i = 0; i < n; i++) {
float y_i;
if (std::is_integral<In>::value && std::is_integral<Out>::value) {
y[i] = op_info.ReferenceImpl((int)x[i], params);
} else {
if (std::is_integral<In>::value) {
y_i = op_info.ReferenceImpl((int)x[i], params);
} else {
y_i = op_info.ReferenceImpl(static_cast<float>(x[i]), params);
}
if (std::is_integral<Out>::value) {
y[i] = xnnpack::round_float_to_int<Out>(y_i);
} else {
y[i] = y_i;
}
}
}
}
#endif // THIRD_PARTY_XNNPACK_TEST_UNARY_OPS_H_