/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #define FBGEMM_EXPORTS #include "./RefImplementations.h" #include "fbgemm/FbgemmBuild.h" #include "fbgemm/FbgemmConvert.h" #include "fbgemm/FloatConversion.h" #include #include #include #include #include #include #include using namespace std; namespace fbgemm { typedef union { uint32_t I; float F; } fint32; // Thread-safe random number generator // // Return a random 32bit integer using xoshiro128++ // http://prng.di.unimi.it/xoshiro128plusplus.c inline uint32_t rnd128_next(int idx, int vlen) { constexpr int VLEN_MAX = 16; // max vector size alignas(64) static thread_local uint32_t g_rnd128_buffer[4 * VLEN_MAX]; static thread_local bool g_rnd128_initialized = false; // Splitmix64: http://prng.di.unimi.it/splitmix64.c auto rnd128_init_next = [](uint64_t& x) { uint64_t z = (x += 0x9e3779b97f4a7c15); z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9; z = (z ^ (z >> 27)) * 0x94d049bb133111eb; return z ^ (z >> 31); }; auto rotl = [](const uint32_t x, int k) { return (x << k) | (x >> (32 - k)); }; if (!g_rnd128_initialized) { // Initialize rand buffer with uniq values per thread uint64_t h0 = std::hash{}(std::this_thread::get_id()); for (auto i = 0; i < 4; ++i) { // Use thread hash as seed g_rnd128_buffer[i * VLEN_MAX] = rnd128_init_next(h0); uint64_t h1 = g_rnd128_buffer[i * VLEN_MAX]; for (auto v = 1; v < VLEN_MAX; ++v) { g_rnd128_buffer[i * VLEN_MAX + v] = rnd128_init_next(h1); } } g_rnd128_initialized = true; } const uint32_t result = rotl(g_rnd128_buffer[idx] + g_rnd128_buffer[3 * vlen + idx], 7) + g_rnd128_buffer[idx]; const uint32_t t = g_rnd128_buffer[1 * vlen + idx] << 9; g_rnd128_buffer[2 * vlen + idx] ^= g_rnd128_buffer[0 * vlen + idx]; g_rnd128_buffer[3 * vlen + idx] ^= g_rnd128_buffer[1 * vlen + idx]; g_rnd128_buffer[1 * vlen + idx] ^= g_rnd128_buffer[2 * vlen + idx]; g_rnd128_buffer[0 * vlen + idx] ^= g_rnd128_buffer[3 * vlen + idx]; g_rnd128_buffer[2 * vlen + idx] ^= t; g_rnd128_buffer[3 * vlen + idx] = rotl(g_rnd128_buffer[3 * vlen + idx], 11); return result; } void FloatToFloat16_ref( const float* src, float16* dst, size_t size, bool do_clip) { constexpr float FP16_MAX = 65504.f; if (do_clip) { for (size_t i = 0; i < size; i++) { float cur_src = std::max(-FP16_MAX, std::min(src[i], FP16_MAX)); dst[i] = cpu_float2half_rn(cur_src); } } else { for (size_t i = 0; i < size; i++) { dst[i] = cpu_float2half_rn(src[i]); } } } void Float16ToFloat_ref(const float16* src, float* dst, size_t size) { for (size_t i = 0; i < size; i++) { dst[i] = cpu_half2float(src[i]); } } void FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size) { for (size_t i = 0; i < size; i++) { // Add 2^15 and right shift 16 to do round-nearest dst[i] = (*reinterpret_cast(src + i) + (1 << 15)) >> 16; } } void Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size) { for (size_t i = 0; i < size; i++) { uint32_t val_fp32 = static_cast(reinterpret_cast(src)[i]) << 16; reinterpret_cast(dst)[i] = val_fp32; } } void FloatToFloat8_ref( const float input, uint8_t* output, int exponent_bits, int exponent_bias) { float max_pos = (1 << ((1 << exponent_bits) - 2 - exponent_bias)) * (2 - std::pow(2, exponent_bits - 7)); int mantissa_bits = 7 - exponent_bits; fint32 val_out, bouncer, smallest_normal; val_out.F = input; uint32_t sign_bit = val_out.I & 0x80000000; val_out.I = val_out.I & 0x7FFFFFFF; val_out.F = fminf(val_out.F, max_pos); smallest_normal.I = (127 - exponent_bias + 1) << 23; // smallest hfp8 normal number in FP32 // I don't know if the input "min_pos" is the smallest denormalized number // or the smallest normalized number. The test below needs to be done with // the smallest normal number, which is the numerical value 2^(1-bias) // The conversion for denormalized values are slightly different. HFP8 is so // low precision that gradual underflow is probably crucial if (val_out.F >= smallest_normal.F) { // Use round to nearest even. We make use of the standard rounding mechanism // in FP32 rather than rounding the mantissa and handling tie-to-even and // incrementing exponent We want to round of 23-mbits of the FP32 value // val_in This can be done by adding a power of 2 exactly 23-mbits larger // than the exponent of val_in This forces val_in to be moved to the right // and rounding exact at the location corresponding to having mbits of // explicit mantissa left bouncer.I = (val_out.I & 0xFF800000) + ((23 - mantissa_bits) << 23); val_out.F = (bouncer.F + val_out.F) - bouncer.F; // adding the bouncer rounds off bits, and subtracting bouncer // leaves the desired value, albeit in FP32 encoding // All we need is to change the exponent encoding to using "bias" val_out.I = uint32_t(val_out.I - ((127 - exponent_bias) << 23)) << (8 - exponent_bits); val_out.I = ((val_out.I | sign_bit) >> 24); // the 8 lsbs is the desired HFP8 encoding } else { // When the value is in the denormal range, IEEE numbers essentially becomes // a fixed point number. The lsb is the smallest non-zero number // 2^(1-bias-mbits) Hence, we define the bouncer so that its lsb is this // smallest non-zero number Adding the input to this bouncer forces rounding // to occur appropriately Also, in this situation, after adding the bouncer, // the 8 least significant bits of the sum is already the HFP8 encoding of // the desired result. Just need to restore the sign bit bouncer.I = (127 + (23 + (1 - exponent_bias - mantissa_bits))) << 23; val_out.F = bouncer.F + val_out.F; val_out.I = val_out.I | (sign_bit >> 24); } *output = val_out.I; // get the 8 lsbs } void Float8ToFloat_ref( const uint8_t input, float* output, int exponent_bits, int exponent_bias) { fint32 val_out, sign, multiplier; sign.I = (input & 0x80) << 24; val_out.I = (input & 0x7F) << (24 - (8 - exponent_bits)); // so that the mantissa bits start at the mantissa bit positions of FP32 // encoding // Let the hfp8 mantissa bits correspond to the value frac, 0 <= frac < 1 // So if the hfp8 value is a normal number, it's value is 2^e x (1+frac) // where e is its (true, unbiased) exponent // If the hfp8 value is denormal, the value is 2^(1-bias) x frac // However, the bit pattern in the 8-bit exponent field of val_out.F // is bias+e when hfp8 is normal, and 0 when hfp8 is subnormal. // So, as an FP32 value, when hfp8 is normal, val_out.F represents the value // of 2^(bias+e-127) * (1+frac) // And when hfp8 is subnormal, val_out.F is also subnormal, and represents the // value of 2^(-126) * frac In either case, val_out.F corresponds to // 2^(bias-127) * (value of hfp8 input) Thus, if we multiply val_out.F by // 2^(127-bias), we obtain the hfp8 value as an FP32 number multiplier.I = (127 + (127 - exponent_bias)) << 23; // multiplier.F is 2^(127-bias) val_out.F *= multiplier.F; val_out.I |= sign.I; *output = val_out.F; } void requantize_u8acc32_ref( int M, int N, int ld, const int32_t* inp, uint8_t* out, int32_t C_multiplier, int32_t C_right_shift, int32_t C_zero_point, int32_t A_zero_point, int32_t B_zero_point, const int32_t* row_offsets, const int32_t* col_offsets, const int32_t* bias, bool fuse_relu) { int64_t nudge = 1ll << std::max(0, C_right_shift - 1); for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { int32_t raw = inp[i * ld + j]; if (A_zero_point) { raw -= A_zero_point * col_offsets[j]; } if (B_zero_point) { raw -= B_zero_point * row_offsets[i]; } if (bias) { raw += bias[j]; } int64_t ab_64 = static_cast(raw) * static_cast(C_multiplier); int64_t rounded = ((ab_64 + nudge) >> C_right_shift) + C_zero_point; out[i * ld + j] = std::max( fuse_relu ? static_cast(C_zero_point) : 0l, std::min(static_cast(255l), rounded)); } } } void requantize_u8acc32_ref( int M, int N, int ld, const int32_t* inp, uint8_t* out, const float* C_multiplier, int32_t C_zero_point, int32_t A_zero_point, const int32_t* B_zero_point, const int32_t* row_offsets, const int32_t* col_offsets, const int32_t* bias, int ncols_per_quant_group, bool fuse_relu) { for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { int32_t raw = inp[i * ld + j]; if (A_zero_point) { raw -= A_zero_point * col_offsets[j]; } raw -= B_zero_point[j / ncols_per_quant_group] * row_offsets[i]; if (bias) { raw += bias[j]; } float result = raw * C_multiplier[j / ncols_per_quant_group]; long rounded = lrintf(result) + C_zero_point; out[i * ld + j] = std::max( fuse_relu ? static_cast(C_zero_point) : 0l, std::min(255l, rounded)); } } } void matmul_u8i8acc32_ref( int M, int N, int K, int lda, int ldb, int ldc, const uint8_t* Aint8, const int8_t* Bint8, int32_t* Cint32) { for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { int32_t sum = 0; for (int k = 0; k < K; ++k) { sum += static_cast(Aint8[i * lda + k]) * static_cast(Bint8[k * ldb + j]); } Cint32[i * ldc + j] = sum; } } } void matmul_u8i8acc16_ref( int M, int N, int K, int lda, int ldb, int ldc, int brow, const uint8_t* Aint8, const int8_t* Bint8, int32_t* Cint32) { for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { int32_t sum = 0, sum_32bit = 0; for (int k = 0; k < K; k += 2) { int a0 = Aint8[i * lda + k]; int b0 = Bint8[k * ldb + j]; int a1 = 0, b1 = 0; if (k + 1 < K) { a1 = Aint8[i * lda + k + 1]; b1 = Bint8[(k + 1) * ldb + j]; } sum = clip_16bit(sum + clip_16bit(a0 * b0 + a1 * b1)); if ((k % brow) == (brow - 2)) { sum_32bit += sum; sum = 0; } } Cint32[i * ldc + j] = sum_32bit + sum; } } } void cblas_sgemm_ref( const matrix_op_t transa, const matrix_op_t transb, const int m, const int n, const int k, float alpha, const float* Afp32, int lda, const float* Bfp32, int ldb, float beta, float* Cfp32, int ldc) { for (int i = 0; i < m; ++i) { for (int j = 0; j < n; ++j) { float sum = 0; for (int p = 0; p < k; ++p) { float a = (transa == matrix_op_t::NoTranspose ? Afp32[i * lda + p] : Afp32[p * lda + i]); float b = (transb == matrix_op_t::NoTranspose ? Bfp32[p * ldb + j] : Bfp32[j * ldb + p]); sum += a * b; } if (beta == 0) { Cfp32[i * ldc + j] = alpha * sum; } else { Cfp32[i * ldc + j] = alpha * sum + beta * Cfp32[i * ldc + j]; } } } } namespace { // From https://stackoverflow.com/questions/31652875 uint64_t umul64wide(uint64_t a, uint64_t b) { uint64_t a_lo = static_cast(a); uint64_t a_hi = a >> 32; uint64_t b_lo = static_cast(b); uint64_t b_hi = b >> 32; uint64_t p0 = a_lo * b_lo; uint64_t p1 = a_lo * b_hi; uint64_t p2 = a_hi * b_lo; return p0 + (p1 << 32) + (p2 << 32); } } // namespace // Expected to have overflows NO_SANITIZE("undefined") void cblas_gemm_i64_i64acc_ref( matrix_op_t transa, matrix_op_t transb, int M, int N, int K, const int64_t* A, int lda, const int64_t* B, int ldb, bool accumulate, int64_t* C, int ldc) { for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { int64_t acc; if (accumulate) { acc = C[i * ldc + j]; } else { acc = 0; } for (int k = 0; k < K; ++k) { int64_t a = A[transa == matrix_op_t::Transpose ? i + k * lda : i * lda + k]; int64_t b = B[transb == matrix_op_t::Transpose ? k + j * ldb : k * ldb + j]; int64_t lo = umul64wide(a, b); acc += lo; } C[i * ldc + j] = acc; } // j } // i } void row_offsets_u8acc32_ref( int M, int K, int ld, const uint8_t* Aint8, int32_t* row_offsets) { // row offset for (int i = 0; i < M; ++i) { int32_t sum = 0; for (int k = 0; k < K; ++k) { sum += static_cast(Aint8[i * ld + k]); } row_offsets[i] = sum; } } void col_offsets_with_zero_pt_s8acc32_ref( int K, int N, int ld, const int8_t* Bint8, const int32_t* B_zero_point, int32_t* col_offsets, int ncols_per_quant_group) { for (int j = 0; j < N; ++j) { int32_t sum = 0; for (int k = 0; k < K; ++k) { sum += Bint8[k * ld + j]; } col_offsets[j] = sum - B_zero_point[j / ncols_per_quant_group] * K; } } void spmdm_ref( int M, const uint8_t* A, int lda, fbgemm::CompressedSparseColumn& B, bool accumulation, int32_t* C, int ldc, int groups /*=1*/) { int N = B.NumOfCols(); assert(N % groups == 0); if (!accumulation) { for (int i = 0; i < M; ++i) { for (int j = 0; j < N; ++j) { C[i * ldc + j] = 0; } } } for (int g = 0; g < groups; ++g) { for (int j = g * (N / groups); j < (g + 1) * (N / groups); ++j) { for (int k = B.ColPtr()[j]; k < B.ColPtr()[j + 1]; ++k) { int row = g * B.NumOfRows() + B.RowIdx()[k]; int w = B.Values()[k]; for (int i = 0; i < M; ++i) { C[i * ldc + j] += A[i * lda + row] * w; } } } // for each column of B } // for each group } int32_t clip_16bit(int32_t x) { if (x > numeric_limits::max()) { return std::min(numeric_limits::max(), x); } else if (x < numeric_limits::min()) { return std::max(numeric_limits::min(), x); } else { return x; } } /* Imitate the Im2Col function * from caffe2/utils/math_cpu.cc * NWC StorageOrder/Layout * A: NWC: NW_0 x C_0 * Ao: NWC: NW_1 x G KW C_0/G */ template <> FBGEMM_API void im2col_ref( const conv_param_t<1>& conv_p, const uint8_t* A, int32_t A_zero_point, uint8_t* Ao) { int IC = conv_p.IC; int G = conv_p.G; assert(IC % G == 0); array IN_DIM = conv_p.IN_DIM; array OUT_DIM = conv_p.OUT_DIM; array K = conv_p.K; if (conv_p.transposed) { for (int n = 0; n < conv_p.MB; ++n) { for (int ow = 0; ow < OUT_DIM[0]; ++ow) { for (int s = 0; s < K[0]; ++s) { int w = ow + conv_p.pad[0] - s * conv_p.dilation[0]; int w_in = w / conv_p.stride[0]; if (w_in * conv_p.stride[0] == w && w_in >= 0 && w_in < IN_DIM[0]) { for (int g = 0; g < G; ++g) { memcpy( Ao + (((n * OUT_DIM[0] + ow) * G + g) * K[0] + s) * (IC / G), A + (n * IN_DIM[0] + w_in) * IC + g * (IC / G), sizeof(uint8_t) * (IC / G)); } } else { for (int g = 0; g < G; ++g) { memset( Ao + (((n * OUT_DIM[0] + ow) * G + g) * K[0] + s) * (IC / G), A_zero_point, sizeof(uint8_t) * (IC / G)); } } } // for each s } // for each ow } // for each n } else { for (int n = 0; n < conv_p.MB; ++n) { for (int w = 0; w < OUT_DIM[0]; ++w) { for (int s = 0; s < K[0]; ++s) { int w_in = -conv_p.pad[0] + w * conv_p.stride[0] + s * conv_p.dilation[0]; if (w_in < 0 || w_in >= IN_DIM[0]) { for (int g = 0; g < G; ++g) { memset( Ao + (((n * OUT_DIM[0] + w) * G + g) * K[0] + s) * (IC / G), A_zero_point, sizeof(uint8_t) * (IC / G)); } } else { for (int g = 0; g < G; ++g) { memcpy( Ao + (((n * OUT_DIM[0] + w) * G + g) * K[0] + s) * (IC / G), A + (n * IN_DIM[0] + w_in) * IC + g * (IC / G), sizeof(uint8_t) * (IC / G)); } } } // for each s } // for each w } // for each n } } /* Imitate the Im2Col function * from caffe2/utils/math_cpu.cc * NHWC StorageOrder/Layout * A: NHWC: NH_0W_0 x C_0 * Ao: NHWC: NH_1W_1 x G RS C_0/G */ template <> FBGEMM_API void im2col_ref( const conv_param_t<2>& conv_p, const uint8_t* A, int32_t A_zero_point, uint8_t* Ao) { int IC = conv_p.IC; int G = conv_p.G; assert(IC % G == 0); array IN_DIM = conv_p.IN_DIM; array OUT_DIM = conv_p.OUT_DIM; array K = conv_p.K; if (conv_p.transposed) { for (int n = 0; n < conv_p.MB; ++n) { for (int oh = 0; oh < OUT_DIM[0]; ++oh) { for (int ow = 0; ow < OUT_DIM[1]; ++ow) { for (int r = 0; r < K[0]; ++r) { for (int s = 0; s < K[1]; ++s) { int h = oh + conv_p.pad[0] - r * conv_p.dilation[0]; int w = ow + conv_p.pad[1] - s * conv_p.dilation[1]; int h_in = h / conv_p.stride[0]; int w_in = w / conv_p.stride[1]; if (h_in * conv_p.stride[0] == h && h_in >= 0 && h_in < IN_DIM[0] && w_in * conv_p.stride[1] == w && w_in >= 0 && w_in < IN_DIM[1]) { for (int g = 0; g < G; ++g) { memcpy( Ao + (((((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * G + g) * K[0] + r) * K[1] + s) * (IC / G), A + ((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC + g * (IC / G), sizeof(uint8_t) * (IC / G)); } } else { for (int g = 0; g < G; ++g) { memset( Ao + (((((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * G + g) * K[0] + r) * K[1] + s) * (IC / G), A_zero_point, sizeof(uint8_t) * (IC / G)); } } } // for each s } // for each r } // for each ow } // for each oh } // for each n } else { for (int n = 0; n < conv_p.MB; ++n) { for (int h = 0; h < OUT_DIM[0]; ++h) { for (int w = 0; w < OUT_DIM[1]; ++w) { for (int r = 0; r < K[0]; ++r) { int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0]; for (int s = 0; s < K[1]; ++s) { int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1]; if (h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 || w_in >= IN_DIM[1]) { for (int g = 0; g < G; ++g) { memset( Ao + (((((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * G + g) * K[0] + r) * K[1] + s) * (IC / G), A_zero_point, sizeof(uint8_t) * (IC / G)); } } else { for (int g = 0; g < G; ++g) { memcpy( Ao + (((((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * G + g) * K[0] + r) * K[1] + s) * (IC / G), A + ((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC + g * (IC / G), sizeof(uint8_t) * (IC / G)); } } } // for each s } // for each r } // for each w } // for each h } // for each n } } /* Imitate the Im2Col function * from caffe2/utils/math_cpu.cc * NHWC StorageOrder/Layout * A: NHWC: NT_0H_0W_0 x C_0 * Ao: NHWC: NT_1H_1W_1 x G QRS C_0/G */ template <> FBGEMM_API void im2col_ref( const conv_param_t<3>& conv_p, const uint8_t* A, int32_t A_zero_point, uint8_t* Ao) { int IC = conv_p.IC; int G = conv_p.G; assert(IC % G == 0); array IN_DIM = conv_p.IN_DIM; array OUT_DIM = conv_p.OUT_DIM; array K = conv_p.K; if (conv_p.transposed) { for (int n = 0; n < conv_p.MB; ++n) { for (int ot = 0; ot < OUT_DIM[0]; ++ot) { for (int oh = 0; oh < OUT_DIM[1]; ++oh) { for (int ow = 0; ow < OUT_DIM[2]; ++ow) { for (int q = 0; q < K[0]; ++q) { for (int r = 0; r < K[1]; ++r) { for (int s = 0; s < K[2]; ++s) { int t = ot + conv_p.pad[0] - q * conv_p.dilation[0]; int h = oh + conv_p.pad[1] - r * conv_p.dilation[1]; int w = ow + conv_p.pad[2] - s * conv_p.dilation[2]; int t_in = t / conv_p.stride[0]; int h_in = h / conv_p.stride[1]; int w_in = w / conv_p.stride[2]; if (t_in * conv_p.stride[0] == t && t_in >= 0 && t_in < IN_DIM[0] && h_in * conv_p.stride[1] == h && h_in >= 0 && h_in < IN_DIM[1] && w_in * conv_p.stride[2] == w && w_in >= 0 && w_in < IN_DIM[2]) { for (int g = 0; g < G; ++g) { memcpy( Ao + (((((((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) * OUT_DIM[2] + ow) * G + g) * K[0] + q) * K[1] + r) * K[2] + s) * (IC / G), A + (((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) * IN_DIM[2] + w_in) * IC + g * (IC / G), sizeof(uint8_t) * (IC / G)); } } else { for (int g = 0; g < G; ++g) { memset( Ao + (((((((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) * OUT_DIM[2] + ow) * G + g) * K[0] + q) * K[1] + r) * K[2] + s) * (IC / G), A_zero_point, sizeof(uint8_t) * (IC / G)); } } } // for each s } // for each r } // for each q } // for each ow } // for each oh } // for each ot } // for each n } else { for (int n = 0; n < conv_p.MB; ++n) { for (int t = 0; t < OUT_DIM[0]; ++t) { for (int h = 0; h < OUT_DIM[1]; ++h) { for (int w = 0; w < OUT_DIM[2]; ++w) { for (int q = 0; q < K[0]; ++q) { int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q * conv_p.dilation[0]; for (int r = 0; r < K[1]; ++r) { int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r * conv_p.dilation[1]; for (int s = 0; s < K[2]; ++s) { int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s * conv_p.dilation[2]; if (t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 || h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]) { for (int g = 0; g < G; ++g) { memset( Ao + (((((((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) * OUT_DIM[2] + w) * G + g) * K[0] + q) * K[1] + r) * K[2] + s) * (IC / G), A_zero_point, sizeof(uint8_t) * (IC / G)); } } else { for (int g = 0; g < G; ++g) { memcpy( Ao + (((((((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) * OUT_DIM[2] + w) * G + g) * K[0] + q) * K[1] + r) * K[2] + s) * (IC / G), A + (((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) * IN_DIM[2] + w_in) * IC + g * (IC / G), sizeof(uint8_t) * (IC / G)); } } } // for each s } // for each r } // for each q } // for each w } // for each h } // for each t } // for each n } } // 1D Conv template <> FBGEMM_API void conv_ref( const conv_param_t<1>& conv_p, const uint8_t* A, int32_t A_zero_point, const int8_t* B, int32_t* C) { // A is assumed to be (N Lin Cin) // B is assumed to be (G K Cin/G Cout/G) // C is assumed to be (N Lout Cout) int IC = conv_p.IC; int OC = conv_p.OC; int G = conv_p.G; assert(IC % G == 0); assert(OC % G == 0); array IN_DIM = conv_p.IN_DIM; array OUT_DIM = conv_p.OUT_DIM; array K = conv_p.K; if (conv_p.transposed) { // for ref implementation, there is no padding on the input buffer, // padding specifies how much we remove from the output buffers for (int n = 0; n < conv_p.MB; ++n) { for (int ow = 0; ow < OUT_DIM[0]; ++ow) { // stride on output is fractional stride on input // conv index is // int w_in = -conv_p.pad[0] + w* conv_p.stride[0] + r* // conv_p.dilation[0]; // so we reverse it for (int g = 0; g < G; ++g) { for (int oc = 0; oc < OC / G; ++oc) { int sum = 0; for (int r = 0; r < K[0]; ++r) { int w = ow + conv_p.pad[0] - r * conv_p.dilation[0]; int w_in = w / conv_p.stride[0]; for (int ic = 0; ic < IC / G; ++ic) { int a = (w_in * conv_p.stride[0] == w && w_in >= 0 && w_in < IN_DIM[0]) ? A[(n * IN_DIM[0] + w_in) * IC + g * (IC / G) + ic] : A_zero_point; int b = B[((g * K[0] + r) * IC / G + ic) * (OC / G) + oc]; // G K IC/G OC/G after transpose sum += a * b; } // for each ic } // for each r C[(n * OUT_DIM[0] + ow) * OC + g * (OC / G) + oc] = sum; } // for each oc } // for each g } // for each w } // for each n } else { for (int n = 0; n < conv_p.MB; ++n) { for (int w = 0; w < OUT_DIM[0]; ++w) { for (int g = 0; g < G; ++g) { for (int m = 0; m < OC / G; ++m) { int sum = 0; for (int r = 0; r < K[0]; ++r) { int w_in = -conv_p.pad[0] + w * conv_p.stride[0] + r * conv_p.dilation[0]; for (int c = 0; c < IC / G; ++c) { int a = w_in < 0 || w_in >= IN_DIM[0] ? A_zero_point : A[(n * IN_DIM[0] + w_in) * IC + g * (IC / G) + c]; int b = B[((g * K[0] + r) * (IC / G) + c) * (OC / G) + m]; // G K IC/G OC/G after transpose sum += a * b; } // for each c } // for each r C[(n * OUT_DIM[0] + w) * OC + g * (OC / G) + m] = sum; } // for each w } // for each m } // for each group } // for each n } } // 2D Conv template <> FBGEMM_API void conv_ref( const conv_param_t<2>& conv_p, const uint8_t* A, int32_t A_zero_point, const int8_t* B, int32_t* C) { // filters are assumed to be in G RS C/G x K format int IC = conv_p.IC; int OC = conv_p.OC; int G = conv_p.G; assert(IC % G == 0); assert(OC % G == 0); array IN_DIM = conv_p.IN_DIM; array OUT_DIM = conv_p.OUT_DIM; array K = conv_p.K; if (conv_p.transposed) { // for ref implementation, there is no padding on the input buffer, // padding specifies how much we remove from the output buffers for (int n = 0; n < conv_p.MB; ++n) { for (int oh = 0; oh < OUT_DIM[0]; ++oh) { for (int ow = 0; ow < OUT_DIM[1]; ++ow) { // stride on output is fractional stride on input // conv index is // int h_in = // -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0]; // int w_in = // -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1]; // so we reverse it for (int g = 0; g < G; ++g) { for (int oc = 0; oc < OC / G; ++oc) { int sum = 0; for (int r = 0; r < K[0]; ++r) { for (int s = 0; s < K[1]; ++s) { int h = oh + conv_p.pad[0] - r * conv_p.dilation[0]; int w = ow + conv_p.pad[1] - s * conv_p.dilation[1]; int h_in = h / conv_p.stride[0]; int w_in = w / conv_p.stride[1]; for (int ic = 0; ic < IC / G; ++ic) { int a = (h_in * conv_p.stride[0] == h && h_in >= 0 && h_in < IN_DIM[0] && w_in * conv_p.stride[1] == w && w_in >= 0 && w_in < IN_DIM[1]) ? A[((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC + g * (IC / G) + ic] : A_zero_point; int b = B[((((g * K[0] + r) * K[1] + s) * (IC / G) + ic) * OC / G) + oc]; // G R S IC OC after transpose sum += a * b; } // for each ic } // for each s } // for each r C[((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * OC + g * (OC / G) + oc] = sum; } // for each oc } // for each g } // for each w } // for each h } // for each n } else { for (int n = 0; n < conv_p.MB; ++n) { for (int h = 0; h < OUT_DIM[0]; ++h) { for (int w = 0; w < OUT_DIM[1]; ++w) { for (int g = 0; g < G; ++g) { for (int m = 0; m < OC / G; ++m) { int sum = 0; for (int r = 0; r < K[0]; ++r) { int h_in = -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0]; for (int s = 0; s < K[1]; ++s) { int w_in = -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1]; for (int c = 0; c < IC / G; ++c) { int a = h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 || w_in >= IN_DIM[1] ? A_zero_point : A[((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC + g * (IC / G) + c]; int b = B[(((g * K[0] + r) * K[1] + s) * (IC / G) + c) * (OC / G) + m]; sum += a * b; } // for each c } // for each s } // for each r C[((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * OC + g * (OC / G) + m] = sum; } // for each m } // for each group } // for each w } // for each h } // for each n } } // 3D Conv template <> FBGEMM_API void conv_ref( const conv_param_t<3>& conv_p, const uint8_t* A, int32_t A_zero_point, const int8_t* B, int32_t* C) { // filters are assumed to be in G QRS C/G x K format int IC = conv_p.IC; int OC = conv_p.OC; int G = conv_p.G; assert(IC % G == 0); assert(OC % G == 0); array IN_DIM = conv_p.IN_DIM; array OUT_DIM = conv_p.OUT_DIM; array K = conv_p.K; if (conv_p.transposed) { // for ref implementation, there is no padding on the input buffer, // padding specifies how much we remove from the output buffers for (int n = 0; n < conv_p.MB; ++n) { for (int ot = 0; ot < OUT_DIM[0]; ++ot) { for (int oh = 0; oh < OUT_DIM[1]; ++oh) { for (int ow = 0; ow < OUT_DIM[2]; ++ow) { // stride on output is fractional stride on input // conv index is // int t_in = // -conv_p.pad[0] + t * conv_p.stride[0] + q * // conv_p.dilation[0]; // int h_in = // -conv_p.pad[1] + h * conv_p.stride[1] + r * // conv_p.dilation[1]; // int w_in = // -conv_p.pad[2] + w * conv_p.stride[2] + s * // conv_p.dilation[2]; // so we reverse it for (int g = 0; g < G; ++g) { for (int oc = 0; oc < OC / G; ++oc) { int sum = 0; for (int q = 0; q < K[0]; ++q) { for (int r = 0; r < K[1]; ++r) { for (int s = 0; s < K[2]; ++s) { int t = ot + conv_p.pad[0] - q * conv_p.dilation[0]; int h = oh + conv_p.pad[1] - r * conv_p.dilation[1]; int w = ow + conv_p.pad[2] - s * conv_p.dilation[2]; int t_in = t / conv_p.stride[0]; int h_in = h / conv_p.stride[1]; int w_in = w / conv_p.stride[2]; for (int ic = 0; ic < IC / G; ++ic) { int a = (t_in * conv_p.stride[0] == t && t_in >= 0 && t_in < IN_DIM[0] && h_in * conv_p.stride[1] == h && h_in >= 0 && h_in < IN_DIM[1] && w_in * conv_p.stride[2] == w && w_in >= 0 && w_in < IN_DIM[2]) ? A[((((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) * IN_DIM[2]) + w_in) * IC + g * (IC / G) + ic] : A_zero_point; int b = B[((((((g * K[0] + q)) * K[1] + r) * K[2] + s) * (IC / G) + ic) * (OC / G)) + oc]; // G Q R S Cin/G Cout/G after transpose sum += a * b; } // for each ic } // for each s } // for each r } // for each q C[(((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) * OUT_DIM[2] + ow) * OC + g * (OC / G) + oc] = sum; } // for each oc } // for each g } // for each ow } // for each oh } // for each ot } // for each n } else { for (int n = 0; n < conv_p.MB; ++n) { for (int t = 0; t < OUT_DIM[0]; ++t) { for (int h = 0; h < OUT_DIM[1]; ++h) { for (int w = 0; w < OUT_DIM[2]; ++w) { for (int g = 0; g < G; ++g) { for (int m = 0; m < OC / G; ++m) { int sum = 0; for (int q = 0; q < K[0]; ++q) { int t_in = -conv_p.pad[0] + t * conv_p.stride[0] + q * conv_p.dilation[0]; for (int r = 0; r < K[1]; ++r) { int h_in = -conv_p.pad[1] + h * conv_p.stride[1] + r * conv_p.dilation[1]; for (int s = 0; s < K[2]; ++s) { int w_in = -conv_p.pad[2] + w * conv_p.stride[2] + s * conv_p.dilation[2]; for (int c = 0; c < IC / G; ++c) { int a = t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 || h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2] ? A_zero_point : A[(((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) * IN_DIM[2] + w_in) * IC + g * (IC / G) + c]; int b = B[((((g * K[0] + q) * K[1] + r) * K[2] + s) * (IC / G) + c) * (OC / G) + m]; sum += a * b; } // for each c } // for each s } // for each r } // for each q C[(((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) * OUT_DIM[2] + w) * OC + g * (OC / G) + m] = sum; } // for each m } // for each group } // for each w } // for each h } // for each t } // for each n } } template void transposeConvWeights( const conv_param_t& conv_p, const std::int8_t* src, std::int8_t* dest) { int G = conv_p.G; int IC_per_G = conv_p.IC / conv_p.G; int OC_per_G = conv_p.OC / conv_p.G; int filter_prod = std::accumulate( conv_p.K.begin(), conv_p.K.begin() + SPATIAL_DIM, 1, std::multiplies()); // Transforms weights from G K/G (T R S C/G) to G (T R S C/G) K/G format. for (int g = 0; g < G; ++g) { for (int k = 0; k < OC_per_G; ++k) { for (int f = 0; f < filter_prod; ++f) { for (int c = 0; c < IC_per_G; ++c) { dest[((g * filter_prod + f) * IC_per_G + c) * OC_per_G + k] = src[((g * OC_per_G + k) * filter_prod + f) * IC_per_G + c]; } } } } } template float convert_to_float_ref(float src, bool is_bf16_out); template float convert_to_float_ref(uint16_t src, bool is_bf16_out); template float convert_from_float_ref(float src, bool is_bf16_out); template uint16_t convert_from_float_ref(float bfloat16, bool is_bf16_out); template < typename InType, typename IndexType, typename OffsetType, typename OutType> bool EmbeddingSpMDM_ref( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const InType* input, const IndexType* indices, const OffsetType* offsets_or_lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, OutType* out, bool is_weight_positional /*=false*/, bool use_offsets /*=true*/, int64_t output_stride /*=-1*/, int64_t input_stride /*=-1*/, bool scale_bias_last /*=true*/, bool no_bag /*=false*/, bool is_bf16_out /*=false*/, bool is_bf16_in /*=false*/) { const bool isWeight8bit = is_same::value; const bool isOutput8bit = is_same::value; if (output_stride == -1) { output_stride = block_size; } if constexpr (isOutput8bit) { assert(input_stride == output_stride); } vector buf(block_size); if (isWeight8bit) { // block_size is the number of elements and fused_block_size is the size of // an entire row, including scale and bias. if (input_stride == -1) { // scale_bias_last == false is for table batched embedding that stores // scale and bias in float16 const auto scale_bias_offset = 2 * (scale_bias_last ? sizeof(float) : sizeof(float16)); input_stride = block_size + scale_bias_offset; } int64_t current = 0; if (no_bag) { for (int m = 0; m < output_size; ++m) { int64_t idx = indices[m]; if (idx < 0 || idx >= data_size) { return false; } if constexpr (isOutput8bit) { const InType* input_row_ptr = input + input_stride * idx; memcpy(out, input_row_ptr, sizeof(InType) * input_stride); } else { memset(buf.data(), 0, sizeof(float) * block_size); const float* scale_bias = reinterpret_cast( input + input_stride * idx + (scale_bias_last ? block_size : 0)); float weight = 1.0f; if (weights) { weight = weights[m]; } float scale, bias; if (scale_bias_last) { scale = weight * scale_bias[0]; bias = weight * scale_bias[1]; } else { scale = weight * cpu_half2float(reinterpret_cast(scale_bias)[0]); bias = weight * cpu_half2float(reinterpret_cast(scale_bias)[1]); } for (int j = 0; j < block_size; ++j) { buf[j] = std::fma( scale, input [input_stride * idx + j + (scale_bias_last ? 0 : 2 * sizeof(float16))], buf[j] + bias); } for (int j = 0; j < block_size; ++j) { out[j] = convert_from_float_ref(buf[j], is_bf16_out); } } out += output_stride; } // m return true; } // no_bag for (int m = 0; m < output_size; ++m) { memset(buf.data(), 0, sizeof(float) * block_size); int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] : offsets_or_lengths[m]; if (current + len > index_size) { return false; } for (int i = 0; i < len; ++i, ++current) { int64_t idx = indices[current]; if (!scale_bias_last && idx == -1) { // When scale_bias_last == false, assume this is for table batched // embedding (TBE) that can get -1 for pruned rows. continue; } if (idx < 0 || idx >= data_size) { return false; } const float* scale_bias = reinterpret_cast( input + input_stride * idx + (scale_bias_last ? block_size : 0)); float weight = 1.0f; if (weights) { weight = weights[is_weight_positional ? i : current]; } float scale, bias; if (scale_bias_last) { scale = weight * scale_bias[0]; bias = weight * scale_bias[1]; } else { scale = weight * cpu_half2float(reinterpret_cast(scale_bias)[0]); bias = weight * cpu_half2float(reinterpret_cast(scale_bias)[1]); } for (int j = 0; j < block_size; ++j) { buf[j] = std::fma( scale, input [input_stride * idx + j + (scale_bias_last ? 0 : 2 * sizeof(float16))], buf[j] + bias); } } if (normalize_by_lengths && len) { float scale = 1.f / len; for (int j = 0; j < block_size; ++j) { buf[j] *= scale; } } for (int j = 0; j < block_size; ++j) { out[j] = convert_from_float_ref(buf[j], is_bf16_out); } out += output_stride; } return current == index_size; } else { if (input_stride == -1) { input_stride = block_size; } if (no_bag) { for (int m = 0; m < output_size; ++m) { memset(buf.data(), 0, sizeof(float) * block_size); int64_t idx = indices[m]; if (idx < 0 || idx >= data_size) { return false; } float w = 1.f; if (weights) { w = weights[m]; } for (int j = 0; j < block_size; ++j) { const InType* inptr = input + input_stride * idx + j; buf[j] = std::fma(w, convert_to_float_ref(*inptr, is_bf16_in), buf[j]); } for (int j = 0; j < block_size; ++j) { out[j] = convert_from_float_ref(buf[j], is_bf16_out); } out += output_stride; } // m return true; } // no_bag // Reference implementation of FP32 SLS int64_t current = 0; for (int m = 0; m < output_size; ++m) { memset(buf.data(), 0, sizeof(float) * block_size); int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] : offsets_or_lengths[m]; if (current + len > index_size) { return false; } for (int i = 0; i < len; ++i) { int64_t idx = indices[current]; if (idx < 0 || idx >= data_size) { return false; } float w = 1.f; if (weights) { w = weights[is_weight_positional ? i : current]; } for (int j = 0; j < block_size; ++j) { const InType* inptr = input + input_stride * idx + j; buf[j] = std::fma(w, convert_to_float_ref(*inptr, is_bf16_in), buf[j]); } ++current; } if (normalize_by_lengths && len) { float scale = 1.f / len; for (int j = 0; j < block_size; ++j) { buf[j] *= scale; } } for (int j = 0; j < block_size; ++j) { out[j] = convert_from_float_ref(buf[j], is_bf16_out); } out += output_stride; } return current == index_size; } } template bool EmbeddingSpMDMNBit_ref( int input_bit_rate, const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, const IndexType* indices, const OffsetType* offsets_or_lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, OutType* out, bool is_weight_positional /*=false*/, bool use_offsets /*=true*/, int64_t output_stride /*=-1*/, int64_t input_stride /*=-1*/, const bool scale_bias_last /*=true*/, const bool is_bf16_out /*=false*/, const bool no_bag /*=false*/, int output_bit_rate /*=-1*/) { if (output_bit_rate == -1) { output_bit_rate = 8 * sizeof(OutType); } nbit_embedding_sanity_check(input_bit_rate, output_bit_rate, no_bag); int num_elem_per_byte = 8 / input_bit_rate; if (output_stride == -1) { output_stride = block_size; } // block_size is the number of elements and fused_block_size is the size of // an entire row, including scale and bias. const auto scale_bias_offset = 2 * sizeof(float16); if (input_stride == -1) { input_stride = (block_size + num_elem_per_byte - 1) / num_elem_per_byte + scale_bias_offset; } if (no_bag) { // We currently only support int4 to int4 for sequential TBE in this nbit // kernel. Note that assert() will be ignored in release mode, so we check // here to double check and also avoid "unused variable" warning if (!(input_bit_rate == 4 && output_bit_rate == 4)) { WARN_ONCE("no_bag is only supported for int4 to int4"); return false; } for (int64_t i = 0; i < output_size; ++i) { const auto idx = indices[i]; if (idx < 0 || idx > data_size) { return false; } const uint8_t* input_row = input + input_stride * idx; memcpy(out, input_row, sizeof(uint8_t) * input_stride); out += input_stride; } return true; } int64_t current = 0; vector buf(block_size); for (int m = 0; m < output_size; ++m) { memset(buf.data(), 0, sizeof(float) * block_size); int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] : offsets_or_lengths[m]; if (current + len > index_size) { return false; } for (int i = 0; i < len; ++i, ++current) { int64_t idx = indices[current]; if (!scale_bias_last && idx == -1) { // When scale_bias_last == false, assume this is for table batched // embedding (TBE) that can get -1 for pruned rows. continue; } if (idx < 0 || idx >= data_size) { return false; } const float16* scale_bias = reinterpret_cast( input + input_stride * idx + (scale_bias_last ? (block_size + num_elem_per_byte - 1) / num_elem_per_byte : 0)); float weight = 1.0f; if (weights) { weight = weights[is_weight_positional ? i : current]; } const float scale = weight * cpu_half2float(scale_bias[0]); const float bias = weight * cpu_half2float(scale_bias[1]); for (int j = 0; j < block_size; ++j) { uint8_t quantized = input [input_stride * idx + j / num_elem_per_byte + (scale_bias_last ? 0 : scale_bias_offset)]; quantized >>= (j % num_elem_per_byte) * input_bit_rate; quantized &= (1 << input_bit_rate) - 1; buf[j] = std::fma(scale, quantized, buf[j] + bias); } } if (normalize_by_lengths && len) { float scale = 1.f / len; for (int j = 0; j < block_size; ++j) { buf[j] *= scale; } } for (int j = 0; j < block_size; ++j) { out[j] = convert_from_float_ref(buf[j], is_bf16_out); } out += output_stride; } return current == index_size; } template bool EmbeddingSpMDMFP8_ref( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t data_size, const uint8_t* input, const IndexType* indices, const OffsetType* offsets_or_lengths, const float* weights, bool normalize_by_lengths, OutType* out, bool is_weight_positional, bool use_offsets, int64_t output_stride, int64_t input_stride, int exponent_bits, int exponent_bias, bool is_bf16_out /*=false*/) { if (output_stride == -1) { output_stride = block_size; } vector buf(block_size); if (input_stride == -1) { input_stride = block_size; } // Reference implementation of FP8 SLS. The algorithm is similar to FP32 SLS // except for the FP8->FP32 conversion after reading the embedding weight. int64_t current = 0; for (int m = 0; m < output_size; ++m) { memset(buf.data(), 0, sizeof(float) * block_size); int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] : offsets_or_lengths[m]; if (current + len > index_size) { return false; } for (int i = 0; i < len; ++i) { int64_t idx = indices[current]; if (idx < 0 || idx >= data_size) { return false; } float w = 1.f; if (weights) { w = weights[is_weight_positional ? i : current]; } for (int j = 0; j < block_size; ++j) { const uint8_t* inptr = input + input_stride * idx + j; float input_f; // Dequantize FP8 to FP32 before compute Float8ToFloat_ref(*inptr, &input_f, exponent_bits, exponent_bias); buf[j] = std::fma(w, input_f, buf[j]); } ++current; } if (normalize_by_lengths && len) { float scale = 1.f / len; for (int j = 0; j < block_size; ++j) { buf[j] *= scale; } } for (int j = 0; j < block_size; ++j) { out[j] = is_same::value ? convert_from_float_ref(buf[j], is_bf16_out) : buf[j]; } out += output_stride; } return current == index_size; } template bool EmbeddingSpMDMRowWiseSparse_ref( const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t uncompressed_data_size, // const int64_t compressed_data_size, const InType* input, const IndexType* indices, const int32_t* compressed_indices_table, const OffsetType* offsets_or_lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, float* out, bool is_weight_positional, bool use_offsets) { bool is8bit = is_same::value; if (is8bit) { // block_size is the number of elements and fused_block_size is the size // of an entire row, including scale and bias. const auto scale_bias_offset = 2 * sizeof(float); const int64_t fused_block_size = block_size + scale_bias_offset; int64_t current = 0; for (int m = 0; m < output_size; ++m) { memset(out, 0, sizeof(float) * block_size); int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] : offsets_or_lengths[m]; if (current + len > index_size) { return false; } for (int i = 0; i < len; ++i) { IndexType uncompressed_idx = indices[current]; if (uncompressed_idx < 0 || uncompressed_idx >= uncompressed_data_size) { return false; } IndexType idx = compressed_indices_table[uncompressed_idx]; if (idx == -1) { ++current; continue; } // if (idx < 0 || idx >= compressed_data_size) { // return false; // } const float* scale_bias = reinterpret_cast( input + fused_block_size * idx + block_size); float weight = 1.0f; if (weights) { weight = weights[is_weight_positional ? i : current]; } const float scale = weight * scale_bias[0]; const float bias = weight * scale_bias[1]; for (int j = 0; j < block_size; ++j) { out[j] = std::fma(scale, input[fused_block_size * idx + j], out[j] + bias); } ++current; } if (normalize_by_lengths && len) { float scale = 1.f / len; for (int j = 0; j < block_size; ++j) { out[j] *= scale; } } out += block_size; } return current == index_size; } else { // Reference implementation of FP32 SLS int64_t current = 0; for (int m = 0; m < output_size; ++m) { memset(out, 0, sizeof(float) * block_size); int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] : offsets_or_lengths[m]; if (current + len > index_size) { return false; } for (int i = 0; i < len; ++i) { IndexType uncompressed_idx = indices[current]; if (uncompressed_idx < 0 || uncompressed_idx >= uncompressed_data_size) { return false; } IndexType idx = compressed_indices_table[uncompressed_idx]; if (idx == -1) { ++current; continue; } // if (idx < 0 || idx >= compressed_data_size) { // return false; // } float w = 1.f; if (weights) { w = weights[is_weight_positional ? i : current]; } for (int j = 0; j < block_size; ++j) { const InType* inptr = input + block_size * idx + j; out[j] = std::fma( w, is_same::value ? cpu_half2float(*inptr) : *inptr, out[j]); } ++current; } if (normalize_by_lengths && len) { float scale = 1.f / len; for (int j = 0; j < block_size; ++j) { out[j] *= scale; } } out += block_size; } return current == index_size; } } template bool EmbeddingSpMDMNBitRowWiseSparse_ref( int bit_rate, const int64_t block_size, const int64_t output_size, const int64_t index_size, const int64_t uncompressed_data_size, // const int64_t compressed_data_size, const uint8_t* input, const IndexType* indices, const int32_t* compressed_indices_table, const OffsetType* offsets_or_lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, float* out, bool is_weight_positional, bool use_offsets) { assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4"); int num_elem_per_byte = 8 / bit_rate; // block_size is the number of elements and fused_block_size is the size of // an entire row, including scale and bias. const auto scale_bias_offset = 2 * sizeof(float16); const int64_t fused_block_size = (block_size + num_elem_per_byte - 1) / num_elem_per_byte + scale_bias_offset; int64_t current = 0; for (int m = 0; m < output_size; ++m) { memset(out, 0, sizeof(float) * block_size); int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] : offsets_or_lengths[m]; if (current + len > index_size) { return false; } for (int i = 0; i < len; ++i, ++current) { IndexType uncompressed_idx = indices[current]; if (uncompressed_idx < 0 || uncompressed_idx >= uncompressed_data_size) { return false; } IndexType idx = compressed_indices_table[uncompressed_idx]; if (idx == -1) { continue; } // if (idx < 0 || idx >= compressed_data_size) { // return false; // } const float16* scale_bias = reinterpret_cast( input + fused_block_size * idx + (block_size + num_elem_per_byte - 1) / num_elem_per_byte); float weight = 1.0f; if (weights) { weight = weights[is_weight_positional ? i : current]; } const float scale = weight * cpu_half2float(scale_bias[0]); const float bias = weight * cpu_half2float(scale_bias[1]); for (int j = 0; j < block_size; ++j) { uint8_t quantized = input[fused_block_size * idx + j / num_elem_per_byte]; quantized >>= (j % num_elem_per_byte) * bit_rate; quantized &= (1 << bit_rate) - 1; out[j] = std::fma(scale, quantized, out[j] + bias); } } if (normalize_by_lengths && len) { float scale = 1.f / len; for (int j = 0; j < block_size; ++j) { out[j] *= scale; } } out += block_size; } return current == index_size; } template int sparse_adagrad_ref( int num_rows, // number of rows reading int block_size, // number of parameters per rows uint64_t param_size, // total number of parameters float* w, // input parameters const float* g, // input gradients float* h, // input momentums const IndexType* indices, // indices of each row float epsilon, float lr, float weight_decay, const double* counter, const int64_t counter_halflife) { for (auto i = 0; i < num_rows; ++i) { uint64_t idx = indices[i]; auto offsetI = i * block_size; auto offsetIdx = idx * block_size; if (block_size + offsetIdx > param_size) { return i; } float freq = (counter && counter[idx] > 0) ? counter_halflife / counter[idx] : 1.0; const float* g_; const float* h_; const float* w_; float* nh_; float* nw_; g_ = g + offsetI; h_ = h + offsetIdx; w_ = w + offsetIdx; nh_ = h + offsetIdx; nw_ = w + offsetIdx; for (auto j = 0; j < block_size; ++j) { float gj = std::fma(weight_decay * freq, w_[j], g_[j]); float hj = h_[j] + gj * gj; nh_[j] = hj; nw_[j] = w_[j] + lr * gj / (std::sqrt(hj) + epsilon); } } return num_rows; } template int rowwise_sparse_adagrad_ref( int num_rows, // number of rows reading int block_size, // number of parameters per rows uint64_t param_size, // total number of parameters float* w, // input parameters const float* g, // input gradients float* h, // input momentums const IndexType* indices, // indices of each row float epsilon, float lr, float weight_decay, const double* counter, const int64_t counter_halflife) { for (auto i = 0; i < num_rows; ++i) { uint64_t idx = indices[i]; auto offsetI = i * block_size; auto offsetIdx = idx * block_size; if (block_size + offsetIdx > param_size) { return i; } float freq = (counter && counter[idx] > 0) ? counter_halflife / counter[idx] : 1.0; const float* g_; float* h_; float* w_; g_ = g + offsetI; h_ = h + idx; // This is different from sparse adagrad w_ = w + offsetIdx; float final_sum = 0.0f; // Note the following code assumes fbgemm will generate AVX2 code for // horizontal reduction, which is OK for now because fbgemm always uses // AVX2 for SparseAdagrad due to its performance is bounded by memory // bandwidth hence no speedup from AVX512. Non-vectorized version would be // just for (auto j = 0; j < block_size; ++j) { // float gj = g_[j]; // final_sum += gj * gj; // } constexpr int VLEN = 8; array partial_sum = {0.0f}; for (auto j = 0; j < block_size; ++j) { float gj = std::fma(weight_decay * freq, w_[j], g_[j]); partial_sum[j % VLEN] += gj * gj; } final_sum = ((partial_sum[0] + partial_sum[1]) + (partial_sum[2] + partial_sum[3])) + ((partial_sum[4] + partial_sum[5]) + (partial_sum[6] + partial_sum[7])); final_sum /= block_size; float hi = *h_ = *h_ + final_sum; float float_step = lr / (std::sqrt(hi) + epsilon); for (auto j = 0; j < block_size; ++j) { float gj = std::fma(weight_decay * freq, w_[j], g_[j]); w_[j] += gj * float_step; } } return num_rows; } template int rowwise_sparse_adagrad_fused_ref( int64_t block_size, int64_t output_size, int64_t index_size, int64_t data_size, DataType* w, const float* g, float* h, const IndexType* indices, const OffsetType* offsets_or_lengths, float epsilon, float lr, bool use_offsets, bool use_stochastic_rounding, int emu_vector_size, int64_t grad_stride) { if (grad_stride == -1) { grad_stride = block_size; } constexpr bool isFloat16w = std::is_same::value; // Local random buffer to emulate SIMD vector // R: generated 32bit base random numbers // r: extracted 8-bit for rounding constexpr int VLEN_MAX = 16; uint32_t R[VLEN_MAX], r[VLEN_MAX]; int vlen = emu_vector_size; if (vlen != 8 && vlen != 16) { // Raise error as it may cause buffer overflow cerr << "Not supported emu_vector_size: " << emu_vector_size << endl; return 0; } int64_t current = 0; for (int m = 0; m < output_size; ++m) { int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m] : offsets_or_lengths[m]; if (current + len > index_size) { return false; } const float* g_ = g + m * grad_stride; // Note the following code assumes fbgemm will generate AVX2 code for // horizontal reduction, which is OK for now because fbgemm always uses // AVX2 for SparseAdagrad due to its performance is bounded by memory // bandwidth hence no speedup from AVX512. Non-vectorized version would be // just for (auto j = 0; j < block_size; ++j) { // float gj = g_[j]; // final_sum += gj * gj; // } constexpr int VLEN_AVX2 = 8; array partial_sum = {0.0f}; for (auto j = 0; j < block_size; ++j) { float gj = g_[j]; partial_sum[j % VLEN_AVX2] += gj * gj; } float final_sum = ((partial_sum[0] + partial_sum[1]) + (partial_sum[2] + partial_sum[3])) + ((partial_sum[4] + partial_sum[5]) + (partial_sum[6] + partial_sum[7])); final_sum /= block_size; for (int i = 0; i < len; ++i, ++current) { int64_t idx = indices[current]; if (idx < 0 || idx >= data_size) { return false; } float* h_ = h + idx; DataType* w_ = w + idx * block_size; float hi = *h_ = *h_ + final_sum; float float_step = lr / (std::sqrt(hi) + epsilon); int nvec = (block_size + vlen - 1) / vlen; int rem = (block_size % vlen) ? (block_size % vlen) : vlen; // Emulate JIT behavior of stochastic rounding with vector-length // // Generate R buffer every 4 steps of nvec loop. Each 8-bit in R // (uint32_t) will be used once. It is shifted to bits[5..13] then // added to FP32 weights before FP16 conversion. // // The shifted 8 bit region // +-------+--------+--------+--------+ // | | | xxxxx|xxx | // 31 23 15 7 0 // // Half float has 10 bits of mantissa, and float has 23, we are shifting // the bits to cover the region where half floats can't represent data. // This is bit 13-23 of the mantissa of fp32. // This will be effectively adding a random variable of [0,1] for (int n = 0; n < nvec; ++n) { int cur_vlen = (n == nvec - 1) ? rem : vlen; int sr_idx = n % 4; if (isFloat16w && use_stochastic_rounding) { if (sr_idx == 0) { for (int v = 0; v < vlen; ++v) { R[v] = rnd128_next(v, vlen); r[v] = (R[v] & 0xFFU) << 5; } } else if (sr_idx == 1) { for (int v = 0; v < vlen; ++v) { r[v] = ((R[v] & 0xFF00U) >> 8) << 5; } } else if (sr_idx == 2) { for (int v = 0; v < vlen; ++v) { r[v] = ((R[v] & 0xFF0000U) >> 16) << 5; } } else { // 3 for (int v = 0; v < vlen; ++v) { r[v] = ((R[v] & 0xFF000000U) >> 24) << 5; } } } for (int v = 0; v < cur_vlen; ++v) { int j = n * vlen + v; if (isFloat16w) { union { float w_f32; uint32_t w_i32; }; w_f32 = cpu_half2float(w_[j]); w_f32 = std::fma(float_step, g_[j], w_f32); if (use_stochastic_rounding) { w_i32 += r[v]; } // Use truncate rounding to 'counterwork' the random added part w_[j] = cpu_float2half_rz(w_f32); } else { // float w_[j] += g_[j] * float_step; } } } } } return current == index_size; } template FBGEMM_API void transposeConvWeights( const conv_param_t<1>& conv_p, const std::int8_t* src, std::int8_t* dest); template FBGEMM_API void transposeConvWeights( const conv_param_t<2>& conv_p, const std::int8_t* src, std::int8_t* dest); template FBGEMM_API void transposeConvWeights( const conv_param_t<3>& conv_p, const std::int8_t* src, std::int8_t* dest); #define INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ template FBGEMM_API bool EmbeddingSpMDM_ref( \ const int64_t block_size, \ const int64_t output_size, \ const int64_t index_size, \ const int64_t data_size, \ const IN_TYPE* input, \ const INDEX_TYPE* indices, \ const OFFSET_TYPE* offsets_or_lengths, \ const float* weights, \ bool normalize_by_lengths, \ OUT_TYPE* out, \ bool is_weight_positional, \ bool use_offsets, \ int64_t input_stride, \ int64_t output_stride, \ bool scale_bias_last, \ bool no_bag, \ bool is_bf16_out, \ bool is_bf16_in); #define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \ INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \ INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float16) \ INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, std::uint8_t) \ template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( \ const int64_t block_size, \ const int64_t output_size, \ const int64_t index_size, \ const int64_t uncompressed_data_size, \ const IN_TYPE* input, \ const INDEX_TYPE* indices, \ const int32_t* compressed_indices_table, \ const OFFSET_TYPE* offsets_or_lengths, \ const float* weights, \ bool normalize_by_lengths, \ float* out, \ bool is_weight_positional, \ bool use_offsets); #define INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, INDEX_TYPE) \ INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, std::int32_t) \ INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, std::int64_t) #define INSTANTIATE_SPMDM_INDEX_T(IN_TYPE) \ INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int32_t) \ INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int64_t) INSTANTIATE_SPMDM_INDEX_T(float) INSTANTIATE_SPMDM_INDEX_T(float16) INSTANTIATE_SPMDM_INDEX_T(std::uint8_t) #undef INSTANTIATE_SPMDM_INDEX_T #undef INSTANTIATE_SPMDM_OFFSET_T #undef INSTANTIATE_SPMDM_OUT_T #undef INSTANTIATE_SPMDM_BASE #define INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ template FBGEMM_API bool EmbeddingSpMDMNBit_ref( \ const int input_bit_rate, \ const int64_t block_size, \ const int64_t output_size, \ const int64_t index_size, \ const int64_t data_size, \ const uint8_t* input, \ const INDEX_TYPE* indices, \ const OFFSET_TYPE* offsets_or_lengths, \ const float* weights, \ bool normalize_by_lengths, \ OUT_TYPE* out, \ bool is_weight_positional, \ bool use_offsets, \ int64_t output_stride, \ int64_t input_stride, \ const bool scale_bias_last, \ const bool is_bf16_out, \ const bool no_bag, \ int output_bit_rate); #define INSTANTIATE_SPMDM_FP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ template FBGEMM_API bool EmbeddingSpMDMFP8_ref( \ const int64_t block_size, \ const int64_t output_size, \ const int64_t index_size, \ const int64_t data_size, \ const uint8_t* input, \ const INDEX_TYPE* indices, \ const OFFSET_TYPE* offsets_or_lengths, \ const float* weights, \ bool normalize_by_lengths, \ OUT_TYPE* out, \ bool is_weight_positional, \ bool use_offsets, \ int64_t output_stride, \ int64_t input_stride, \ int exponent_bits, \ int exponent_bias, \ bool is_bf16_out); #define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \ INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ INSTANTIATE_SPMDM_FP8_BASE(INDEX_TYPE, OFFSET_TYPE, float) \ INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \ INSTANTIATE_SPMDM_FP8_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \ INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, uint8_t) \ template FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref( \ int bit_rate, \ const int64_t block_size, \ const int64_t output_size, \ const int64_t index_size, \ const int64_t uncompressed_data_size, \ const uint8_t* input, \ const INDEX_TYPE* indices, \ const int32_t* compressed_indices_table, \ const OFFSET_TYPE* offsets_or_lengths, \ const float* weights, \ bool normalize_by_lengths, \ float* out, \ bool is_weight_positional, \ bool use_offsets); #define INSTANTIATE_SPMDM_OFFSET_T(INDEX_TYPE) \ INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int32_t) \ INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int64_t) INSTANTIATE_SPMDM_OFFSET_T(int32_t) INSTANTIATE_SPMDM_OFFSET_T(int64_t) #undef INSTANTIATE_SPMDM_OFFSET_T #undef INSTANTIATE_SPMDM_OUT_T #undef INSTANTIATE_SPMDM_BASE template FBGEMM_API int sparse_adagrad_ref( int num_rows, // number of rows reading int block_size, // number of parameters per rows std::uint64_t param_size, // total number of parameters float* w, // input parameters const float* g, // input gradients float* h, // input momentums const std::int64_t* indices, // indices of each row float epsilon, float lr, float weight_decay, const double* counter, const int64_t counter_halflife); template FBGEMM_API int sparse_adagrad_ref( int num_rows, // number of rows reading int block_size, // number of parameters per rows std::uint64_t param_size, // total number of parameters float* w, // input parameters const float* g, // input gradients float* h, // input momentums const std::int32_t* indices, // indices of each row float epsilon, float lr, float weight_decay, const double* counter, const int64_t counter_halflife); template FBGEMM_API int rowwise_sparse_adagrad_ref( int num_rows, // number of rows reading int block_size, // number of parameters per rows std::uint64_t param_size, // total number of parameters float* w, // input parameters const float* g, // input gradients float* h, // input momentums const std::int64_t* indices, // indices of each row float epsilon, float lr, float weight_decay, const double* counter, const int64_t counter_halflife); template FBGEMM_API int rowwise_sparse_adagrad_ref( int num_rows, // number of rows reading int block_size, // number of parameters per rows std::uint64_t param_size, // total number of parameters float* w, // input parameters const float* g, // input gradients float* h, // input momentums const std::int32_t* indices, // indices of each row float epsilon, float lr, float weight_decay, const double* counter, const int64_t counter_halflife); #define INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, OFFSET_TYPE) \ template FBGEMM_API int rowwise_sparse_adagrad_fused_ref( \ int64_t block_size, \ int64_t output_size, \ int64_t index_size, \ int64_t data_size, \ DATA_TYPE* w, \ const float* g, \ float* h, \ const INDEX_TYPE* indices, \ const OFFSET_TYPE* offsets_or_lengths, \ float epsilon, \ float lr, \ bool use_offsets, \ bool use_stochastic_rounding, \ int emu_vector_size, \ int64_t grad_stride); #define INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, INDEX_TYPE) \ INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, int32_t) \ INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, int64_t) #define INSTANTIATE_SPMDM_INDEX_T(DATA_TYPE) \ INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, int32_t) \ INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, int64_t) INSTANTIATE_SPMDM_INDEX_T(float) INSTANTIATE_SPMDM_INDEX_T(float16) #undef INSTANTIATE_SPMDM_OFFSET_T #undef INSTANTIATE_SPMDM_BASE template FBGEMM_API void compressed_indices_remap_ref( std::int32_t offsets_numel, const IndexType* indices, const int32_t* compressed_indices_mapping, const IndexType* offsets, const float* weights, // optional, can be null, IndexType* out_indices, IndexType* out_offsets, float* out_weights) { bool has_per_sample_weights = (weights != nullptr); out_offsets[0] = offsets[0]; IndexType j = 0; for (int i = 1; i < offsets_numel; i++) { for (int32_t k = offsets[i - 1]; k < offsets[i]; k++) { if (compressed_indices_mapping[indices[k]] != -1) { out_indices[j] = compressed_indices_mapping[indices[k]]; if (has_per_sample_weights) { out_weights[j] = weights[k]; } j++; } } out_offsets[i] = j; } } #define INSTANTIATE_REMAP_BASE(INDEX_TYPE) \ template FBGEMM_API void compressed_indices_remap_ref( \ std::int32_t offsets_numel, \ const INDEX_TYPE* indices, \ const int32_t* compressed_indices_mapping, \ const INDEX_TYPE* offsets, \ const float* weights, \ INDEX_TYPE* out_indices, \ INDEX_TYPE* out_offsets, \ float* out_weights); INSTANTIATE_REMAP_BASE(int32_t) INSTANTIATE_REMAP_BASE(int64_t) #undef INSTANTIATE_REMAP_BASE } // namespace fbgemm