2340 lines
81 KiB
C++
2340 lines
81 KiB
C++
/*
|
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD-style license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
#define FBGEMM_EXPORTS
|
|
#include "./RefImplementations.h"
|
|
|
|
#include "fbgemm/FbgemmBuild.h"
|
|
#include "fbgemm/FbgemmConvert.h"
|
|
#include "fbgemm/FloatConversion.h"
|
|
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
#include <cmath>
|
|
#include <cstring>
|
|
#include <iostream>
|
|
#include <numeric>
|
|
#include <thread>
|
|
|
|
using namespace std;
|
|
|
|
namespace fbgemm {
|
|
typedef union {
|
|
uint32_t I;
|
|
float F;
|
|
} fint32;
|
|
|
|
// Thread-safe random number generator
|
|
//
|
|
// Return a random 32bit integer using xoshiro128++
|
|
// http://prng.di.unimi.it/xoshiro128plusplus.c
|
|
inline uint32_t rnd128_next(int idx, int vlen) {
|
|
constexpr int VLEN_MAX = 16; // max vector size
|
|
alignas(64) static thread_local uint32_t g_rnd128_buffer[4 * VLEN_MAX];
|
|
static thread_local bool g_rnd128_initialized = false;
|
|
|
|
// Splitmix64: http://prng.di.unimi.it/splitmix64.c
|
|
auto rnd128_init_next = [](uint64_t& x) {
|
|
uint64_t z = (x += 0x9e3779b97f4a7c15);
|
|
z = (z ^ (z >> 30)) * 0xbf58476d1ce4e5b9;
|
|
z = (z ^ (z >> 27)) * 0x94d049bb133111eb;
|
|
return z ^ (z >> 31);
|
|
};
|
|
|
|
auto rotl = [](const uint32_t x, int k) {
|
|
return (x << k) | (x >> (32 - k));
|
|
};
|
|
|
|
if (!g_rnd128_initialized) {
|
|
// Initialize rand buffer with uniq values per thread
|
|
uint64_t h0 = std::hash<std::thread::id>{}(std::this_thread::get_id());
|
|
for (auto i = 0; i < 4; ++i) {
|
|
// Use thread hash as seed
|
|
g_rnd128_buffer[i * VLEN_MAX] = rnd128_init_next(h0);
|
|
uint64_t h1 = g_rnd128_buffer[i * VLEN_MAX];
|
|
for (auto v = 1; v < VLEN_MAX; ++v) {
|
|
g_rnd128_buffer[i * VLEN_MAX + v] = rnd128_init_next(h1);
|
|
}
|
|
}
|
|
g_rnd128_initialized = true;
|
|
}
|
|
|
|
const uint32_t result =
|
|
rotl(g_rnd128_buffer[idx] + g_rnd128_buffer[3 * vlen + idx], 7) +
|
|
g_rnd128_buffer[idx];
|
|
|
|
const uint32_t t = g_rnd128_buffer[1 * vlen + idx] << 9;
|
|
|
|
g_rnd128_buffer[2 * vlen + idx] ^= g_rnd128_buffer[0 * vlen + idx];
|
|
g_rnd128_buffer[3 * vlen + idx] ^= g_rnd128_buffer[1 * vlen + idx];
|
|
g_rnd128_buffer[1 * vlen + idx] ^= g_rnd128_buffer[2 * vlen + idx];
|
|
g_rnd128_buffer[0 * vlen + idx] ^= g_rnd128_buffer[3 * vlen + idx];
|
|
|
|
g_rnd128_buffer[2 * vlen + idx] ^= t;
|
|
|
|
g_rnd128_buffer[3 * vlen + idx] = rotl(g_rnd128_buffer[3 * vlen + idx], 11);
|
|
|
|
return result;
|
|
}
|
|
|
|
void FloatToFloat16_ref(
|
|
const float* src,
|
|
float16* dst,
|
|
size_t size,
|
|
bool do_clip) {
|
|
constexpr float FP16_MAX = 65504.f;
|
|
if (do_clip) {
|
|
for (size_t i = 0; i < size; i++) {
|
|
float cur_src = std::max(-FP16_MAX, std::min(src[i], FP16_MAX));
|
|
dst[i] = cpu_float2half_rn(cur_src);
|
|
}
|
|
} else {
|
|
for (size_t i = 0; i < size; i++) {
|
|
dst[i] = cpu_float2half_rn(src[i]);
|
|
}
|
|
}
|
|
}
|
|
|
|
void Float16ToFloat_ref(const float16* src, float* dst, size_t size) {
|
|
for (size_t i = 0; i < size; i++) {
|
|
dst[i] = cpu_half2float(src[i]);
|
|
}
|
|
}
|
|
|
|
void FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size) {
|
|
for (size_t i = 0; i < size; i++) {
|
|
// Add 2^15 and right shift 16 to do round-nearest
|
|
dst[i] = (*reinterpret_cast<const uint32_t*>(src + i) + (1 << 15)) >> 16;
|
|
}
|
|
}
|
|
|
|
void Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size) {
|
|
for (size_t i = 0; i < size; i++) {
|
|
uint32_t val_fp32 =
|
|
static_cast<uint32_t>(reinterpret_cast<const uint16_t*>(src)[i]) << 16;
|
|
reinterpret_cast<uint32_t*>(dst)[i] = val_fp32;
|
|
}
|
|
}
|
|
|
|
void FloatToFloat8_ref(
|
|
const float input,
|
|
uint8_t* output,
|
|
int exponent_bits,
|
|
int exponent_bias) {
|
|
float max_pos = (1 << ((1 << exponent_bits) - 2 - exponent_bias)) *
|
|
(2 - std::pow(2, exponent_bits - 7));
|
|
int mantissa_bits = 7 - exponent_bits;
|
|
fint32 val_out, bouncer, smallest_normal;
|
|
|
|
val_out.F = input;
|
|
uint32_t sign_bit = val_out.I & 0x80000000;
|
|
val_out.I = val_out.I & 0x7FFFFFFF;
|
|
val_out.F = fminf(val_out.F, max_pos);
|
|
|
|
smallest_normal.I = (127 - exponent_bias + 1)
|
|
<< 23; // smallest hfp8 normal number in FP32
|
|
// I don't know if the input "min_pos" is the smallest denormalized number
|
|
// or the smallest normalized number. The test below needs to be done with
|
|
// the smallest normal number, which is the numerical value 2^(1-bias)
|
|
|
|
// The conversion for denormalized values are slightly different. HFP8 is so
|
|
// low precision that gradual underflow is probably crucial
|
|
if (val_out.F >= smallest_normal.F) {
|
|
// Use round to nearest even. We make use of the standard rounding mechanism
|
|
// in FP32 rather than rounding the mantissa and handling tie-to-even and
|
|
// incrementing exponent We want to round of 23-mbits of the FP32 value
|
|
// val_in This can be done by adding a power of 2 exactly 23-mbits larger
|
|
// than the exponent of val_in This forces val_in to be moved to the right
|
|
// and rounding exact at the location corresponding to having mbits of
|
|
// explicit mantissa left
|
|
bouncer.I = (val_out.I & 0xFF800000) + ((23 - mantissa_bits) << 23);
|
|
val_out.F = (bouncer.F + val_out.F) - bouncer.F;
|
|
// adding the bouncer rounds off bits, and subtracting bouncer
|
|
// leaves the desired value, albeit in FP32 encoding
|
|
// All we need is to change the exponent encoding to using "bias"
|
|
val_out.I = uint32_t(val_out.I - ((127 - exponent_bias) << 23))
|
|
<< (8 - exponent_bits);
|
|
val_out.I =
|
|
((val_out.I | sign_bit) >>
|
|
24); // the 8 lsbs is the desired HFP8 encoding
|
|
|
|
} else {
|
|
// When the value is in the denormal range, IEEE numbers essentially becomes
|
|
// a fixed point number. The lsb is the smallest non-zero number
|
|
// 2^(1-bias-mbits) Hence, we define the bouncer so that its lsb is this
|
|
// smallest non-zero number Adding the input to this bouncer forces rounding
|
|
// to occur appropriately Also, in this situation, after adding the bouncer,
|
|
// the 8 least significant bits of the sum is already the HFP8 encoding of
|
|
// the desired result. Just need to restore the sign bit
|
|
bouncer.I = (127 + (23 + (1 - exponent_bias - mantissa_bits))) << 23;
|
|
val_out.F = bouncer.F + val_out.F;
|
|
val_out.I = val_out.I | (sign_bit >> 24);
|
|
}
|
|
|
|
*output = val_out.I; // get the 8 lsbs
|
|
}
|
|
|
|
void Float8ToFloat_ref(
|
|
const uint8_t input,
|
|
float* output,
|
|
int exponent_bits,
|
|
int exponent_bias) {
|
|
fint32 val_out, sign, multiplier;
|
|
|
|
sign.I = (input & 0x80) << 24;
|
|
val_out.I = (input & 0x7F) << (24 - (8 - exponent_bits));
|
|
// so that the mantissa bits start at the mantissa bit positions of FP32
|
|
// encoding
|
|
|
|
// Let the hfp8 mantissa bits correspond to the value frac, 0 <= frac < 1
|
|
// So if the hfp8 value is a normal number, it's value is 2^e x (1+frac)
|
|
// where e is its (true, unbiased) exponent
|
|
// If the hfp8 value is denormal, the value is 2^(1-bias) x frac
|
|
|
|
// However, the bit pattern in the 8-bit exponent field of val_out.F
|
|
// is bias+e when hfp8 is normal, and 0 when hfp8 is subnormal.
|
|
// So, as an FP32 value, when hfp8 is normal, val_out.F represents the value
|
|
// of 2^(bias+e-127) * (1+frac)
|
|
// And when hfp8 is subnormal, val_out.F is also subnormal, and represents the
|
|
// value of 2^(-126) * frac In either case, val_out.F corresponds to
|
|
// 2^(bias-127) * (value of hfp8 input) Thus, if we multiply val_out.F by
|
|
// 2^(127-bias), we obtain the hfp8 value as an FP32 number
|
|
|
|
multiplier.I = (127 + (127 - exponent_bias))
|
|
<< 23; // multiplier.F is 2^(127-bias)
|
|
val_out.F *= multiplier.F;
|
|
val_out.I |= sign.I;
|
|
*output = val_out.F;
|
|
}
|
|
|
|
void requantize_u8acc32_ref(
|
|
int M,
|
|
int N,
|
|
int ld,
|
|
const int32_t* inp,
|
|
uint8_t* out,
|
|
int32_t C_multiplier,
|
|
int32_t C_right_shift,
|
|
int32_t C_zero_point,
|
|
int32_t A_zero_point,
|
|
int32_t B_zero_point,
|
|
const int32_t* row_offsets,
|
|
const int32_t* col_offsets,
|
|
const int32_t* bias,
|
|
bool fuse_relu) {
|
|
int64_t nudge = 1ll << std::max(0, C_right_shift - 1);
|
|
for (int i = 0; i < M; ++i) {
|
|
for (int j = 0; j < N; ++j) {
|
|
int32_t raw = inp[i * ld + j];
|
|
if (A_zero_point) {
|
|
raw -= A_zero_point * col_offsets[j];
|
|
}
|
|
if (B_zero_point) {
|
|
raw -= B_zero_point * row_offsets[i];
|
|
}
|
|
if (bias) {
|
|
raw += bias[j];
|
|
}
|
|
|
|
int64_t ab_64 =
|
|
static_cast<int64_t>(raw) * static_cast<int64_t>(C_multiplier);
|
|
int64_t rounded = ((ab_64 + nudge) >> C_right_shift) + C_zero_point;
|
|
|
|
out[i * ld + j] = std::max(
|
|
fuse_relu ? static_cast<int64_t>(C_zero_point) : 0l,
|
|
std::min(static_cast<int64_t>(255l), rounded));
|
|
}
|
|
}
|
|
}
|
|
|
|
void requantize_u8acc32_ref(
|
|
int M,
|
|
int N,
|
|
int ld,
|
|
const int32_t* inp,
|
|
uint8_t* out,
|
|
const float* C_multiplier,
|
|
int32_t C_zero_point,
|
|
int32_t A_zero_point,
|
|
const int32_t* B_zero_point,
|
|
const int32_t* row_offsets,
|
|
const int32_t* col_offsets,
|
|
const int32_t* bias,
|
|
int ncols_per_quant_group,
|
|
bool fuse_relu) {
|
|
for (int i = 0; i < M; ++i) {
|
|
for (int j = 0; j < N; ++j) {
|
|
int32_t raw = inp[i * ld + j];
|
|
if (A_zero_point) {
|
|
raw -= A_zero_point * col_offsets[j];
|
|
}
|
|
raw -= B_zero_point[j / ncols_per_quant_group] * row_offsets[i];
|
|
if (bias) {
|
|
raw += bias[j];
|
|
}
|
|
|
|
float result = raw * C_multiplier[j / ncols_per_quant_group];
|
|
long rounded = lrintf(result) + C_zero_point;
|
|
out[i * ld + j] = std::max(
|
|
fuse_relu ? static_cast<long>(C_zero_point) : 0l,
|
|
std::min(255l, rounded));
|
|
}
|
|
}
|
|
}
|
|
|
|
void matmul_u8i8acc32_ref(
|
|
int M,
|
|
int N,
|
|
int K,
|
|
int lda,
|
|
int ldb,
|
|
int ldc,
|
|
const uint8_t* Aint8,
|
|
const int8_t* Bint8,
|
|
int32_t* Cint32) {
|
|
for (int i = 0; i < M; ++i) {
|
|
for (int j = 0; j < N; ++j) {
|
|
int32_t sum = 0;
|
|
for (int k = 0; k < K; ++k) {
|
|
sum += static_cast<int32_t>(Aint8[i * lda + k]) *
|
|
static_cast<int32_t>(Bint8[k * ldb + j]);
|
|
}
|
|
Cint32[i * ldc + j] = sum;
|
|
}
|
|
}
|
|
}
|
|
|
|
void matmul_u8i8acc16_ref(
|
|
int M,
|
|
int N,
|
|
int K,
|
|
int lda,
|
|
int ldb,
|
|
int ldc,
|
|
int brow,
|
|
const uint8_t* Aint8,
|
|
const int8_t* Bint8,
|
|
int32_t* Cint32) {
|
|
for (int i = 0; i < M; ++i) {
|
|
for (int j = 0; j < N; ++j) {
|
|
int32_t sum = 0, sum_32bit = 0;
|
|
for (int k = 0; k < K; k += 2) {
|
|
int a0 = Aint8[i * lda + k];
|
|
int b0 = Bint8[k * ldb + j];
|
|
int a1 = 0, b1 = 0;
|
|
if (k + 1 < K) {
|
|
a1 = Aint8[i * lda + k + 1];
|
|
b1 = Bint8[(k + 1) * ldb + j];
|
|
}
|
|
sum = clip_16bit(sum + clip_16bit(a0 * b0 + a1 * b1));
|
|
if ((k % brow) == (brow - 2)) {
|
|
sum_32bit += sum;
|
|
sum = 0;
|
|
}
|
|
}
|
|
Cint32[i * ldc + j] = sum_32bit + sum;
|
|
}
|
|
}
|
|
}
|
|
|
|
void cblas_sgemm_ref(
|
|
const matrix_op_t transa,
|
|
const matrix_op_t transb,
|
|
const int m,
|
|
const int n,
|
|
const int k,
|
|
float alpha,
|
|
const float* Afp32,
|
|
int lda,
|
|
const float* Bfp32,
|
|
int ldb,
|
|
float beta,
|
|
float* Cfp32,
|
|
int ldc) {
|
|
for (int i = 0; i < m; ++i) {
|
|
for (int j = 0; j < n; ++j) {
|
|
float sum = 0;
|
|
for (int p = 0; p < k; ++p) {
|
|
float a =
|
|
(transa == matrix_op_t::NoTranspose ? Afp32[i * lda + p]
|
|
: Afp32[p * lda + i]);
|
|
float b =
|
|
(transb == matrix_op_t::NoTranspose ? Bfp32[p * ldb + j]
|
|
: Bfp32[j * ldb + p]);
|
|
sum += a * b;
|
|
}
|
|
if (beta == 0) {
|
|
Cfp32[i * ldc + j] = alpha * sum;
|
|
} else {
|
|
Cfp32[i * ldc + j] = alpha * sum + beta * Cfp32[i * ldc + j];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
// From https://stackoverflow.com/questions/31652875
|
|
uint64_t umul64wide(uint64_t a, uint64_t b) {
|
|
uint64_t a_lo = static_cast<uint32_t>(a);
|
|
uint64_t a_hi = a >> 32;
|
|
uint64_t b_lo = static_cast<uint32_t>(b);
|
|
uint64_t b_hi = b >> 32;
|
|
|
|
uint64_t p0 = a_lo * b_lo;
|
|
uint64_t p1 = a_lo * b_hi;
|
|
uint64_t p2 = a_hi * b_lo;
|
|
|
|
return p0 + (p1 << 32) + (p2 << 32);
|
|
}
|
|
} // namespace
|
|
|
|
// Expected to have overflows
|
|
NO_SANITIZE("undefined")
|
|
void cblas_gemm_i64_i64acc_ref(
|
|
matrix_op_t transa,
|
|
matrix_op_t transb,
|
|
int M,
|
|
int N,
|
|
int K,
|
|
const int64_t* A,
|
|
int lda,
|
|
const int64_t* B,
|
|
int ldb,
|
|
bool accumulate,
|
|
int64_t* C,
|
|
int ldc) {
|
|
for (int i = 0; i < M; ++i) {
|
|
for (int j = 0; j < N; ++j) {
|
|
int64_t acc;
|
|
if (accumulate) {
|
|
acc = C[i * ldc + j];
|
|
} else {
|
|
acc = 0;
|
|
}
|
|
for (int k = 0; k < K; ++k) {
|
|
int64_t a =
|
|
A[transa == matrix_op_t::Transpose ? i + k * lda : i * lda + k];
|
|
int64_t b =
|
|
B[transb == matrix_op_t::Transpose ? k + j * ldb : k * ldb + j];
|
|
int64_t lo = umul64wide(a, b);
|
|
acc += lo;
|
|
}
|
|
C[i * ldc + j] = acc;
|
|
} // j
|
|
} // i
|
|
}
|
|
|
|
void row_offsets_u8acc32_ref(
|
|
int M,
|
|
int K,
|
|
int ld,
|
|
const uint8_t* Aint8,
|
|
int32_t* row_offsets) {
|
|
// row offset
|
|
for (int i = 0; i < M; ++i) {
|
|
int32_t sum = 0;
|
|
for (int k = 0; k < K; ++k) {
|
|
sum += static_cast<int32_t>(Aint8[i * ld + k]);
|
|
}
|
|
row_offsets[i] = sum;
|
|
}
|
|
}
|
|
|
|
void col_offsets_with_zero_pt_s8acc32_ref(
|
|
int K,
|
|
int N,
|
|
int ld,
|
|
const int8_t* Bint8,
|
|
const int32_t* B_zero_point,
|
|
int32_t* col_offsets,
|
|
int ncols_per_quant_group) {
|
|
for (int j = 0; j < N; ++j) {
|
|
int32_t sum = 0;
|
|
for (int k = 0; k < K; ++k) {
|
|
sum += Bint8[k * ld + j];
|
|
}
|
|
col_offsets[j] = sum - B_zero_point[j / ncols_per_quant_group] * K;
|
|
}
|
|
}
|
|
|
|
void spmdm_ref(
|
|
int M,
|
|
const uint8_t* A,
|
|
int lda,
|
|
fbgemm::CompressedSparseColumn& B,
|
|
bool accumulation,
|
|
int32_t* C,
|
|
int ldc,
|
|
int groups /*=1*/) {
|
|
int N = B.NumOfCols();
|
|
assert(N % groups == 0);
|
|
if (!accumulation) {
|
|
for (int i = 0; i < M; ++i) {
|
|
for (int j = 0; j < N; ++j) {
|
|
C[i * ldc + j] = 0;
|
|
}
|
|
}
|
|
}
|
|
for (int g = 0; g < groups; ++g) {
|
|
for (int j = g * (N / groups); j < (g + 1) * (N / groups); ++j) {
|
|
for (int k = B.ColPtr()[j]; k < B.ColPtr()[j + 1]; ++k) {
|
|
int row = g * B.NumOfRows() + B.RowIdx()[k];
|
|
int w = B.Values()[k];
|
|
for (int i = 0; i < M; ++i) {
|
|
C[i * ldc + j] += A[i * lda + row] * w;
|
|
}
|
|
}
|
|
} // for each column of B
|
|
} // for each group
|
|
}
|
|
|
|
int32_t clip_16bit(int32_t x) {
|
|
if (x > numeric_limits<int16_t>::max()) {
|
|
return std::min<int>(numeric_limits<int16_t>::max(), x);
|
|
} else if (x < numeric_limits<int16_t>::min()) {
|
|
return std::max<int>(numeric_limits<int16_t>::min(), x);
|
|
} else {
|
|
return x;
|
|
}
|
|
}
|
|
|
|
/* Imitate the Im2Col<float, CPUContext, StorageOrder::NWC> function
|
|
* from caffe2/utils/math_cpu.cc
|
|
* NWC StorageOrder/Layout
|
|
* A: NWC: NW_0 x C_0
|
|
* Ao: NWC: NW_1 x G KW C_0/G
|
|
*/
|
|
template <>
|
|
FBGEMM_API void im2col_ref(
|
|
const conv_param_t<1>& conv_p,
|
|
const uint8_t* A,
|
|
int32_t A_zero_point,
|
|
uint8_t* Ao) {
|
|
int IC = conv_p.IC;
|
|
int G = conv_p.G;
|
|
assert(IC % G == 0);
|
|
array<int, 1> IN_DIM = conv_p.IN_DIM;
|
|
array<int, 1> OUT_DIM = conv_p.OUT_DIM;
|
|
array<int, 1> K = conv_p.K;
|
|
|
|
if (conv_p.transposed) {
|
|
for (int n = 0; n < conv_p.MB; ++n) {
|
|
for (int ow = 0; ow < OUT_DIM[0]; ++ow) {
|
|
for (int s = 0; s < K[0]; ++s) {
|
|
int w = ow + conv_p.pad[0] - s * conv_p.dilation[0];
|
|
int w_in = w / conv_p.stride[0];
|
|
if (w_in * conv_p.stride[0] == w && w_in >= 0 && w_in < IN_DIM[0]) {
|
|
for (int g = 0; g < G; ++g) {
|
|
memcpy(
|
|
Ao + (((n * OUT_DIM[0] + ow) * G + g) * K[0] + s) * (IC / G),
|
|
A + (n * IN_DIM[0] + w_in) * IC + g * (IC / G),
|
|
sizeof(uint8_t) * (IC / G));
|
|
}
|
|
} else {
|
|
for (int g = 0; g < G; ++g) {
|
|
memset(
|
|
Ao + (((n * OUT_DIM[0] + ow) * G + g) * K[0] + s) * (IC / G),
|
|
A_zero_point,
|
|
sizeof(uint8_t) * (IC / G));
|
|
}
|
|
}
|
|
} // for each s
|
|
} // for each ow
|
|
} // for each n
|
|
} else {
|
|
for (int n = 0; n < conv_p.MB; ++n) {
|
|
for (int w = 0; w < OUT_DIM[0]; ++w) {
|
|
for (int s = 0; s < K[0]; ++s) {
|
|
int w_in =
|
|
-conv_p.pad[0] + w * conv_p.stride[0] + s * conv_p.dilation[0];
|
|
if (w_in < 0 || w_in >= IN_DIM[0]) {
|
|
for (int g = 0; g < G; ++g) {
|
|
memset(
|
|
Ao + (((n * OUT_DIM[0] + w) * G + g) * K[0] + s) * (IC / G),
|
|
A_zero_point,
|
|
sizeof(uint8_t) * (IC / G));
|
|
}
|
|
} else {
|
|
for (int g = 0; g < G; ++g) {
|
|
memcpy(
|
|
Ao + (((n * OUT_DIM[0] + w) * G + g) * K[0] + s) * (IC / G),
|
|
A + (n * IN_DIM[0] + w_in) * IC + g * (IC / G),
|
|
sizeof(uint8_t) * (IC / G));
|
|
}
|
|
}
|
|
} // for each s
|
|
} // for each w
|
|
} // for each n
|
|
}
|
|
}
|
|
|
|
/* Imitate the Im2Col<float, CPUContext, StorageOrder::NHWC> function
|
|
* from caffe2/utils/math_cpu.cc
|
|
* NHWC StorageOrder/Layout
|
|
* A: NHWC: NH_0W_0 x C_0
|
|
* Ao: NHWC: NH_1W_1 x G RS C_0/G
|
|
*/
|
|
template <>
|
|
FBGEMM_API void im2col_ref(
|
|
const conv_param_t<2>& conv_p,
|
|
const uint8_t* A,
|
|
int32_t A_zero_point,
|
|
uint8_t* Ao) {
|
|
int IC = conv_p.IC;
|
|
int G = conv_p.G;
|
|
assert(IC % G == 0);
|
|
array<int, 2> IN_DIM = conv_p.IN_DIM;
|
|
array<int, 2> OUT_DIM = conv_p.OUT_DIM;
|
|
array<int, 2> K = conv_p.K;
|
|
|
|
if (conv_p.transposed) {
|
|
for (int n = 0; n < conv_p.MB; ++n) {
|
|
for (int oh = 0; oh < OUT_DIM[0]; ++oh) {
|
|
for (int ow = 0; ow < OUT_DIM[1]; ++ow) {
|
|
for (int r = 0; r < K[0]; ++r) {
|
|
for (int s = 0; s < K[1]; ++s) {
|
|
int h = oh + conv_p.pad[0] - r * conv_p.dilation[0];
|
|
int w = ow + conv_p.pad[1] - s * conv_p.dilation[1];
|
|
int h_in = h / conv_p.stride[0];
|
|
int w_in = w / conv_p.stride[1];
|
|
if (h_in * conv_p.stride[0] == h && h_in >= 0 &&
|
|
h_in < IN_DIM[0] && w_in * conv_p.stride[1] == w &&
|
|
w_in >= 0 && w_in < IN_DIM[1]) {
|
|
for (int g = 0; g < G; ++g) {
|
|
memcpy(
|
|
Ao +
|
|
(((((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * G +
|
|
g) *
|
|
K[0] +
|
|
r) *
|
|
K[1] +
|
|
s) *
|
|
(IC / G),
|
|
A + ((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
|
|
g * (IC / G),
|
|
sizeof(uint8_t) * (IC / G));
|
|
}
|
|
} else {
|
|
for (int g = 0; g < G; ++g) {
|
|
memset(
|
|
Ao +
|
|
(((((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * G +
|
|
g) *
|
|
K[0] +
|
|
r) *
|
|
K[1] +
|
|
s) *
|
|
(IC / G),
|
|
A_zero_point,
|
|
sizeof(uint8_t) * (IC / G));
|
|
}
|
|
}
|
|
} // for each s
|
|
} // for each r
|
|
} // for each ow
|
|
} // for each oh
|
|
} // for each n
|
|
} else {
|
|
for (int n = 0; n < conv_p.MB; ++n) {
|
|
for (int h = 0; h < OUT_DIM[0]; ++h) {
|
|
for (int w = 0; w < OUT_DIM[1]; ++w) {
|
|
for (int r = 0; r < K[0]; ++r) {
|
|
int h_in =
|
|
-conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0];
|
|
for (int s = 0; s < K[1]; ++s) {
|
|
int w_in = -conv_p.pad[1] + w * conv_p.stride[1] +
|
|
s * conv_p.dilation[1];
|
|
if (h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
|
|
w_in >= IN_DIM[1]) {
|
|
for (int g = 0; g < G; ++g) {
|
|
memset(
|
|
Ao +
|
|
(((((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * G + g) *
|
|
K[0] +
|
|
r) *
|
|
K[1] +
|
|
s) *
|
|
(IC / G),
|
|
A_zero_point,
|
|
sizeof(uint8_t) * (IC / G));
|
|
}
|
|
} else {
|
|
for (int g = 0; g < G; ++g) {
|
|
memcpy(
|
|
Ao +
|
|
(((((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * G + g) *
|
|
K[0] +
|
|
r) *
|
|
K[1] +
|
|
s) *
|
|
(IC / G),
|
|
A + ((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
|
|
g * (IC / G),
|
|
sizeof(uint8_t) * (IC / G));
|
|
}
|
|
}
|
|
} // for each s
|
|
} // for each r
|
|
} // for each w
|
|
} // for each h
|
|
} // for each n
|
|
}
|
|
}
|
|
|
|
/* Imitate the Im2Col<float, CPUContext, StorageOrder::NHWC> function
|
|
* from caffe2/utils/math_cpu.cc
|
|
* NHWC StorageOrder/Layout
|
|
* A: NHWC: NT_0H_0W_0 x C_0
|
|
* Ao: NHWC: NT_1H_1W_1 x G QRS C_0/G
|
|
*/
|
|
template <>
|
|
FBGEMM_API void im2col_ref(
|
|
const conv_param_t<3>& conv_p,
|
|
const uint8_t* A,
|
|
int32_t A_zero_point,
|
|
uint8_t* Ao) {
|
|
int IC = conv_p.IC;
|
|
int G = conv_p.G;
|
|
assert(IC % G == 0);
|
|
array<int, 3> IN_DIM = conv_p.IN_DIM;
|
|
array<int, 3> OUT_DIM = conv_p.OUT_DIM;
|
|
array<int, 3> K = conv_p.K;
|
|
|
|
if (conv_p.transposed) {
|
|
for (int n = 0; n < conv_p.MB; ++n) {
|
|
for (int ot = 0; ot < OUT_DIM[0]; ++ot) {
|
|
for (int oh = 0; oh < OUT_DIM[1]; ++oh) {
|
|
for (int ow = 0; ow < OUT_DIM[2]; ++ow) {
|
|
for (int q = 0; q < K[0]; ++q) {
|
|
for (int r = 0; r < K[1]; ++r) {
|
|
for (int s = 0; s < K[2]; ++s) {
|
|
int t = ot + conv_p.pad[0] - q * conv_p.dilation[0];
|
|
int h = oh + conv_p.pad[1] - r * conv_p.dilation[1];
|
|
int w = ow + conv_p.pad[2] - s * conv_p.dilation[2];
|
|
int t_in = t / conv_p.stride[0];
|
|
int h_in = h / conv_p.stride[1];
|
|
int w_in = w / conv_p.stride[2];
|
|
if (t_in * conv_p.stride[0] == t && t_in >= 0 &&
|
|
t_in < IN_DIM[0] && h_in * conv_p.stride[1] == h &&
|
|
h_in >= 0 && h_in < IN_DIM[1] &&
|
|
w_in * conv_p.stride[2] == w && w_in >= 0 &&
|
|
w_in < IN_DIM[2]) {
|
|
for (int g = 0; g < G; ++g) {
|
|
memcpy(
|
|
Ao +
|
|
(((((((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) *
|
|
OUT_DIM[2] +
|
|
ow) *
|
|
G +
|
|
g) *
|
|
K[0] +
|
|
q) *
|
|
K[1] +
|
|
r) *
|
|
K[2] +
|
|
s) *
|
|
(IC / G),
|
|
A +
|
|
(((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
|
|
IN_DIM[2] +
|
|
w_in) *
|
|
IC +
|
|
g * (IC / G),
|
|
sizeof(uint8_t) * (IC / G));
|
|
}
|
|
} else {
|
|
for (int g = 0; g < G; ++g) {
|
|
memset(
|
|
Ao +
|
|
(((((((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) *
|
|
OUT_DIM[2] +
|
|
ow) *
|
|
G +
|
|
g) *
|
|
K[0] +
|
|
q) *
|
|
K[1] +
|
|
r) *
|
|
K[2] +
|
|
s) *
|
|
(IC / G),
|
|
A_zero_point,
|
|
sizeof(uint8_t) * (IC / G));
|
|
}
|
|
}
|
|
} // for each s
|
|
} // for each r
|
|
} // for each q
|
|
} // for each ow
|
|
} // for each oh
|
|
} // for each ot
|
|
} // for each n
|
|
} else {
|
|
for (int n = 0; n < conv_p.MB; ++n) {
|
|
for (int t = 0; t < OUT_DIM[0]; ++t) {
|
|
for (int h = 0; h < OUT_DIM[1]; ++h) {
|
|
for (int w = 0; w < OUT_DIM[2]; ++w) {
|
|
for (int q = 0; q < K[0]; ++q) {
|
|
int t_in = -conv_p.pad[0] + t * conv_p.stride[0] +
|
|
q * conv_p.dilation[0];
|
|
for (int r = 0; r < K[1]; ++r) {
|
|
int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
|
|
r * conv_p.dilation[1];
|
|
for (int s = 0; s < K[2]; ++s) {
|
|
int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
|
|
s * conv_p.dilation[2];
|
|
if (t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
|
|
h_in >= IN_DIM[1] || w_in < 0 || w_in >= IN_DIM[2]) {
|
|
for (int g = 0; g < G; ++g) {
|
|
memset(
|
|
Ao +
|
|
(((((((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) *
|
|
OUT_DIM[2] +
|
|
w) *
|
|
G +
|
|
g) *
|
|
K[0] +
|
|
q) *
|
|
K[1] +
|
|
r) *
|
|
K[2] +
|
|
s) *
|
|
(IC / G),
|
|
A_zero_point,
|
|
sizeof(uint8_t) * (IC / G));
|
|
}
|
|
} else {
|
|
for (int g = 0; g < G; ++g) {
|
|
memcpy(
|
|
Ao +
|
|
(((((((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) *
|
|
OUT_DIM[2] +
|
|
w) *
|
|
G +
|
|
g) *
|
|
K[0] +
|
|
q) *
|
|
K[1] +
|
|
r) *
|
|
K[2] +
|
|
s) *
|
|
(IC / G),
|
|
A +
|
|
(((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
|
|
IN_DIM[2] +
|
|
w_in) *
|
|
IC +
|
|
g * (IC / G),
|
|
sizeof(uint8_t) * (IC / G));
|
|
}
|
|
}
|
|
} // for each s
|
|
} // for each r
|
|
} // for each q
|
|
} // for each w
|
|
} // for each h
|
|
} // for each t
|
|
} // for each n
|
|
}
|
|
}
|
|
|
|
// 1D Conv
|
|
template <>
|
|
FBGEMM_API void conv_ref(
|
|
const conv_param_t<1>& conv_p,
|
|
const uint8_t* A,
|
|
int32_t A_zero_point,
|
|
const int8_t* B,
|
|
int32_t* C) {
|
|
// A is assumed to be (N Lin Cin)
|
|
// B is assumed to be (G K Cin/G Cout/G)
|
|
// C is assumed to be (N Lout Cout)
|
|
int IC = conv_p.IC;
|
|
int OC = conv_p.OC;
|
|
int G = conv_p.G;
|
|
assert(IC % G == 0);
|
|
assert(OC % G == 0);
|
|
array<int, 1> IN_DIM = conv_p.IN_DIM;
|
|
array<int, 1> OUT_DIM = conv_p.OUT_DIM;
|
|
array<int, 1> K = conv_p.K;
|
|
|
|
if (conv_p.transposed) {
|
|
// for ref implementation, there is no padding on the input buffer,
|
|
// padding specifies how much we remove from the output buffers
|
|
for (int n = 0; n < conv_p.MB; ++n) {
|
|
for (int ow = 0; ow < OUT_DIM[0]; ++ow) {
|
|
// stride on output is fractional stride on input
|
|
// conv index is
|
|
// int w_in = -conv_p.pad[0] + w* conv_p.stride[0] + r*
|
|
// conv_p.dilation[0];
|
|
// so we reverse it
|
|
for (int g = 0; g < G; ++g) {
|
|
for (int oc = 0; oc < OC / G; ++oc) {
|
|
int sum = 0;
|
|
for (int r = 0; r < K[0]; ++r) {
|
|
int w = ow + conv_p.pad[0] - r * conv_p.dilation[0];
|
|
int w_in = w / conv_p.stride[0];
|
|
for (int ic = 0; ic < IC / G; ++ic) {
|
|
int a = (w_in * conv_p.stride[0] == w && w_in >= 0 &&
|
|
w_in < IN_DIM[0])
|
|
? A[(n * IN_DIM[0] + w_in) * IC + g * (IC / G) + ic]
|
|
: A_zero_point;
|
|
int b =
|
|
B[((g * K[0] + r) * IC / G + ic) * (OC / G) +
|
|
oc]; // G K IC/G OC/G after transpose
|
|
sum += a * b;
|
|
} // for each ic
|
|
} // for each r
|
|
C[(n * OUT_DIM[0] + ow) * OC + g * (OC / G) + oc] = sum;
|
|
} // for each oc
|
|
} // for each g
|
|
} // for each w
|
|
} // for each n
|
|
} else {
|
|
for (int n = 0; n < conv_p.MB; ++n) {
|
|
for (int w = 0; w < OUT_DIM[0]; ++w) {
|
|
for (int g = 0; g < G; ++g) {
|
|
for (int m = 0; m < OC / G; ++m) {
|
|
int sum = 0;
|
|
for (int r = 0; r < K[0]; ++r) {
|
|
int w_in = -conv_p.pad[0] + w * conv_p.stride[0] +
|
|
r * conv_p.dilation[0];
|
|
for (int c = 0; c < IC / G; ++c) {
|
|
int a = w_in < 0 || w_in >= IN_DIM[0]
|
|
? A_zero_point
|
|
: A[(n * IN_DIM[0] + w_in) * IC + g * (IC / G) + c];
|
|
int b =
|
|
B[((g * K[0] + r) * (IC / G) + c) * (OC / G) +
|
|
m]; // G K IC/G OC/G after transpose
|
|
sum += a * b;
|
|
} // for each c
|
|
} // for each r
|
|
C[(n * OUT_DIM[0] + w) * OC + g * (OC / G) + m] = sum;
|
|
} // for each w
|
|
} // for each m
|
|
} // for each group
|
|
} // for each n
|
|
}
|
|
}
|
|
|
|
// 2D Conv
|
|
template <>
|
|
FBGEMM_API void conv_ref(
|
|
const conv_param_t<2>& conv_p,
|
|
const uint8_t* A,
|
|
int32_t A_zero_point,
|
|
const int8_t* B,
|
|
int32_t* C) {
|
|
// filters are assumed to be in G RS C/G x K format
|
|
int IC = conv_p.IC;
|
|
int OC = conv_p.OC;
|
|
int G = conv_p.G;
|
|
assert(IC % G == 0);
|
|
assert(OC % G == 0);
|
|
array<int, 2> IN_DIM = conv_p.IN_DIM;
|
|
array<int, 2> OUT_DIM = conv_p.OUT_DIM;
|
|
array<int, 2> K = conv_p.K;
|
|
|
|
if (conv_p.transposed) {
|
|
// for ref implementation, there is no padding on the input buffer,
|
|
// padding specifies how much we remove from the output buffers
|
|
for (int n = 0; n < conv_p.MB; ++n) {
|
|
for (int oh = 0; oh < OUT_DIM[0]; ++oh) {
|
|
for (int ow = 0; ow < OUT_DIM[1]; ++ow) {
|
|
// stride on output is fractional stride on input
|
|
// conv index is
|
|
// int h_in =
|
|
// -conv_p.pad[0] + h * conv_p.stride[0] + r * conv_p.dilation[0];
|
|
// int w_in =
|
|
// -conv_p.pad[1] + w * conv_p.stride[1] + s * conv_p.dilation[1];
|
|
// so we reverse it
|
|
for (int g = 0; g < G; ++g) {
|
|
for (int oc = 0; oc < OC / G; ++oc) {
|
|
int sum = 0;
|
|
for (int r = 0; r < K[0]; ++r) {
|
|
for (int s = 0; s < K[1]; ++s) {
|
|
int h = oh + conv_p.pad[0] - r * conv_p.dilation[0];
|
|
int w = ow + conv_p.pad[1] - s * conv_p.dilation[1];
|
|
int h_in = h / conv_p.stride[0];
|
|
int w_in = w / conv_p.stride[1];
|
|
for (int ic = 0; ic < IC / G; ++ic) {
|
|
int a = (h_in * conv_p.stride[0] == h && h_in >= 0 &&
|
|
h_in < IN_DIM[0] && w_in * conv_p.stride[1] == w &&
|
|
w_in >= 0 && w_in < IN_DIM[1])
|
|
? A[((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
|
|
g * (IC / G) + ic]
|
|
: A_zero_point;
|
|
int b =
|
|
B[((((g * K[0] + r) * K[1] + s) * (IC / G) + ic) * OC /
|
|
G) +
|
|
oc]; // G R S IC OC after transpose
|
|
sum += a * b;
|
|
} // for each ic
|
|
} // for each s
|
|
} // for each r
|
|
C[((n * OUT_DIM[0] + oh) * OUT_DIM[1] + ow) * OC + g * (OC / G) +
|
|
oc] = sum;
|
|
} // for each oc
|
|
} // for each g
|
|
} // for each w
|
|
} // for each h
|
|
} // for each n
|
|
} else {
|
|
for (int n = 0; n < conv_p.MB; ++n) {
|
|
for (int h = 0; h < OUT_DIM[0]; ++h) {
|
|
for (int w = 0; w < OUT_DIM[1]; ++w) {
|
|
for (int g = 0; g < G; ++g) {
|
|
for (int m = 0; m < OC / G; ++m) {
|
|
int sum = 0;
|
|
for (int r = 0; r < K[0]; ++r) {
|
|
int h_in = -conv_p.pad[0] + h * conv_p.stride[0] +
|
|
r * conv_p.dilation[0];
|
|
for (int s = 0; s < K[1]; ++s) {
|
|
int w_in = -conv_p.pad[1] + w * conv_p.stride[1] +
|
|
s * conv_p.dilation[1];
|
|
for (int c = 0; c < IC / G; ++c) {
|
|
int a = h_in < 0 || h_in >= IN_DIM[0] || w_in < 0 ||
|
|
w_in >= IN_DIM[1]
|
|
? A_zero_point
|
|
: A[((n * IN_DIM[0] + h_in) * IN_DIM[1] + w_in) * IC +
|
|
g * (IC / G) + c];
|
|
int b =
|
|
B[(((g * K[0] + r) * K[1] + s) * (IC / G) + c) *
|
|
(OC / G) +
|
|
m];
|
|
sum += a * b;
|
|
} // for each c
|
|
} // for each s
|
|
} // for each r
|
|
C[((n * OUT_DIM[0] + h) * OUT_DIM[1] + w) * OC + g * (OC / G) +
|
|
m] = sum;
|
|
} // for each m
|
|
} // for each group
|
|
} // for each w
|
|
} // for each h
|
|
} // for each n
|
|
}
|
|
}
|
|
|
|
// 3D Conv
|
|
template <>
|
|
FBGEMM_API void conv_ref(
|
|
const conv_param_t<3>& conv_p,
|
|
const uint8_t* A,
|
|
int32_t A_zero_point,
|
|
const int8_t* B,
|
|
int32_t* C) {
|
|
// filters are assumed to be in G QRS C/G x K format
|
|
int IC = conv_p.IC;
|
|
int OC = conv_p.OC;
|
|
int G = conv_p.G;
|
|
assert(IC % G == 0);
|
|
assert(OC % G == 0);
|
|
array<int, 3> IN_DIM = conv_p.IN_DIM;
|
|
array<int, 3> OUT_DIM = conv_p.OUT_DIM;
|
|
array<int, 3> K = conv_p.K;
|
|
|
|
if (conv_p.transposed) {
|
|
// for ref implementation, there is no padding on the input buffer,
|
|
// padding specifies how much we remove from the output buffers
|
|
for (int n = 0; n < conv_p.MB; ++n) {
|
|
for (int ot = 0; ot < OUT_DIM[0]; ++ot) {
|
|
for (int oh = 0; oh < OUT_DIM[1]; ++oh) {
|
|
for (int ow = 0; ow < OUT_DIM[2]; ++ow) {
|
|
// stride on output is fractional stride on input
|
|
// conv index is
|
|
// int t_in =
|
|
// -conv_p.pad[0] + t * conv_p.stride[0] + q *
|
|
// conv_p.dilation[0];
|
|
// int h_in =
|
|
// -conv_p.pad[1] + h * conv_p.stride[1] + r *
|
|
// conv_p.dilation[1];
|
|
// int w_in =
|
|
// -conv_p.pad[2] + w * conv_p.stride[2] + s *
|
|
// conv_p.dilation[2];
|
|
// so we reverse it
|
|
for (int g = 0; g < G; ++g) {
|
|
for (int oc = 0; oc < OC / G; ++oc) {
|
|
int sum = 0;
|
|
for (int q = 0; q < K[0]; ++q) {
|
|
for (int r = 0; r < K[1]; ++r) {
|
|
for (int s = 0; s < K[2]; ++s) {
|
|
int t = ot + conv_p.pad[0] - q * conv_p.dilation[0];
|
|
int h = oh + conv_p.pad[1] - r * conv_p.dilation[1];
|
|
int w = ow + conv_p.pad[2] - s * conv_p.dilation[2];
|
|
int t_in = t / conv_p.stride[0];
|
|
int h_in = h / conv_p.stride[1];
|
|
int w_in = w / conv_p.stride[2];
|
|
for (int ic = 0; ic < IC / G; ++ic) {
|
|
int a =
|
|
(t_in * conv_p.stride[0] == t && t_in >= 0 &&
|
|
t_in < IN_DIM[0] && h_in * conv_p.stride[1] == h &&
|
|
h_in >= 0 && h_in < IN_DIM[1] &&
|
|
w_in * conv_p.stride[2] == w && w_in >= 0 &&
|
|
w_in < IN_DIM[2])
|
|
? A[((((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
|
|
IN_DIM[2]) +
|
|
w_in) *
|
|
IC +
|
|
g * (IC / G) + ic]
|
|
: A_zero_point;
|
|
int b =
|
|
B[((((((g * K[0] + q)) * K[1] + r) * K[2] + s) *
|
|
(IC / G) +
|
|
ic) *
|
|
(OC / G)) +
|
|
oc]; // G Q R S Cin/G Cout/G after transpose
|
|
sum += a * b;
|
|
} // for each ic
|
|
} // for each s
|
|
} // for each r
|
|
} // for each q
|
|
C[(((n * OUT_DIM[0] + ot) * OUT_DIM[1] + oh) * OUT_DIM[2] +
|
|
ow) *
|
|
OC +
|
|
g * (OC / G) + oc] = sum;
|
|
} // for each oc
|
|
} // for each g
|
|
} // for each ow
|
|
} // for each oh
|
|
} // for each ot
|
|
} // for each n
|
|
} else {
|
|
for (int n = 0; n < conv_p.MB; ++n) {
|
|
for (int t = 0; t < OUT_DIM[0]; ++t) {
|
|
for (int h = 0; h < OUT_DIM[1]; ++h) {
|
|
for (int w = 0; w < OUT_DIM[2]; ++w) {
|
|
for (int g = 0; g < G; ++g) {
|
|
for (int m = 0; m < OC / G; ++m) {
|
|
int sum = 0;
|
|
for (int q = 0; q < K[0]; ++q) {
|
|
int t_in = -conv_p.pad[0] + t * conv_p.stride[0] +
|
|
q * conv_p.dilation[0];
|
|
for (int r = 0; r < K[1]; ++r) {
|
|
int h_in = -conv_p.pad[1] + h * conv_p.stride[1] +
|
|
r * conv_p.dilation[1];
|
|
for (int s = 0; s < K[2]; ++s) {
|
|
int w_in = -conv_p.pad[2] + w * conv_p.stride[2] +
|
|
s * conv_p.dilation[2];
|
|
for (int c = 0; c < IC / G; ++c) {
|
|
int a = t_in < 0 || t_in >= IN_DIM[0] || h_in < 0 ||
|
|
h_in >= IN_DIM[1] || w_in < 0 ||
|
|
w_in >= IN_DIM[2]
|
|
? A_zero_point
|
|
: A[(((n * IN_DIM[0] + t_in) * IN_DIM[1] + h_in) *
|
|
IN_DIM[2] +
|
|
w_in) *
|
|
IC +
|
|
g * (IC / G) + c];
|
|
int b =
|
|
B[((((g * K[0] + q) * K[1] + r) * K[2] + s) *
|
|
(IC / G) +
|
|
c) *
|
|
(OC / G) +
|
|
m];
|
|
sum += a * b;
|
|
} // for each c
|
|
} // for each s
|
|
} // for each r
|
|
} // for each q
|
|
C[(((n * OUT_DIM[0] + t) * OUT_DIM[1] + h) * OUT_DIM[2] + w) *
|
|
OC +
|
|
g * (OC / G) + m] = sum;
|
|
} // for each m
|
|
} // for each group
|
|
} // for each w
|
|
} // for each h
|
|
} // for each t
|
|
} // for each n
|
|
}
|
|
}
|
|
|
|
template <int SPATIAL_DIM>
|
|
void transposeConvWeights(
|
|
const conv_param_t<SPATIAL_DIM>& conv_p,
|
|
const std::int8_t* src,
|
|
std::int8_t* dest) {
|
|
int G = conv_p.G;
|
|
int IC_per_G = conv_p.IC / conv_p.G;
|
|
int OC_per_G = conv_p.OC / conv_p.G;
|
|
|
|
int filter_prod = std::accumulate(
|
|
conv_p.K.begin(),
|
|
conv_p.K.begin() + SPATIAL_DIM,
|
|
1,
|
|
std::multiplies<int>());
|
|
// Transforms weights from G K/G (T R S C/G) to G (T R S C/G) K/G format.
|
|
for (int g = 0; g < G; ++g) {
|
|
for (int k = 0; k < OC_per_G; ++k) {
|
|
for (int f = 0; f < filter_prod; ++f) {
|
|
for (int c = 0; c < IC_per_G; ++c) {
|
|
dest[((g * filter_prod + f) * IC_per_G + c) * OC_per_G + k] =
|
|
src[((g * OC_per_G + k) * filter_prod + f) * IC_per_G + c];
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
template float convert_to_float_ref(float src, bool is_bf16_out);
|
|
template float convert_to_float_ref(uint16_t src, bool is_bf16_out);
|
|
template float convert_from_float_ref(float src, bool is_bf16_out);
|
|
template uint16_t convert_from_float_ref(float bfloat16, bool is_bf16_out);
|
|
|
|
template <
|
|
typename InType,
|
|
typename IndexType,
|
|
typename OffsetType,
|
|
typename OutType>
|
|
bool EmbeddingSpMDM_ref(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const InType* input,
|
|
const IndexType* indices,
|
|
const OffsetType* offsets_or_lengths,
|
|
const float* weights, // optional, can be null for non-weighted sum
|
|
bool normalize_by_lengths,
|
|
OutType* out,
|
|
bool is_weight_positional /*=false*/,
|
|
bool use_offsets /*=true*/,
|
|
int64_t output_stride /*=-1*/,
|
|
int64_t input_stride /*=-1*/,
|
|
bool scale_bias_last /*=true*/,
|
|
bool no_bag /*=false*/,
|
|
bool is_bf16_out /*=false*/,
|
|
bool is_bf16_in /*=false*/) {
|
|
const bool isWeight8bit = is_same<InType, uint8_t>::value;
|
|
const bool isOutput8bit = is_same<OutType, uint8_t>::value;
|
|
if (output_stride == -1) {
|
|
output_stride = block_size;
|
|
}
|
|
if constexpr (isOutput8bit) {
|
|
assert(input_stride == output_stride);
|
|
}
|
|
vector<float> buf(block_size);
|
|
|
|
if (isWeight8bit) {
|
|
// block_size is the number of elements and fused_block_size is the size of
|
|
// an entire row, including scale and bias.
|
|
if (input_stride == -1) {
|
|
// scale_bias_last == false is for table batched embedding that stores
|
|
// scale and bias in float16
|
|
const auto scale_bias_offset =
|
|
2 * (scale_bias_last ? sizeof(float) : sizeof(float16));
|
|
input_stride = block_size + scale_bias_offset;
|
|
}
|
|
int64_t current = 0;
|
|
|
|
if (no_bag) {
|
|
for (int m = 0; m < output_size; ++m) {
|
|
int64_t idx = indices[m];
|
|
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
if constexpr (isOutput8bit) {
|
|
const InType* input_row_ptr = input + input_stride * idx;
|
|
memcpy(out, input_row_ptr, sizeof(InType) * input_stride);
|
|
} else {
|
|
memset(buf.data(), 0, sizeof(float) * block_size);
|
|
const float* scale_bias = reinterpret_cast<const float*>(
|
|
input + input_stride * idx + (scale_bias_last ? block_size : 0));
|
|
|
|
float weight = 1.0f;
|
|
if (weights) {
|
|
weight = weights[m];
|
|
}
|
|
|
|
float scale, bias;
|
|
if (scale_bias_last) {
|
|
scale = weight * scale_bias[0];
|
|
bias = weight * scale_bias[1];
|
|
} else {
|
|
scale = weight *
|
|
cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[0]);
|
|
bias = weight *
|
|
cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[1]);
|
|
}
|
|
|
|
for (int j = 0; j < block_size; ++j) {
|
|
buf[j] = std::fma(
|
|
scale,
|
|
input
|
|
[input_stride * idx + j +
|
|
(scale_bias_last ? 0 : 2 * sizeof(float16))],
|
|
buf[j] + bias);
|
|
}
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] = convert_from_float_ref<OutType>(buf[j], is_bf16_out);
|
|
}
|
|
}
|
|
out += output_stride;
|
|
} // m
|
|
return true;
|
|
} // no_bag
|
|
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(buf.data(), 0, sizeof(float) * block_size);
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
if (current + len > index_size) {
|
|
return false;
|
|
}
|
|
for (int i = 0; i < len; ++i, ++current) {
|
|
int64_t idx = indices[current];
|
|
if (!scale_bias_last && idx == -1) {
|
|
// When scale_bias_last == false, assume this is for table batched
|
|
// embedding (TBE) that can get -1 for pruned rows.
|
|
continue;
|
|
}
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
|
|
const float* scale_bias = reinterpret_cast<const float*>(
|
|
input + input_stride * idx + (scale_bias_last ? block_size : 0));
|
|
|
|
float weight = 1.0f;
|
|
if (weights) {
|
|
weight = weights[is_weight_positional ? i : current];
|
|
}
|
|
float scale, bias;
|
|
if (scale_bias_last) {
|
|
scale = weight * scale_bias[0];
|
|
bias = weight * scale_bias[1];
|
|
} else {
|
|
scale = weight *
|
|
cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[0]);
|
|
bias = weight *
|
|
cpu_half2float(reinterpret_cast<const float16*>(scale_bias)[1]);
|
|
}
|
|
|
|
for (int j = 0; j < block_size; ++j) {
|
|
buf[j] = std::fma(
|
|
scale,
|
|
input
|
|
[input_stride * idx + j +
|
|
(scale_bias_last ? 0 : 2 * sizeof(float16))],
|
|
buf[j] + bias);
|
|
}
|
|
}
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
for (int j = 0; j < block_size; ++j) {
|
|
buf[j] *= scale;
|
|
}
|
|
}
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] = convert_from_float_ref<OutType>(buf[j], is_bf16_out);
|
|
}
|
|
out += output_stride;
|
|
}
|
|
return current == index_size;
|
|
} else {
|
|
if (input_stride == -1) {
|
|
input_stride = block_size;
|
|
}
|
|
|
|
if (no_bag) {
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(buf.data(), 0, sizeof(float) * block_size);
|
|
int64_t idx = indices[m];
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
|
|
float w = 1.f;
|
|
if (weights) {
|
|
w = weights[m];
|
|
}
|
|
|
|
for (int j = 0; j < block_size; ++j) {
|
|
const InType* inptr = input + input_stride * idx + j;
|
|
buf[j] =
|
|
std::fma(w, convert_to_float_ref(*inptr, is_bf16_in), buf[j]);
|
|
}
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] = convert_from_float_ref<OutType>(buf[j], is_bf16_out);
|
|
}
|
|
out += output_stride;
|
|
} // m
|
|
return true;
|
|
} // no_bag
|
|
|
|
// Reference implementation of FP32 SLS
|
|
int64_t current = 0;
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(buf.data(), 0, sizeof(float) * block_size);
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
if (current + len > index_size) {
|
|
return false;
|
|
}
|
|
for (int i = 0; i < len; ++i) {
|
|
int64_t idx = indices[current];
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
|
|
float w = 1.f;
|
|
if (weights) {
|
|
w = weights[is_weight_positional ? i : current];
|
|
}
|
|
|
|
for (int j = 0; j < block_size; ++j) {
|
|
const InType* inptr = input + input_stride * idx + j;
|
|
buf[j] =
|
|
std::fma(w, convert_to_float_ref(*inptr, is_bf16_in), buf[j]);
|
|
}
|
|
|
|
++current;
|
|
}
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
for (int j = 0; j < block_size; ++j) {
|
|
buf[j] *= scale;
|
|
}
|
|
}
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] = convert_from_float_ref<OutType>(buf[j], is_bf16_out);
|
|
}
|
|
out += output_stride;
|
|
}
|
|
return current == index_size;
|
|
}
|
|
}
|
|
|
|
template <typename IndexType, typename OffsetType, typename OutType>
|
|
bool EmbeddingSpMDMNBit_ref(
|
|
int input_bit_rate,
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const uint8_t* input,
|
|
const IndexType* indices,
|
|
const OffsetType* offsets_or_lengths,
|
|
const float* weights, // optional, can be null for non-weighted sum
|
|
bool normalize_by_lengths,
|
|
OutType* out,
|
|
bool is_weight_positional /*=false*/,
|
|
bool use_offsets /*=true*/,
|
|
int64_t output_stride /*=-1*/,
|
|
int64_t input_stride /*=-1*/,
|
|
const bool scale_bias_last /*=true*/,
|
|
const bool is_bf16_out /*=false*/,
|
|
const bool no_bag /*=false*/,
|
|
int output_bit_rate /*=-1*/) {
|
|
if (output_bit_rate == -1) {
|
|
output_bit_rate = 8 * sizeof(OutType);
|
|
}
|
|
nbit_embedding_sanity_check<OutType>(input_bit_rate, output_bit_rate, no_bag);
|
|
int num_elem_per_byte = 8 / input_bit_rate;
|
|
|
|
if (output_stride == -1) {
|
|
output_stride = block_size;
|
|
}
|
|
|
|
// block_size is the number of elements and fused_block_size is the size of
|
|
// an entire row, including scale and bias.
|
|
const auto scale_bias_offset = 2 * sizeof(float16);
|
|
if (input_stride == -1) {
|
|
input_stride = (block_size + num_elem_per_byte - 1) / num_elem_per_byte +
|
|
scale_bias_offset;
|
|
}
|
|
|
|
if (no_bag) {
|
|
// We currently only support int4 to int4 for sequential TBE in this nbit
|
|
// kernel. Note that assert() will be ignored in release mode, so we check
|
|
// here to double check and also avoid "unused variable" warning
|
|
if (!(input_bit_rate == 4 && output_bit_rate == 4)) {
|
|
WARN_ONCE("no_bag is only supported for int4 to int4");
|
|
return false;
|
|
}
|
|
for (int64_t i = 0; i < output_size; ++i) {
|
|
const auto idx = indices[i];
|
|
if (idx < 0 || idx > data_size) {
|
|
return false;
|
|
}
|
|
const uint8_t* input_row = input + input_stride * idx;
|
|
memcpy(out, input_row, sizeof(uint8_t) * input_stride);
|
|
out += input_stride;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
int64_t current = 0;
|
|
vector<float> buf(block_size);
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(buf.data(), 0, sizeof(float) * block_size);
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
if (current + len > index_size) {
|
|
return false;
|
|
}
|
|
for (int i = 0; i < len; ++i, ++current) {
|
|
int64_t idx = indices[current];
|
|
if (!scale_bias_last && idx == -1) {
|
|
// When scale_bias_last == false, assume this is for table batched
|
|
// embedding (TBE) that can get -1 for pruned rows.
|
|
continue;
|
|
}
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
|
|
const float16* scale_bias = reinterpret_cast<const float16*>(
|
|
input + input_stride * idx +
|
|
(scale_bias_last
|
|
? (block_size + num_elem_per_byte - 1) / num_elem_per_byte
|
|
: 0));
|
|
|
|
float weight = 1.0f;
|
|
if (weights) {
|
|
weight = weights[is_weight_positional ? i : current];
|
|
}
|
|
const float scale = weight * cpu_half2float(scale_bias[0]);
|
|
const float bias = weight * cpu_half2float(scale_bias[1]);
|
|
|
|
for (int j = 0; j < block_size; ++j) {
|
|
uint8_t quantized = input
|
|
[input_stride * idx + j / num_elem_per_byte +
|
|
(scale_bias_last ? 0 : scale_bias_offset)];
|
|
quantized >>= (j % num_elem_per_byte) * input_bit_rate;
|
|
quantized &= (1 << input_bit_rate) - 1;
|
|
|
|
buf[j] = std::fma(scale, quantized, buf[j] + bias);
|
|
}
|
|
}
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
for (int j = 0; j < block_size; ++j) {
|
|
buf[j] *= scale;
|
|
}
|
|
}
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] = convert_from_float_ref<OutType>(buf[j], is_bf16_out);
|
|
}
|
|
out += output_stride;
|
|
}
|
|
return current == index_size;
|
|
}
|
|
|
|
template <typename IndexType, typename OffsetType, typename OutType>
|
|
bool EmbeddingSpMDMFP8_ref(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const uint8_t* input,
|
|
const IndexType* indices,
|
|
const OffsetType* offsets_or_lengths,
|
|
const float* weights,
|
|
bool normalize_by_lengths,
|
|
OutType* out,
|
|
bool is_weight_positional,
|
|
bool use_offsets,
|
|
int64_t output_stride,
|
|
int64_t input_stride,
|
|
int exponent_bits,
|
|
int exponent_bias,
|
|
bool is_bf16_out /*=false*/) {
|
|
if (output_stride == -1) {
|
|
output_stride = block_size;
|
|
}
|
|
|
|
vector<float> buf(block_size);
|
|
|
|
if (input_stride == -1) {
|
|
input_stride = block_size;
|
|
}
|
|
|
|
// Reference implementation of FP8 SLS. The algorithm is similar to FP32 SLS
|
|
// except for the FP8->FP32 conversion after reading the embedding weight.
|
|
int64_t current = 0;
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(buf.data(), 0, sizeof(float) * block_size);
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
if (current + len > index_size) {
|
|
return false;
|
|
}
|
|
for (int i = 0; i < len; ++i) {
|
|
int64_t idx = indices[current];
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
|
|
float w = 1.f;
|
|
if (weights) {
|
|
w = weights[is_weight_positional ? i : current];
|
|
}
|
|
|
|
for (int j = 0; j < block_size; ++j) {
|
|
const uint8_t* inptr = input + input_stride * idx + j;
|
|
float input_f;
|
|
// Dequantize FP8 to FP32 before compute
|
|
Float8ToFloat_ref(*inptr, &input_f, exponent_bits, exponent_bias);
|
|
buf[j] = std::fma(w, input_f, buf[j]);
|
|
}
|
|
|
|
++current;
|
|
}
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
for (int j = 0; j < block_size; ++j) {
|
|
buf[j] *= scale;
|
|
}
|
|
}
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] = is_same<OutType, uint16_t>::value
|
|
? convert_from_float_ref<OutType>(buf[j], is_bf16_out)
|
|
: buf[j];
|
|
}
|
|
out += output_stride;
|
|
}
|
|
return current == index_size;
|
|
}
|
|
|
|
template <typename InType, typename IndexType, typename OffsetType>
|
|
bool EmbeddingSpMDMRowWiseSparse_ref(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t uncompressed_data_size,
|
|
// const int64_t compressed_data_size,
|
|
const InType* input,
|
|
const IndexType* indices,
|
|
const int32_t* compressed_indices_table,
|
|
const OffsetType* offsets_or_lengths,
|
|
const float* weights, // optional, can be null for non-weighted sum
|
|
bool normalize_by_lengths,
|
|
float* out,
|
|
bool is_weight_positional,
|
|
bool use_offsets) {
|
|
bool is8bit = is_same<InType, uint8_t>::value;
|
|
|
|
if (is8bit) {
|
|
// block_size is the number of elements and fused_block_size is the size
|
|
// of an entire row, including scale and bias.
|
|
const auto scale_bias_offset = 2 * sizeof(float);
|
|
const int64_t fused_block_size = block_size + scale_bias_offset;
|
|
int64_t current = 0;
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(out, 0, sizeof(float) * block_size);
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
if (current + len > index_size) {
|
|
return false;
|
|
}
|
|
for (int i = 0; i < len; ++i) {
|
|
IndexType uncompressed_idx = indices[current];
|
|
if (uncompressed_idx < 0 ||
|
|
uncompressed_idx >= uncompressed_data_size) {
|
|
return false;
|
|
}
|
|
IndexType idx = compressed_indices_table[uncompressed_idx];
|
|
if (idx == -1) {
|
|
++current;
|
|
continue;
|
|
}
|
|
// if (idx < 0 || idx >= compressed_data_size) {
|
|
// return false;
|
|
// }
|
|
|
|
const float* scale_bias = reinterpret_cast<const float*>(
|
|
input + fused_block_size * idx + block_size);
|
|
|
|
float weight = 1.0f;
|
|
if (weights) {
|
|
weight = weights[is_weight_positional ? i : current];
|
|
}
|
|
const float scale = weight * scale_bias[0];
|
|
const float bias = weight * scale_bias[1];
|
|
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] =
|
|
std::fma(scale, input[fused_block_size * idx + j], out[j] + bias);
|
|
}
|
|
|
|
++current;
|
|
}
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] *= scale;
|
|
}
|
|
}
|
|
out += block_size;
|
|
}
|
|
return current == index_size;
|
|
} else {
|
|
// Reference implementation of FP32 SLS
|
|
int64_t current = 0;
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(out, 0, sizeof(float) * block_size);
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
if (current + len > index_size) {
|
|
return false;
|
|
}
|
|
for (int i = 0; i < len; ++i) {
|
|
IndexType uncompressed_idx = indices[current];
|
|
if (uncompressed_idx < 0 ||
|
|
uncompressed_idx >= uncompressed_data_size) {
|
|
return false;
|
|
}
|
|
IndexType idx = compressed_indices_table[uncompressed_idx];
|
|
if (idx == -1) {
|
|
++current;
|
|
continue;
|
|
}
|
|
// if (idx < 0 || idx >= compressed_data_size) {
|
|
// return false;
|
|
// }
|
|
|
|
float w = 1.f;
|
|
if (weights) {
|
|
w = weights[is_weight_positional ? i : current];
|
|
}
|
|
|
|
for (int j = 0; j < block_size; ++j) {
|
|
const InType* inptr = input + block_size * idx + j;
|
|
out[j] = std::fma(
|
|
w,
|
|
is_same<InType, float16>::value ? cpu_half2float(*inptr) : *inptr,
|
|
out[j]);
|
|
}
|
|
|
|
++current;
|
|
}
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] *= scale;
|
|
}
|
|
}
|
|
out += block_size;
|
|
}
|
|
return current == index_size;
|
|
}
|
|
}
|
|
|
|
template <typename IndexType, typename OffsetType>
|
|
bool EmbeddingSpMDMNBitRowWiseSparse_ref(
|
|
int bit_rate,
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t uncompressed_data_size,
|
|
// const int64_t compressed_data_size,
|
|
const uint8_t* input,
|
|
const IndexType* indices,
|
|
const int32_t* compressed_indices_table,
|
|
const OffsetType* offsets_or_lengths,
|
|
const float* weights, // optional, can be null for non-weighted sum
|
|
bool normalize_by_lengths,
|
|
float* out,
|
|
bool is_weight_positional,
|
|
bool use_offsets) {
|
|
assert((bit_rate == 2 || bit_rate == 4) && "bit_rate must be 2 or 4");
|
|
int num_elem_per_byte = 8 / bit_rate;
|
|
|
|
// block_size is the number of elements and fused_block_size is the size of
|
|
// an entire row, including scale and bias.
|
|
const auto scale_bias_offset = 2 * sizeof(float16);
|
|
const int64_t fused_block_size =
|
|
(block_size + num_elem_per_byte - 1) / num_elem_per_byte +
|
|
scale_bias_offset;
|
|
int64_t current = 0;
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(out, 0, sizeof(float) * block_size);
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
if (current + len > index_size) {
|
|
return false;
|
|
}
|
|
for (int i = 0; i < len; ++i, ++current) {
|
|
IndexType uncompressed_idx = indices[current];
|
|
if (uncompressed_idx < 0 || uncompressed_idx >= uncompressed_data_size) {
|
|
return false;
|
|
}
|
|
IndexType idx = compressed_indices_table[uncompressed_idx];
|
|
if (idx == -1) {
|
|
continue;
|
|
}
|
|
// if (idx < 0 || idx >= compressed_data_size) {
|
|
// return false;
|
|
// }
|
|
|
|
const float16* scale_bias = reinterpret_cast<const float16*>(
|
|
input + fused_block_size * idx +
|
|
(block_size + num_elem_per_byte - 1) / num_elem_per_byte);
|
|
|
|
float weight = 1.0f;
|
|
if (weights) {
|
|
weight = weights[is_weight_positional ? i : current];
|
|
}
|
|
const float scale = weight * cpu_half2float(scale_bias[0]);
|
|
const float bias = weight * cpu_half2float(scale_bias[1]);
|
|
|
|
for (int j = 0; j < block_size; ++j) {
|
|
uint8_t quantized =
|
|
input[fused_block_size * idx + j / num_elem_per_byte];
|
|
quantized >>= (j % num_elem_per_byte) * bit_rate;
|
|
quantized &= (1 << bit_rate) - 1;
|
|
|
|
out[j] = std::fma(scale, quantized, out[j] + bias);
|
|
}
|
|
}
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] *= scale;
|
|
}
|
|
}
|
|
out += block_size;
|
|
}
|
|
return current == index_size;
|
|
}
|
|
|
|
template <typename IndexType>
|
|
int sparse_adagrad_ref(
|
|
int num_rows, // number of rows reading
|
|
int block_size, // number of parameters per rows
|
|
uint64_t param_size, // total number of parameters
|
|
float* w, // input parameters
|
|
const float* g, // input gradients
|
|
float* h, // input momentums
|
|
const IndexType* indices, // indices of each row
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay,
|
|
const double* counter,
|
|
const int64_t counter_halflife) {
|
|
for (auto i = 0; i < num_rows; ++i) {
|
|
uint64_t idx = indices[i];
|
|
auto offsetI = i * block_size;
|
|
auto offsetIdx = idx * block_size;
|
|
|
|
if (block_size + offsetIdx > param_size) {
|
|
return i;
|
|
}
|
|
|
|
float freq =
|
|
(counter && counter[idx] > 0) ? counter_halflife / counter[idx] : 1.0;
|
|
|
|
const float* g_;
|
|
const float* h_;
|
|
const float* w_;
|
|
float* nh_;
|
|
float* nw_;
|
|
|
|
g_ = g + offsetI;
|
|
h_ = h + offsetIdx;
|
|
w_ = w + offsetIdx;
|
|
nh_ = h + offsetIdx;
|
|
nw_ = w + offsetIdx;
|
|
|
|
for (auto j = 0; j < block_size; ++j) {
|
|
float gj = std::fma(weight_decay * freq, w_[j], g_[j]);
|
|
float hj = h_[j] + gj * gj;
|
|
nh_[j] = hj;
|
|
nw_[j] = w_[j] + lr * gj / (std::sqrt(hj) + epsilon);
|
|
}
|
|
}
|
|
return num_rows;
|
|
}
|
|
|
|
template <typename IndexType>
|
|
int rowwise_sparse_adagrad_ref(
|
|
int num_rows, // number of rows reading
|
|
int block_size, // number of parameters per rows
|
|
uint64_t param_size, // total number of parameters
|
|
float* w, // input parameters
|
|
const float* g, // input gradients
|
|
float* h, // input momentums
|
|
const IndexType* indices, // indices of each row
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay,
|
|
const double* counter,
|
|
const int64_t counter_halflife) {
|
|
for (auto i = 0; i < num_rows; ++i) {
|
|
uint64_t idx = indices[i];
|
|
auto offsetI = i * block_size;
|
|
auto offsetIdx = idx * block_size;
|
|
|
|
if (block_size + offsetIdx > param_size) {
|
|
return i;
|
|
}
|
|
|
|
float freq =
|
|
(counter && counter[idx] > 0) ? counter_halflife / counter[idx] : 1.0;
|
|
|
|
const float* g_;
|
|
float* h_;
|
|
float* w_;
|
|
|
|
g_ = g + offsetI;
|
|
h_ = h + idx; // This is different from sparse adagrad
|
|
w_ = w + offsetIdx;
|
|
|
|
float final_sum = 0.0f;
|
|
// Note the following code assumes fbgemm will generate AVX2 code for
|
|
// horizontal reduction, which is OK for now because fbgemm always uses
|
|
// AVX2 for SparseAdagrad due to its performance is bounded by memory
|
|
// bandwidth hence no speedup from AVX512. Non-vectorized version would be
|
|
// just for (auto j = 0; j < block_size; ++j) {
|
|
// float gj = g_[j];
|
|
// final_sum += gj * gj;
|
|
// }
|
|
constexpr int VLEN = 8;
|
|
array<float, VLEN> partial_sum = {0.0f};
|
|
for (auto j = 0; j < block_size; ++j) {
|
|
float gj = std::fma(weight_decay * freq, w_[j], g_[j]);
|
|
partial_sum[j % VLEN] += gj * gj;
|
|
}
|
|
final_sum = ((partial_sum[0] + partial_sum[1]) +
|
|
(partial_sum[2] + partial_sum[3])) +
|
|
((partial_sum[4] + partial_sum[5]) + (partial_sum[6] + partial_sum[7]));
|
|
final_sum /= block_size;
|
|
float hi = *h_ = *h_ + final_sum;
|
|
float float_step = lr / (std::sqrt(hi) + epsilon);
|
|
|
|
for (auto j = 0; j < block_size; ++j) {
|
|
float gj = std::fma(weight_decay * freq, w_[j], g_[j]);
|
|
w_[j] += gj * float_step;
|
|
}
|
|
}
|
|
return num_rows;
|
|
}
|
|
|
|
template <typename DataType, typename IndexType, typename OffsetType>
|
|
int rowwise_sparse_adagrad_fused_ref(
|
|
int64_t block_size,
|
|
int64_t output_size,
|
|
int64_t index_size,
|
|
int64_t data_size,
|
|
DataType* w,
|
|
const float* g,
|
|
float* h,
|
|
const IndexType* indices,
|
|
const OffsetType* offsets_or_lengths,
|
|
float epsilon,
|
|
float lr,
|
|
bool use_offsets,
|
|
bool use_stochastic_rounding,
|
|
int emu_vector_size,
|
|
int64_t grad_stride) {
|
|
if (grad_stride == -1) {
|
|
grad_stride = block_size;
|
|
}
|
|
|
|
constexpr bool isFloat16w = std::is_same<float16, DataType>::value;
|
|
// Local random buffer to emulate SIMD vector
|
|
// R: generated 32bit base random numbers
|
|
// r: extracted 8-bit for rounding
|
|
constexpr int VLEN_MAX = 16;
|
|
uint32_t R[VLEN_MAX], r[VLEN_MAX];
|
|
int vlen = emu_vector_size;
|
|
if (vlen != 8 && vlen != 16) {
|
|
// Raise error as it may cause buffer overflow
|
|
cerr << "Not supported emu_vector_size: " << emu_vector_size << endl;
|
|
return 0;
|
|
}
|
|
|
|
int64_t current = 0;
|
|
for (int m = 0; m < output_size; ++m) {
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
if (current + len > index_size) {
|
|
return false;
|
|
}
|
|
const float* g_ = g + m * grad_stride;
|
|
// Note the following code assumes fbgemm will generate AVX2 code for
|
|
// horizontal reduction, which is OK for now because fbgemm always uses
|
|
// AVX2 for SparseAdagrad due to its performance is bounded by memory
|
|
// bandwidth hence no speedup from AVX512. Non-vectorized version would be
|
|
// just for (auto j = 0; j < block_size; ++j) {
|
|
// float gj = g_[j];
|
|
// final_sum += gj * gj;
|
|
// }
|
|
constexpr int VLEN_AVX2 = 8;
|
|
array<float, VLEN_AVX2> partial_sum = {0.0f};
|
|
for (auto j = 0; j < block_size; ++j) {
|
|
float gj = g_[j];
|
|
partial_sum[j % VLEN_AVX2] += gj * gj;
|
|
}
|
|
float final_sum = ((partial_sum[0] + partial_sum[1]) +
|
|
(partial_sum[2] + partial_sum[3])) +
|
|
((partial_sum[4] + partial_sum[5]) + (partial_sum[6] + partial_sum[7]));
|
|
final_sum /= block_size;
|
|
|
|
for (int i = 0; i < len; ++i, ++current) {
|
|
int64_t idx = indices[current];
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
|
|
float* h_ = h + idx;
|
|
DataType* w_ = w + idx * block_size;
|
|
|
|
float hi = *h_ = *h_ + final_sum;
|
|
float float_step = lr / (std::sqrt(hi) + epsilon);
|
|
|
|
int nvec = (block_size + vlen - 1) / vlen;
|
|
int rem = (block_size % vlen) ? (block_size % vlen) : vlen;
|
|
|
|
// Emulate JIT behavior of stochastic rounding with vector-length
|
|
//
|
|
// Generate R buffer every 4 steps of nvec loop. Each 8-bit in R
|
|
// (uint32_t) will be used once. It is shifted to bits[5..13] then
|
|
// added to FP32 weights before FP16 conversion.
|
|
//
|
|
// The shifted 8 bit region
|
|
// +-------+--------+--------+--------+
|
|
// | | | xxxxx|xxx |
|
|
// 31 23 15 7 0
|
|
//
|
|
// Half float has 10 bits of mantissa, and float has 23, we are shifting
|
|
// the bits to cover the region where half floats can't represent data.
|
|
// This is bit 13-23 of the mantissa of fp32.
|
|
// This will be effectively adding a random variable of [0,1]
|
|
|
|
for (int n = 0; n < nvec; ++n) {
|
|
int cur_vlen = (n == nvec - 1) ? rem : vlen;
|
|
int sr_idx = n % 4;
|
|
|
|
if (isFloat16w && use_stochastic_rounding) {
|
|
if (sr_idx == 0) {
|
|
for (int v = 0; v < vlen; ++v) {
|
|
R[v] = rnd128_next(v, vlen);
|
|
r[v] = (R[v] & 0xFFU) << 5;
|
|
}
|
|
} else if (sr_idx == 1) {
|
|
for (int v = 0; v < vlen; ++v) {
|
|
r[v] = ((R[v] & 0xFF00U) >> 8) << 5;
|
|
}
|
|
} else if (sr_idx == 2) {
|
|
for (int v = 0; v < vlen; ++v) {
|
|
r[v] = ((R[v] & 0xFF0000U) >> 16) << 5;
|
|
}
|
|
} else { // 3
|
|
for (int v = 0; v < vlen; ++v) {
|
|
r[v] = ((R[v] & 0xFF000000U) >> 24) << 5;
|
|
}
|
|
}
|
|
}
|
|
|
|
for (int v = 0; v < cur_vlen; ++v) {
|
|
int j = n * vlen + v;
|
|
if (isFloat16w) {
|
|
union {
|
|
float w_f32;
|
|
uint32_t w_i32;
|
|
};
|
|
w_f32 = cpu_half2float(w_[j]);
|
|
w_f32 = std::fma(float_step, g_[j], w_f32);
|
|
if (use_stochastic_rounding) {
|
|
w_i32 += r[v];
|
|
}
|
|
// Use truncate rounding to 'counterwork' the random added part
|
|
w_[j] = cpu_float2half_rz(w_f32);
|
|
} else { // float
|
|
w_[j] += g_[j] * float_step;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return current == index_size;
|
|
}
|
|
|
|
template FBGEMM_API void transposeConvWeights(
|
|
const conv_param_t<1>& conv_p,
|
|
const std::int8_t* src,
|
|
std::int8_t* dest);
|
|
|
|
template FBGEMM_API void transposeConvWeights(
|
|
const conv_param_t<2>& conv_p,
|
|
const std::int8_t* src,
|
|
std::int8_t* dest);
|
|
|
|
template FBGEMM_API void transposeConvWeights(
|
|
const conv_param_t<3>& conv_p,
|
|
const std::int8_t* src,
|
|
std::int8_t* dest);
|
|
|
|
#define INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
|
|
template FBGEMM_API bool EmbeddingSpMDM_ref( \
|
|
const int64_t block_size, \
|
|
const int64_t output_size, \
|
|
const int64_t index_size, \
|
|
const int64_t data_size, \
|
|
const IN_TYPE* input, \
|
|
const INDEX_TYPE* indices, \
|
|
const OFFSET_TYPE* offsets_or_lengths, \
|
|
const float* weights, \
|
|
bool normalize_by_lengths, \
|
|
OUT_TYPE* out, \
|
|
bool is_weight_positional, \
|
|
bool use_offsets, \
|
|
int64_t input_stride, \
|
|
int64_t output_stride, \
|
|
bool scale_bias_last, \
|
|
bool no_bag, \
|
|
bool is_bf16_out, \
|
|
bool is_bf16_in);
|
|
|
|
#define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \
|
|
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \
|
|
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float16) \
|
|
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, std::uint8_t) \
|
|
template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( \
|
|
const int64_t block_size, \
|
|
const int64_t output_size, \
|
|
const int64_t index_size, \
|
|
const int64_t uncompressed_data_size, \
|
|
const IN_TYPE* input, \
|
|
const INDEX_TYPE* indices, \
|
|
const int32_t* compressed_indices_table, \
|
|
const OFFSET_TYPE* offsets_or_lengths, \
|
|
const float* weights, \
|
|
bool normalize_by_lengths, \
|
|
float* out, \
|
|
bool is_weight_positional, \
|
|
bool use_offsets);
|
|
|
|
#define INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, INDEX_TYPE) \
|
|
INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, std::int32_t) \
|
|
INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, std::int64_t)
|
|
|
|
#define INSTANTIATE_SPMDM_INDEX_T(IN_TYPE) \
|
|
INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int32_t) \
|
|
INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int64_t)
|
|
|
|
INSTANTIATE_SPMDM_INDEX_T(float)
|
|
INSTANTIATE_SPMDM_INDEX_T(float16)
|
|
INSTANTIATE_SPMDM_INDEX_T(std::uint8_t)
|
|
|
|
#undef INSTANTIATE_SPMDM_INDEX_T
|
|
#undef INSTANTIATE_SPMDM_OFFSET_T
|
|
#undef INSTANTIATE_SPMDM_OUT_T
|
|
#undef INSTANTIATE_SPMDM_BASE
|
|
|
|
#define INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
|
|
template FBGEMM_API bool EmbeddingSpMDMNBit_ref( \
|
|
const int input_bit_rate, \
|
|
const int64_t block_size, \
|
|
const int64_t output_size, \
|
|
const int64_t index_size, \
|
|
const int64_t data_size, \
|
|
const uint8_t* input, \
|
|
const INDEX_TYPE* indices, \
|
|
const OFFSET_TYPE* offsets_or_lengths, \
|
|
const float* weights, \
|
|
bool normalize_by_lengths, \
|
|
OUT_TYPE* out, \
|
|
bool is_weight_positional, \
|
|
bool use_offsets, \
|
|
int64_t output_stride, \
|
|
int64_t input_stride, \
|
|
const bool scale_bias_last, \
|
|
const bool is_bf16_out, \
|
|
const bool no_bag, \
|
|
int output_bit_rate);
|
|
#define INSTANTIATE_SPMDM_FP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
|
|
template FBGEMM_API bool EmbeddingSpMDMFP8_ref( \
|
|
const int64_t block_size, \
|
|
const int64_t output_size, \
|
|
const int64_t index_size, \
|
|
const int64_t data_size, \
|
|
const uint8_t* input, \
|
|
const INDEX_TYPE* indices, \
|
|
const OFFSET_TYPE* offsets_or_lengths, \
|
|
const float* weights, \
|
|
bool normalize_by_lengths, \
|
|
OUT_TYPE* out, \
|
|
bool is_weight_positional, \
|
|
bool use_offsets, \
|
|
int64_t output_stride, \
|
|
int64_t input_stride, \
|
|
int exponent_bits, \
|
|
int exponent_bias, \
|
|
bool is_bf16_out);
|
|
|
|
#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \
|
|
INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, float) \
|
|
INSTANTIATE_SPMDM_FP8_BASE(INDEX_TYPE, OFFSET_TYPE, float) \
|
|
INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \
|
|
INSTANTIATE_SPMDM_FP8_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \
|
|
INSTANTIATE_SPMDM_NBIT_BASE(INDEX_TYPE, OFFSET_TYPE, uint8_t) \
|
|
template FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref( \
|
|
int bit_rate, \
|
|
const int64_t block_size, \
|
|
const int64_t output_size, \
|
|
const int64_t index_size, \
|
|
const int64_t uncompressed_data_size, \
|
|
const uint8_t* input, \
|
|
const INDEX_TYPE* indices, \
|
|
const int32_t* compressed_indices_table, \
|
|
const OFFSET_TYPE* offsets_or_lengths, \
|
|
const float* weights, \
|
|
bool normalize_by_lengths, \
|
|
float* out, \
|
|
bool is_weight_positional, \
|
|
bool use_offsets);
|
|
|
|
#define INSTANTIATE_SPMDM_OFFSET_T(INDEX_TYPE) \
|
|
INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int32_t) \
|
|
INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int64_t)
|
|
|
|
INSTANTIATE_SPMDM_OFFSET_T(int32_t)
|
|
INSTANTIATE_SPMDM_OFFSET_T(int64_t)
|
|
|
|
#undef INSTANTIATE_SPMDM_OFFSET_T
|
|
#undef INSTANTIATE_SPMDM_OUT_T
|
|
#undef INSTANTIATE_SPMDM_BASE
|
|
|
|
template FBGEMM_API int sparse_adagrad_ref(
|
|
int num_rows, // number of rows reading
|
|
int block_size, // number of parameters per rows
|
|
std::uint64_t param_size, // total number of parameters
|
|
float* w, // input parameters
|
|
const float* g, // input gradients
|
|
float* h, // input momentums
|
|
const std::int64_t* indices, // indices of each row
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay,
|
|
const double* counter,
|
|
const int64_t counter_halflife);
|
|
|
|
template FBGEMM_API int sparse_adagrad_ref(
|
|
int num_rows, // number of rows reading
|
|
int block_size, // number of parameters per rows
|
|
std::uint64_t param_size, // total number of parameters
|
|
float* w, // input parameters
|
|
const float* g, // input gradients
|
|
float* h, // input momentums
|
|
const std::int32_t* indices, // indices of each row
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay,
|
|
const double* counter,
|
|
const int64_t counter_halflife);
|
|
|
|
template FBGEMM_API int rowwise_sparse_adagrad_ref(
|
|
int num_rows, // number of rows reading
|
|
int block_size, // number of parameters per rows
|
|
std::uint64_t param_size, // total number of parameters
|
|
float* w, // input parameters
|
|
const float* g, // input gradients
|
|
float* h, // input momentums
|
|
const std::int64_t* indices, // indices of each row
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay,
|
|
const double* counter,
|
|
const int64_t counter_halflife);
|
|
|
|
template FBGEMM_API int rowwise_sparse_adagrad_ref(
|
|
int num_rows, // number of rows reading
|
|
int block_size, // number of parameters per rows
|
|
std::uint64_t param_size, // total number of parameters
|
|
float* w, // input parameters
|
|
const float* g, // input gradients
|
|
float* h, // input momentums
|
|
const std::int32_t* indices, // indices of each row
|
|
float epsilon,
|
|
float lr,
|
|
float weight_decay,
|
|
const double* counter,
|
|
const int64_t counter_halflife);
|
|
|
|
#define INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, OFFSET_TYPE) \
|
|
template FBGEMM_API int rowwise_sparse_adagrad_fused_ref( \
|
|
int64_t block_size, \
|
|
int64_t output_size, \
|
|
int64_t index_size, \
|
|
int64_t data_size, \
|
|
DATA_TYPE* w, \
|
|
const float* g, \
|
|
float* h, \
|
|
const INDEX_TYPE* indices, \
|
|
const OFFSET_TYPE* offsets_or_lengths, \
|
|
float epsilon, \
|
|
float lr, \
|
|
bool use_offsets, \
|
|
bool use_stochastic_rounding, \
|
|
int emu_vector_size, \
|
|
int64_t grad_stride);
|
|
|
|
#define INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, INDEX_TYPE) \
|
|
INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, int32_t) \
|
|
INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, int64_t)
|
|
|
|
#define INSTANTIATE_SPMDM_INDEX_T(DATA_TYPE) \
|
|
INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, int32_t) \
|
|
INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, int64_t)
|
|
|
|
INSTANTIATE_SPMDM_INDEX_T(float)
|
|
INSTANTIATE_SPMDM_INDEX_T(float16)
|
|
|
|
#undef INSTANTIATE_SPMDM_OFFSET_T
|
|
#undef INSTANTIATE_SPMDM_BASE
|
|
|
|
template <typename IndexType>
|
|
FBGEMM_API void compressed_indices_remap_ref(
|
|
std::int32_t offsets_numel,
|
|
const IndexType* indices,
|
|
const int32_t* compressed_indices_mapping,
|
|
const IndexType* offsets,
|
|
const float* weights, // optional, can be null,
|
|
IndexType* out_indices,
|
|
IndexType* out_offsets,
|
|
float* out_weights) {
|
|
bool has_per_sample_weights = (weights != nullptr);
|
|
out_offsets[0] = offsets[0];
|
|
IndexType j = 0;
|
|
for (int i = 1; i < offsets_numel; i++) {
|
|
for (int32_t k = offsets[i - 1]; k < offsets[i]; k++) {
|
|
if (compressed_indices_mapping[indices[k]] != -1) {
|
|
out_indices[j] = compressed_indices_mapping[indices[k]];
|
|
if (has_per_sample_weights) {
|
|
out_weights[j] = weights[k];
|
|
}
|
|
j++;
|
|
}
|
|
}
|
|
out_offsets[i] = j;
|
|
}
|
|
}
|
|
|
|
#define INSTANTIATE_REMAP_BASE(INDEX_TYPE) \
|
|
template FBGEMM_API void compressed_indices_remap_ref( \
|
|
std::int32_t offsets_numel, \
|
|
const INDEX_TYPE* indices, \
|
|
const int32_t* compressed_indices_mapping, \
|
|
const INDEX_TYPE* offsets, \
|
|
const float* weights, \
|
|
INDEX_TYPE* out_indices, \
|
|
INDEX_TYPE* out_offsets, \
|
|
float* out_weights);
|
|
|
|
INSTANTIATE_REMAP_BASE(int32_t)
|
|
INSTANTIATE_REMAP_BASE(int64_t)
|
|
|
|
#undef INSTANTIATE_REMAP_BASE
|
|
|
|
} // namespace fbgemm
|