/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include "fbgemm/Utils.h" #if HAVE_SVE #define FBGEMM_EXPORTS #include #include #include #include //for std::min/std::max #include //for assert #include // for FLT_MAX #include //for nearbyint #include //for memcpy #include //for numeric_limits #include "fbgemm/FloatConversion.h" #include "fbgemm/QuantUtilsNeon.h" #include "fbgemm/Types.h" namespace fbgemm { using namespace std; //////////////////////////////////////////////////////////////////////////////// // Utility functions template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( const std::uint8_t* input, size_t input_rows, int input_columns, OutputType* output) { size_t output_columns = std::max(input_columns - 2 * sizeof(float), 0); svbool_t allTruePred = svptrue_b32(); size_t output_columns_mod = output_columns % 8; svbool_t lastPredA = svwhilelt_b32_u64(0, output_columns_mod); svbool_t lastPredB = svwhilelt_b32_u64(4, output_columns_mod); svbool_t lastPredC = svwhilelt_b16_u64(0, output_columns_mod); const uint64_t* input_row_v_0 = reinterpret_cast(input); const uint64_t* input_row_v_1 = reinterpret_cast(input + 4); OutputType* output_row = output; for (; input_rows > 0; --input_rows) { const float* input_row_scale_bias = reinterpret_cast( reinterpret_cast(input_row_v_0) + output_columns); float scale = input_row_scale_bias[0]; float bias = input_row_scale_bias[1]; svfloat32_t scale_v = svdup_n_f32(scale); svfloat32_t bias_v = svdup_n_f32(bias); float32x4x2_t* output_row_v = reinterpret_cast(output_row); float16x8_t* output_row_v_half = reinterpret_cast(output_row); size_t colIndex = 0; for (size_t colMax = output_columns / 8; __builtin_expect(colIndex < colMax, 1); ++colIndex) { svuint32_t in_v_0 = svld1ub_u32( allTruePred, reinterpret_cast(input_row_v_0 + colIndex)); svuint32_t in_v_1 = svld1ub_u32( allTruePred, reinterpret_cast(input_row_v_1 + colIndex)); svfloat32_t in_v_0_f = svcvt_f32_u32_x(allTruePred, in_v_0); svfloat32_t in_v_1_f = svcvt_f32_u32_x(allTruePred, in_v_1); in_v_0_f = svmad_f32_m(allTruePred, in_v_0_f, scale_v, bias_v); in_v_1_f = svmad_f32_m(allTruePred, in_v_1_f, scale_v, bias_v); if constexpr (std::is_same()) { output_row_v[colIndex].val[0] = svget_neonq(in_v_0_f); output_row_v[colIndex].val[1] = svget_neonq(in_v_1_f); } else { float16x4_t dequantzed_v_half_low_low = vcvt_f16_f32(svget_neonq(in_v_0_f)); float16x8_t dequantzed_v_half_low = vcvt_high_f16_f32(dequantzed_v_half_low_low, svget_neonq(in_v_1_f)); output_row_v_half[colIndex] = dequantzed_v_half_low; } } svuint32_t in_v_0 = svld1ub_u32( lastPredA, reinterpret_cast(input_row_v_0 + colIndex)); svuint32_t in_v_1 = svld1ub_u32( lastPredB, reinterpret_cast(input_row_v_1 + colIndex)); svfloat32_t in_v_0_f = svcvt_f32_u32_x(lastPredA, in_v_0); svfloat32_t in_v_1_f = svcvt_f32_u32_x(lastPredB, in_v_1); in_v_0_f = svmad_f32_m(lastPredA, in_v_0_f, scale_v, bias_v); in_v_1_f = svmad_f32_m(lastPredB, in_v_1_f, scale_v, bias_v); if constexpr (std::is_same()) { svst1_f32(lastPredA, (float32_t*)&(output_row_v[colIndex]), in_v_0_f); svst1_f32( lastPredB, (float32_t*)&(output_row_v[colIndex].val[1]), in_v_1_f); } else { float16x4_t dequantzed_v_half_low_low = vcvt_f16_f32(svget_neonq(in_v_0_f)); float16x8_t dequantzed_v_half_low = vcvt_high_f16_f32(dequantzed_v_half_low_low, svget_neonq(in_v_1_f)); svst1_f16( lastPredC, (float16_t*)&(output_row_v_half[colIndex]), svset_neonq_f16(svundef_f16(), dequantzed_v_half_low)); } input_row_v_0 = reinterpret_cast( reinterpret_cast(input_row_v_0) + input_columns); input_row_v_1 = reinterpret_cast( reinterpret_cast(input_row_v_1) + input_columns); output_row += output_columns; } // for each row } #define INSTANTIATE_QuantizationNeonFunctions8Bits(type) \ template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon( \ const std::uint8_t* input, \ size_t input_rows, \ int input_columns, \ type* output); // clang-format off INSTANTIATE_QuantizationNeonFunctions8Bits(float) INSTANTIATE_QuantizationNeonFunctions8Bits(float16) // clang-format on #undef INSTANTIATE_QuantizationNeonFunctions8Bits } // namespace fbgemm #endif // __aarch64__