sglang_v0.5.2/pytorch_2.8.0/third_party/fbgemm/src/QuantUtilsNeon.cc

142 lines
5.2 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "fbgemm/Utils.h"
#if HAVE_SVE
#define FBGEMM_EXPORTS
#include <arm_neon.h>
#include <arm_sve.h>
#include <arm_neon_sve_bridge.h>
#include <algorithm> //for std::min/std::max
#include <cassert> //for assert
#include <cfloat> // for FLT_MAX
#include <cmath> //for nearbyint
#include <cstring> //for memcpy
#include <limits> //for numeric_limits
#include "fbgemm/FloatConversion.h"
#include "fbgemm/QuantUtilsNeon.h"
#include "fbgemm/Types.h"
namespace fbgemm {
using namespace std;
////////////////////////////////////////////////////////////////////////////////
// Utility functions
template <typename OutputType>
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon(
const std::uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output) {
size_t output_columns = std::max<int>(input_columns - 2 * sizeof(float), 0);
svbool_t allTruePred = svptrue_b32();
size_t output_columns_mod = output_columns % 8;
svbool_t lastPredA = svwhilelt_b32_u64(0, output_columns_mod);
svbool_t lastPredB = svwhilelt_b32_u64(4, output_columns_mod);
svbool_t lastPredC = svwhilelt_b16_u64(0, output_columns_mod);
const uint64_t* input_row_v_0 = reinterpret_cast<const uint64_t*>(input);
const uint64_t* input_row_v_1 = reinterpret_cast<const uint64_t*>(input + 4);
OutputType* output_row = output;
for (; input_rows > 0; --input_rows) {
const float* input_row_scale_bias = reinterpret_cast<const float*>(
reinterpret_cast<const uint8_t*>(input_row_v_0) + output_columns);
float scale = input_row_scale_bias[0];
float bias = input_row_scale_bias[1];
svfloat32_t scale_v = svdup_n_f32(scale);
svfloat32_t bias_v = svdup_n_f32(bias);
float32x4x2_t* output_row_v = reinterpret_cast<float32x4x2_t*>(output_row);
float16x8_t* output_row_v_half = reinterpret_cast<float16x8_t*>(output_row);
size_t colIndex = 0;
for (size_t colMax = output_columns / 8;
__builtin_expect(colIndex < colMax, 1);
++colIndex) {
svuint32_t in_v_0 = svld1ub_u32(
allTruePred,
reinterpret_cast<const uint8_t*>(input_row_v_0 + colIndex));
svuint32_t in_v_1 = svld1ub_u32(
allTruePred,
reinterpret_cast<const uint8_t*>(input_row_v_1 + colIndex));
svfloat32_t in_v_0_f = svcvt_f32_u32_x(allTruePred, in_v_0);
svfloat32_t in_v_1_f = svcvt_f32_u32_x(allTruePred, in_v_1);
in_v_0_f = svmad_f32_m(allTruePred, in_v_0_f, scale_v, bias_v);
in_v_1_f = svmad_f32_m(allTruePred, in_v_1_f, scale_v, bias_v);
if constexpr (std::is_same<OutputType, float>()) {
output_row_v[colIndex].val[0] = svget_neonq(in_v_0_f);
output_row_v[colIndex].val[1] = svget_neonq(in_v_1_f);
} else {
float16x4_t dequantzed_v_half_low_low =
vcvt_f16_f32(svget_neonq(in_v_0_f));
float16x8_t dequantzed_v_half_low =
vcvt_high_f16_f32(dequantzed_v_half_low_low, svget_neonq(in_v_1_f));
output_row_v_half[colIndex] = dequantzed_v_half_low;
}
}
svuint32_t in_v_0 = svld1ub_u32(
lastPredA, reinterpret_cast<const uint8_t*>(input_row_v_0 + colIndex));
svuint32_t in_v_1 = svld1ub_u32(
lastPredB, reinterpret_cast<const uint8_t*>(input_row_v_1 + colIndex));
svfloat32_t in_v_0_f = svcvt_f32_u32_x(lastPredA, in_v_0);
svfloat32_t in_v_1_f = svcvt_f32_u32_x(lastPredB, in_v_1);
in_v_0_f = svmad_f32_m(lastPredA, in_v_0_f, scale_v, bias_v);
in_v_1_f = svmad_f32_m(lastPredB, in_v_1_f, scale_v, bias_v);
if constexpr (std::is_same<OutputType, float>()) {
svst1_f32(lastPredA, (float32_t*)&(output_row_v[colIndex]), in_v_0_f);
svst1_f32(
lastPredB, (float32_t*)&(output_row_v[colIndex].val[1]), in_v_1_f);
} else {
float16x4_t dequantzed_v_half_low_low =
vcvt_f16_f32(svget_neonq(in_v_0_f));
float16x8_t dequantzed_v_half_low =
vcvt_high_f16_f32(dequantzed_v_half_low_low, svget_neonq(in_v_1_f));
svst1_f16(
lastPredC,
(float16_t*)&(output_row_v_half[colIndex]),
svset_neonq_f16(svundef_f16(), dequantzed_v_half_low));
}
input_row_v_0 = reinterpret_cast<const uint64_t*>(
reinterpret_cast<const uint8_t*>(input_row_v_0) + input_columns);
input_row_v_1 = reinterpret_cast<const uint64_t*>(
reinterpret_cast<const uint8_t*>(input_row_v_1) + input_columns);
output_row += output_columns;
} // for each row
}
#define INSTANTIATE_QuantizationNeonFunctions8Bits(type) \
template void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfNeon<type>( \
const std::uint8_t* input, \
size_t input_rows, \
int input_columns, \
type* output);
// clang-format off
INSTANTIATE_QuantizationNeonFunctions8Bits(float)
INSTANTIATE_QuantizationNeonFunctions8Bits(float16)
// clang-format on
#undef INSTANTIATE_QuantizationNeonFunctions8Bits
} // namespace fbgemm
#endif // __aarch64__