1839 lines
79 KiB
C++
1839 lines
79 KiB
C++
/*
|
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD-style license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
#ifdef __linux__
|
|
|
|
#define FBGEMM_EXPORTS
|
|
#include "./EmbeddingSpMDMAutovec.h"
|
|
#include "./RefImplementations.h"
|
|
#include "fbgemm/FbgemmBuild.h"
|
|
#include "fbgemm/FloatConversion.h"
|
|
|
|
#include <algorithm>
|
|
#include <array>
|
|
#include <cassert>
|
|
#include <cmath>
|
|
#include <cstring>
|
|
#include <new>
|
|
#include <numeric>
|
|
#include <thread>
|
|
|
|
/// @defgroup tbe-cpu-autovec TBE CPU Autovectorization (FP8/16/32)
|
|
|
|
#ifdef _WIN32
|
|
#define do_prefetch(...)
|
|
#else
|
|
#define do_prefetch(...) __builtin_prefetch(__VA_ARGS__)
|
|
#endif
|
|
|
|
#ifdef __clang__
|
|
// https://github.com/llvm/llvm-project/issues/114891 / T206675074
|
|
// Work around LLVM loop vectorization not produce optimal code when
|
|
// `block_size` is not a multiple of the natural vector size.
|
|
#ifdef __AVX512F__
|
|
#define FBGEMM_VECTOR_WIDTH 16
|
|
#elif __AVX2__
|
|
#define FBGEMM_VECTOR_WIDTH 8
|
|
#elif __SSE__
|
|
#define FBGEMM_VECTOR_WIDTH 4
|
|
#endif
|
|
#endif // #ifdef __clang__
|
|
|
|
namespace fbgemm {
|
|
|
|
static constexpr size_t LOCAL_STORAGE_SIZE = 512;
|
|
|
|
template <typename OutType>
|
|
static inline void fill_output(
|
|
OutType* out,
|
|
const float* src,
|
|
const int64_t block_size,
|
|
const bool is_bf16_out) {
|
|
if (std::is_same<OutType, float>::value) {
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] = src[j];
|
|
}
|
|
} else if (std::is_same<OutType, uint16_t>::value && is_bf16_out) {
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] = cpu_float2bfloat16(src[j]);
|
|
}
|
|
} else {
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] = cpu_float2half(src[j]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename IndexType, typename OffsetType, typename OutType>
|
|
static bool ALWAYS_INLINE EmbeddingSpMDM8Bit_autovec(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const uint8_t* input,
|
|
const IndexType* indices,
|
|
const OffsetType* offsets_or_lengths,
|
|
const float* weights, // optional, can be null for non-weighted sum
|
|
const bool normalize_by_lengths,
|
|
OutType* out,
|
|
const bool is_weight_positional,
|
|
const bool use_offsets,
|
|
const int64_t output_stride,
|
|
const int64_t input_stride,
|
|
const bool scale_bias_last,
|
|
const bool no_bag,
|
|
const bool is_bf16_out) {
|
|
constexpr bool isOutput8bit = std::is_same<OutType, uint8_t>::value;
|
|
if (data_size < 0) {
|
|
return false;
|
|
}
|
|
if constexpr (isOutput8bit) {
|
|
assert(input_stride == output_stride);
|
|
}
|
|
|
|
constexpr int64_t CACHE_LINE_SIZE = 64;
|
|
constexpr int64_t MAX_INITIAL_PREFETCH_ROWS = 16;
|
|
const int64_t prefetch_stride =
|
|
std::min(MAX_INITIAL_PREFETCH_ROWS, index_size);
|
|
for (int64_t pf_idx = 0; pf_idx < prefetch_stride; ++pf_idx) {
|
|
const uint8_t* prefetch_addr = input + input_stride * indices[pf_idx];
|
|
for (int64_t offset = 0; offset < input_stride; offset += CACHE_LINE_SIZE) {
|
|
do_prefetch(prefetch_addr + offset, 0, 0);
|
|
}
|
|
}
|
|
|
|
const int64_t scale_bias_size = 2 * sizeof(float16);
|
|
const int64_t scale_bias_offset = scale_bias_last ? block_size : 0;
|
|
const int64_t input_offset = scale_bias_last ? 0 : scale_bias_size;
|
|
|
|
std::array<float, LOCAL_STORAGE_SIZE> local_storage;
|
|
std::unique_ptr<float[]> heap_storage;
|
|
float* buf;
|
|
if (static_cast<size_t>(block_size) <= LOCAL_STORAGE_SIZE) {
|
|
buf = local_storage.data();
|
|
} else {
|
|
heap_storage.reset(new float[block_size]);
|
|
buf = heap_storage.get();
|
|
}
|
|
|
|
if (no_bag) {
|
|
for (int64_t m = 0; m < output_size; ++m) {
|
|
const IndexType idx = indices[m];
|
|
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
|
|
const uint8_t* input_row_base = input + input_stride * idx;
|
|
if constexpr (isOutput8bit) {
|
|
memcpy(out, input_row_base, sizeof(uint8_t) * input_stride);
|
|
} else {
|
|
memset(buf, 0, sizeof(float) * block_size);
|
|
|
|
float scale;
|
|
float bias;
|
|
const uint8_t* scale_bias_addr = input_row_base + scale_bias_offset;
|
|
if (scale_bias_last) {
|
|
memcpy(&scale, scale_bias_addr, sizeof(float));
|
|
memcpy(&bias, scale_bias_addr + sizeof(float), sizeof(float));
|
|
} else {
|
|
float16 scale16;
|
|
float16 bias16;
|
|
memcpy(&scale16, scale_bias_addr, sizeof(float16));
|
|
memcpy(&bias16, scale_bias_addr + sizeof(float16), sizeof(float16));
|
|
scale = cpu_half2float(scale16);
|
|
bias = cpu_half2float(bias16);
|
|
}
|
|
if (weights) {
|
|
float weight = weights[m];
|
|
scale *= weight;
|
|
bias *= weight;
|
|
}
|
|
|
|
const uint8_t* input_row = input_row_base + input_offset;
|
|
int64_t j = 0;
|
|
#ifdef FBGEMM_VECTOR_WIDTH
|
|
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
|
|
uint8_t value = input_row[j];
|
|
buf[j] = std::fma(scale, (float)value, buf[j] + bias);
|
|
}
|
|
#endif
|
|
for (; j < block_size; ++j) {
|
|
uint8_t value = input_row[j];
|
|
buf[j] = std::fma(scale, (float)value, buf[j] + bias);
|
|
}
|
|
fill_output(out, buf, block_size, is_bf16_out);
|
|
}
|
|
out += output_stride;
|
|
} // m
|
|
return true;
|
|
} // no_bag
|
|
|
|
int64_t current = 0;
|
|
for (int64_t m = 0; m < output_size; ++m) {
|
|
memset(buf, 0, sizeof(float) * block_size);
|
|
const OffsetType len = use_offsets
|
|
? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
int64_t end = current + len;
|
|
if (end > index_size) {
|
|
return false;
|
|
}
|
|
|
|
const float* weights_addr = weights != nullptr
|
|
? (is_weight_positional ? weights : weights + current)
|
|
: nullptr;
|
|
for (; current < end; ++current) {
|
|
IndexType idx = indices[current];
|
|
|
|
IndexType prefetch_idx =
|
|
indices[std::min(current + prefetch_stride, index_size - 1)];
|
|
const uint8_t* prefetch_addr = input + input_stride * prefetch_idx;
|
|
for (int64_t offset = 0; offset < input_stride;
|
|
offset += CACHE_LINE_SIZE) {
|
|
do_prefetch(prefetch_addr + offset, 1);
|
|
}
|
|
if (idx < 0 || idx >= data_size) {
|
|
if (!scale_bias_last && idx == -1) {
|
|
// When scale_bias_last == false, assume this is for table batched
|
|
// embedding (TBE) that can get -1 for pruned rows.
|
|
continue;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
const uint8_t* input_row_base = input + input_stride * idx;
|
|
|
|
const uint8_t* scale_bias_addr = input_row_base + scale_bias_offset;
|
|
float scale;
|
|
float bias;
|
|
if (scale_bias_last) {
|
|
memcpy(&scale, scale_bias_addr, sizeof(float));
|
|
memcpy(&bias, scale_bias_addr + sizeof(float), sizeof(float));
|
|
} else {
|
|
float16 scale16;
|
|
float16 bias16;
|
|
memcpy(&scale16, scale_bias_addr, sizeof(float16));
|
|
memcpy(&bias16, scale_bias_addr + sizeof(float16), sizeof(float16));
|
|
scale = cpu_half2float(scale16);
|
|
bias = cpu_half2float(bias16);
|
|
}
|
|
|
|
if (weights != nullptr) {
|
|
float weight = *weights_addr++;
|
|
scale *= weight;
|
|
bias *= weight;
|
|
}
|
|
|
|
const uint8_t* input_row = input_row_base + input_offset;
|
|
int64_t j = 0;
|
|
#ifdef FBGEMM_VECTOR_WIDTH
|
|
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
|
|
uint8_t value = input_row[j];
|
|
buf[j] = std::fma(scale, (float)value, buf[j] + bias);
|
|
}
|
|
#endif
|
|
for (; j < block_size; ++j) {
|
|
uint8_t value = input_row[j];
|
|
buf[j] = std::fma(scale, (float)value, buf[j] + bias);
|
|
}
|
|
}
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
for (int64_t j = 0; j < block_size; ++j) {
|
|
buf[j] *= scale;
|
|
}
|
|
}
|
|
fill_output(out, buf, block_size, is_bf16_out);
|
|
out += output_stride;
|
|
}
|
|
return current == index_size;
|
|
}
|
|
|
|
template <typename IndexType, typename OffsetType, typename OutType>
|
|
static bool ALWAYS_INLINE EmbeddingSpMDMNBit_autovec(
|
|
const int input_bit_rate,
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const uint8_t* input,
|
|
const IndexType* indices,
|
|
const OffsetType* offsets_or_lengths,
|
|
const float* weights, // optional, can be null for non-weighted sum
|
|
const bool normalize_by_lengths,
|
|
OutType* out,
|
|
const bool is_weight_positional,
|
|
const bool use_offsets,
|
|
const int64_t output_stride,
|
|
const int64_t input_stride,
|
|
const bool scale_bias_last,
|
|
const bool is_bf16_out,
|
|
const bool no_bag,
|
|
int output_bit_rate) {
|
|
nbit_embedding_sanity_check<OutType>(input_bit_rate, output_bit_rate, no_bag);
|
|
if (data_size < 0) {
|
|
return false;
|
|
}
|
|
|
|
// more prefetch
|
|
// TODO: in the future we should adjust max_prefetch_bytes based on CPU cache
|
|
// size
|
|
constexpr int64_t max_prefetch_bytes = 4096;
|
|
// 16 is manually tuned for Neoverse-V2 for best performance
|
|
constexpr int64_t max_initial_prefetch_rows = 16;
|
|
constexpr int64_t CACHE_LINE_SIZE = 64;
|
|
const int64_t rows_to_prefetch =
|
|
std::min(max_initial_prefetch_rows, max_prefetch_bytes / input_stride);
|
|
const int64_t prefetch_stride = std::min(rows_to_prefetch, index_size);
|
|
const int num_elem_per_byte = 8 / input_bit_rate;
|
|
const int64_t scale_bias_offset =
|
|
scale_bias_last ? div_up(block_size, num_elem_per_byte) : 0;
|
|
const size_t scale_bias_size = 2 * sizeof(float16);
|
|
const int64_t input_row_offset = scale_bias_last ? 0 : scale_bias_size;
|
|
// The following prefetch loop is written in this way for better performance.
|
|
// My understanding is that manually separating the case of input_stride being
|
|
// greater or not greater than cache line size will make the branch predictor
|
|
// work better. Same for line 113-126.
|
|
for (int64_t pf_idx = 0; pf_idx < prefetch_stride; ++pf_idx) {
|
|
const uint8_t* prefetch_addr = input + input_stride * indices[pf_idx];
|
|
for (int64_t offset = 0; offset < input_stride; offset += CACHE_LINE_SIZE) {
|
|
do_prefetch(prefetch_addr + offset, 0, 0);
|
|
}
|
|
}
|
|
|
|
if (no_bag) {
|
|
// We currently only support int4 to int4 for sequential TBE in this nbit
|
|
// kernel. Note that assert() will be ignored in release mode, so we check
|
|
// here to double check and also avoid "unused variable" warning
|
|
if (!(input_bit_rate == 4 && output_bit_rate == 4)) {
|
|
WARN_ONCE("no_bag is only supported for int4 to int4");
|
|
return false;
|
|
}
|
|
for (int64_t i = 0; i < output_size; ++i) {
|
|
const auto idx = indices[i];
|
|
if (idx < 0 || idx > data_size) {
|
|
return false;
|
|
}
|
|
const uint8_t* input_row = input + input_stride * idx;
|
|
memcpy(out, input_row, sizeof(uint8_t) * input_stride);
|
|
out += input_stride;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
int64_t current = 0;
|
|
const int64_t rounded_block_size = round_up(block_size, num_elem_per_byte);
|
|
|
|
std::array<float, LOCAL_STORAGE_SIZE> local_storage;
|
|
std::unique_ptr<float[]> heap_storage;
|
|
float* buf;
|
|
if (static_cast<size_t>(rounded_block_size) <= LOCAL_STORAGE_SIZE) {
|
|
buf = local_storage.data();
|
|
} else {
|
|
heap_storage.reset(new float[rounded_block_size]);
|
|
buf = heap_storage.get();
|
|
}
|
|
|
|
for (int64_t m = 0; m < output_size; ++m) {
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
int64_t end = current + len;
|
|
if (end > index_size) {
|
|
return false;
|
|
}
|
|
memset(buf, 0, sizeof(float) * rounded_block_size);
|
|
|
|
const float* weights_addr = weights != nullptr
|
|
? (is_weight_positional ? weights : weights + current)
|
|
: nullptr;
|
|
for (; current < end; ++current) {
|
|
int64_t idx = indices[current];
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
int64_t prefetch_idx =
|
|
indices[std::min(current + prefetch_stride, index_size - 1)];
|
|
|
|
const uint8_t* input_row_base = input + input_stride * idx;
|
|
const uint8_t* scale_bias_addr = input_row_base + scale_bias_offset;
|
|
const uint8_t* input_row = input_row_base + input_row_offset;
|
|
|
|
float16 scale16;
|
|
float16 bias16;
|
|
memcpy(&scale16, scale_bias_addr, sizeof(float16));
|
|
memcpy(&bias16, scale_bias_addr + sizeof(float16), sizeof(float16));
|
|
static_assert(sizeof(scale16) + sizeof(bias16) == scale_bias_size);
|
|
|
|
float scale = cpu_half2float(scale16);
|
|
float bias = cpu_half2float(bias16);
|
|
if (weights != nullptr) {
|
|
float weight = *weights_addr++;
|
|
scale *= weight;
|
|
bias *= weight;
|
|
}
|
|
|
|
if (input_bit_rate == 4) {
|
|
int64_t j = 0;
|
|
#ifdef FBGEMM_VECTOR_WIDTH
|
|
for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 2));
|
|
j += 2) {
|
|
uint8_t tmp = *input_row++;
|
|
float quantized1 = float(tmp & 0xf);
|
|
float quantized2 = float(tmp >> 4);
|
|
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
|
|
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
|
|
}
|
|
#endif
|
|
for (; j < block_size; j += 2) {
|
|
uint8_t tmp = *input_row++;
|
|
float quantized1 = float(tmp & 0xf);
|
|
float quantized2 = float(tmp >> 4);
|
|
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
|
|
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
|
|
}
|
|
} else if (input_bit_rate == 2) {
|
|
int64_t j = 0;
|
|
#ifdef FBGEMM_VECTOR_WIDTH
|
|
for (; j < block_size - (block_size % (FBGEMM_VECTOR_WIDTH * 4));
|
|
j += 4) {
|
|
uint8_t tmp = *input_row++;
|
|
float quantized1 = float(tmp & 0x3);
|
|
float quantized2 = float((tmp & 0xC) >> 2);
|
|
float quantized3 = float((tmp & 0x30) >> 4);
|
|
float quantized4 = float(tmp >> 6);
|
|
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
|
|
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
|
|
buf[j + 2] = std::fma(scale, quantized3, buf[j + 2] + bias);
|
|
buf[j + 3] = std::fma(scale, quantized4, buf[j + 3] + bias);
|
|
}
|
|
#endif
|
|
for (; j < block_size; j += 4) {
|
|
uint8_t tmp = *input_row++;
|
|
float quantized1 = float(tmp & 0x3);
|
|
float quantized2 = float((tmp & 0xC) >> 2);
|
|
float quantized3 = float((tmp & 0x30) >> 4);
|
|
float quantized4 = float(tmp >> 6);
|
|
buf[j] = std::fma(scale, quantized1, buf[j] + bias);
|
|
buf[j + 1] = std::fma(scale, quantized2, buf[j + 1] + bias);
|
|
buf[j + 2] = std::fma(scale, quantized3, buf[j + 2] + bias);
|
|
buf[j + 3] = std::fma(scale, quantized4, buf[j + 3] + bias);
|
|
}
|
|
}
|
|
|
|
const uint8_t* prefetch_addr = input + input_stride * prefetch_idx;
|
|
for (int64_t offset = 0; offset < input_stride;
|
|
offset += CACHE_LINE_SIZE) {
|
|
do_prefetch(prefetch_addr + offset, 0, 0);
|
|
}
|
|
}
|
|
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
for (int64_t j = 0; j < block_size; ++j) {
|
|
buf[j] *= scale;
|
|
}
|
|
}
|
|
fill_output(out, buf, block_size, is_bf16_out);
|
|
out += output_stride;
|
|
}
|
|
return current == index_size;
|
|
}
|
|
|
|
/// @ingroup tbe-cpu-autovec
|
|
///
|
|
/// Autovectorized version of method `EmbeddingSpMDM_ref` for FP32 weight type.
|
|
///
|
|
/// @tparam InType input data type (`uint8_t` is used)
|
|
/// @tparam IndexType index data type (`int64_t` is used)
|
|
/// @tparam OffsetType offset data type (`int32_t` is used)
|
|
/// @tparam OutType output data type (`float` is used)
|
|
///
|
|
/// @param block_size Number of elements in a block (`int64_t`)
|
|
/// @param output_size Number of elements in output (`int64_t`)
|
|
/// @param index_size Number of elements in index (`int64_t`)
|
|
/// @param data_size Number of elements in data (`int64_t`)
|
|
/// @param input Address of input (`InType*`)
|
|
/// @param indices Address of index (`IndexType*`)
|
|
/// @param offsets_or_lengths Address of offset (`OffsetType*`)
|
|
/// @param weights Weights of sum; optional, can be null for non-weighted sum
|
|
/// (`float*`)
|
|
/// @param normalize_by_lengths Whether or not to normalize by lengths (`bool`)
|
|
/// @param out Address of output (`OutType*`)
|
|
/// @param is_weight_positional If `true`, weight is positional; set to `false`
|
|
/// for FP32 autovec implementation (`bool`)
|
|
/// @param use_offsets If `true`, will use offsets instead of lengths; set to
|
|
/// `true` for FP32 autovec implementation (`bool`)
|
|
/// @param output_stride If -1, output_stride is same as block_size; set to -1
|
|
/// for FP32 autovec implementation (`int64_t`)
|
|
/// @param input_stride If -1, input_stride is same as block_size; set to -1
|
|
/// for FP32 autovec implementation (`int64_t`)
|
|
/// @param scale_bias_last If `true`, scale and bias appear at end of each row;
|
|
/// set to `true` for FP32 autovec implementation (`bool`)
|
|
/// @param no_bag If `true`, no embedding bag; set to `false` for FP32 autovec
|
|
/// implementation (`bool`)
|
|
/// @param is_bf16_out If `true`, output is `BFLOAT16` type; set to `false` for
|
|
/// FP32 autovec implementation (`bool`)
|
|
/// @param is_bf16_in If `true`, input is `BFLOAT16` type; set to `false` for
|
|
/// FP32 autovec implementation (`bool`)
|
|
template <
|
|
typename InType,
|
|
typename IndexType,
|
|
typename OffsetType,
|
|
typename OutType>
|
|
static bool ALWAYS_INLINE EmbeddingSpMDM_autovec(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const InType* input,
|
|
const IndexType* indices,
|
|
const OffsetType* offsets_or_lengths,
|
|
const float* weights, // optional, can be null for non-weighted sum
|
|
bool normalize_by_lengths,
|
|
OutType* out,
|
|
const bool is_weight_positional,
|
|
const bool use_offsets,
|
|
const int64_t output_stride,
|
|
const int64_t input_stride,
|
|
const bool no_bag,
|
|
const bool is_bf16_out,
|
|
const bool is_bf16_in) {
|
|
if (data_size < 0) {
|
|
return false;
|
|
}
|
|
|
|
std::array<float, LOCAL_STORAGE_SIZE> local_storage;
|
|
std::unique_ptr<float[]> heap_storage;
|
|
float* buf;
|
|
if (static_cast<size_t>(block_size) <= LOCAL_STORAGE_SIZE) {
|
|
buf = local_storage.data();
|
|
} else {
|
|
heap_storage.reset(new float[block_size]);
|
|
buf = heap_storage.get();
|
|
}
|
|
|
|
if (no_bag) {
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(buf, 0, sizeof(float) * block_size);
|
|
int64_t idx = indices[m];
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
|
|
if (weights != nullptr) {
|
|
float weight = weights[m];
|
|
int64_t j = 0;
|
|
#ifdef FBGEMM_VECTOR_WIDTH
|
|
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
|
|
const InType* inptr = input + input_stride * idx + j;
|
|
buf[j] = std::fma(
|
|
weight, convert_to_float_ref(*inptr, is_bf16_in), buf[j]);
|
|
}
|
|
#endif
|
|
for (; j < block_size; ++j) {
|
|
const InType* inptr = input + input_stride * idx + j;
|
|
buf[j] = std::fma(
|
|
weight, convert_to_float_ref(*inptr, is_bf16_in), buf[j]);
|
|
}
|
|
} else {
|
|
int64_t j = 0;
|
|
#ifdef FBGEMM_VECTOR_WIDTH
|
|
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
|
|
const InType* inptr = input + input_stride * idx + j;
|
|
buf[j] += convert_to_float_ref(*inptr, is_bf16_in);
|
|
}
|
|
#endif
|
|
for (; j < block_size; ++j) {
|
|
const InType* inptr = input + input_stride * idx + j;
|
|
buf[j] += convert_to_float_ref(*inptr, is_bf16_in);
|
|
}
|
|
}
|
|
fill_output(out, buf, block_size, is_bf16_out);
|
|
out += output_stride;
|
|
} // m
|
|
return true;
|
|
} // no_bag
|
|
|
|
// more prefetch
|
|
// TODO: in the future we should adjust max_prefetch_bytes based on CPU
|
|
// cache size
|
|
constexpr int64_t max_prefetch_bytes = 4096;
|
|
// 16 is manually tuned for Neoverse-V2 for best performance
|
|
constexpr int64_t max_initial_prefetch_rows = 8;
|
|
constexpr int64_t CACHE_LINE_SIZE = 64;
|
|
const int64_t rows_to_prefetch =
|
|
std::min(max_initial_prefetch_rows, max_prefetch_bytes / input_stride);
|
|
const int64_t prefetch_stride = std::min(rows_to_prefetch, index_size);
|
|
// The following prefetch loop is written in this way for better
|
|
// performance. My understanding is that manually separating the case of
|
|
// input_stride being greater or not greater than cache line size will make
|
|
// the branch predictor work better. Same for line 113-126.
|
|
for (int pf_idx = 0; pf_idx < prefetch_stride; ++pf_idx) {
|
|
const uint8_t* prefetch_addr = reinterpret_cast<const uint8_t*>(
|
|
input + input_stride * indices[pf_idx]);
|
|
for (int64_t offset = 0; offset < input_stride; offset += CACHE_LINE_SIZE) {
|
|
do_prefetch(prefetch_addr + offset, 0, 0);
|
|
}
|
|
}
|
|
|
|
// Reference implementation of FP32 SLS
|
|
int64_t current = 0;
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(buf, 0, sizeof(float) * block_size);
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
if (current + len > index_size) {
|
|
return false;
|
|
}
|
|
|
|
for (int i = 0; i < len; ++i) {
|
|
int64_t idx = indices[current];
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
|
|
int64_t prefetch_idx =
|
|
indices[std::min(current + prefetch_stride, index_size - 1)];
|
|
|
|
do_prefetch(
|
|
reinterpret_cast<const char*>(input + input_stride * prefetch_idx),
|
|
0,
|
|
0);
|
|
if (input_stride > CACHE_LINE_SIZE) {
|
|
for (int64_t offset = CACHE_LINE_SIZE; offset < input_stride;
|
|
offset += CACHE_LINE_SIZE) {
|
|
do_prefetch(
|
|
reinterpret_cast<const char*>(
|
|
input + input_stride * prefetch_idx + offset),
|
|
0,
|
|
0);
|
|
}
|
|
}
|
|
|
|
float w = 1.f;
|
|
if (weights) {
|
|
w = weights[is_weight_positional ? i : current];
|
|
}
|
|
|
|
const InType* input_row = input + input_stride * idx;
|
|
int64_t j = 0;
|
|
#ifdef FBGEMM_VECTOR_WIDTH
|
|
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
|
|
InType value = *input_row++;
|
|
buf[j] = std::fma(w, convert_to_float_ref(value, is_bf16_in), buf[j]);
|
|
}
|
|
#endif
|
|
for (; j < block_size; ++j) {
|
|
InType value = *input_row++;
|
|
buf[j] = std::fma(w, convert_to_float_ref(value, is_bf16_in), buf[j]);
|
|
}
|
|
|
|
++current;
|
|
}
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
|
|
for (int64_t j = 0; j < block_size; ++j) {
|
|
buf[j] *= scale;
|
|
}
|
|
}
|
|
|
|
fill_output(out, buf, block_size, is_bf16_out);
|
|
out += output_stride;
|
|
}
|
|
return current == index_size;
|
|
}
|
|
|
|
template <typename InType, typename IndexType, typename OffsetType>
|
|
static bool ALWAYS_INLINE EmbeddingSpMDMRowWiseSparse_autovec(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t uncompressed_data_size,
|
|
// const int64_t compressed_data_size,
|
|
const InType* input,
|
|
const IndexType* indices,
|
|
const int32_t* compressed_indices_table,
|
|
const OffsetType* offsets_or_lengths,
|
|
const float* weights, // optional, can be null for non-weighted sum
|
|
const bool normalize_by_lengths,
|
|
float* out,
|
|
const bool is_weight_positional,
|
|
const bool use_offsets) {
|
|
bool is8bit = std::is_same<InType, uint8_t>::value;
|
|
|
|
if (is8bit) {
|
|
// block_size is the number of elements and fused_block_size is the size
|
|
// of an entire row, including scale and bias.
|
|
const auto scale_bias_offset = 2 * sizeof(float);
|
|
const int64_t fused_block_size = block_size + scale_bias_offset;
|
|
int64_t current = 0;
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(out, 0, sizeof(float) * block_size);
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
int64_t end = current + len;
|
|
if (end > index_size) {
|
|
return false;
|
|
}
|
|
const float* weights_addr = weights != nullptr
|
|
? (is_weight_positional ? weights : weights + current)
|
|
: nullptr;
|
|
for (; current < end; ++current) {
|
|
IndexType uncompressed_idx = indices[current];
|
|
if (uncompressed_idx < 0 ||
|
|
uncompressed_idx >= uncompressed_data_size) {
|
|
return false;
|
|
}
|
|
IndexType idx = compressed_indices_table[uncompressed_idx];
|
|
if (idx == -1) {
|
|
continue;
|
|
}
|
|
// if (idx < 0 || idx >= compressed_data_size) {
|
|
// return false;
|
|
// }
|
|
|
|
const uint8_t* scale_bias_addr = reinterpret_cast<const uint8_t*>(
|
|
input + fused_block_size * idx + block_size);
|
|
|
|
float scale;
|
|
float bias;
|
|
memcpy(&scale, scale_bias_addr, sizeof(float));
|
|
memcpy(&bias, scale_bias_addr + sizeof(float), sizeof(float));
|
|
if (weights != nullptr) {
|
|
float weight = *weights_addr++;
|
|
scale *= weight;
|
|
bias *= weight;
|
|
}
|
|
|
|
const InType* input_row = input + fused_block_size * idx;
|
|
int64_t j = 0;
|
|
#ifdef FBGEMM_VECTOR_WIDTH
|
|
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
|
|
InType value = *input_row++;
|
|
out[j] = std::fma(scale, value, out[j] + bias);
|
|
}
|
|
#endif
|
|
for (; j < block_size; ++j) {
|
|
InType value = *input_row++;
|
|
out[j] = std::fma(scale, value, out[j] + bias);
|
|
}
|
|
}
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] *= scale;
|
|
}
|
|
}
|
|
out += block_size;
|
|
}
|
|
return current == index_size;
|
|
} else {
|
|
// Reference implementation of FP32 SLS
|
|
|
|
int64_t current = 0;
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(out, 0, sizeof(float) * block_size);
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
int64_t end = current + len;
|
|
if (end > index_size) {
|
|
return false;
|
|
}
|
|
|
|
const float* weights_addr = weights != nullptr
|
|
? (is_weight_positional ? weights : weights + current)
|
|
: nullptr;
|
|
for (; current < end; ++current) {
|
|
IndexType uncompressed_idx = indices[current];
|
|
if (uncompressed_idx < 0 ||
|
|
uncompressed_idx >= uncompressed_data_size) {
|
|
return false;
|
|
}
|
|
IndexType idx = compressed_indices_table[uncompressed_idx];
|
|
if (idx == -1) {
|
|
continue;
|
|
}
|
|
|
|
float weight = 1.f;
|
|
if (weights != nullptr) {
|
|
weight = *weights_addr++;
|
|
}
|
|
|
|
const InType* input_row = input + block_size * idx;
|
|
int64_t j = 0;
|
|
#ifdef FBGEMM_VECTOR_WIDTH
|
|
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
|
|
const InType* inptr = input_row++;
|
|
out[j] = std::fma(
|
|
weight,
|
|
std::is_same<InType, float16>::value ? cpu_half2float(*inptr)
|
|
: *inptr,
|
|
out[j]);
|
|
}
|
|
#endif
|
|
for (; j < block_size; ++j) {
|
|
const InType* inptr = input_row++;
|
|
out[j] = std::fma(
|
|
weight,
|
|
std::is_same<InType, float16>::value ? cpu_half2float(*inptr)
|
|
: *inptr,
|
|
out[j]);
|
|
}
|
|
}
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
for (int j = 0; j < block_size; ++j) {
|
|
out[j] *= scale;
|
|
}
|
|
}
|
|
out += block_size;
|
|
}
|
|
return current == index_size;
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
void Float8ToFloat_ref_batch(
|
|
const uint8_t* input,
|
|
float* output,
|
|
int count,
|
|
int exponent_bits,
|
|
int exponent_bias) {
|
|
for (int i = 0; i < count; ++i) {
|
|
uint32_t val_out, sign, multiplier;
|
|
uint8_t inp = input[i];
|
|
|
|
sign = (inp & 0x80) << 24;
|
|
val_out = (inp & 0x7F) << (24 - (8 - exponent_bits));
|
|
|
|
multiplier = (127 + (127 - exponent_bias)) << 23; // 2^(127-bias)
|
|
float val_out_f = *reinterpret_cast<float*>(&val_out) *
|
|
*reinterpret_cast<float*>(&multiplier); // val_out * multiplier
|
|
val_out = *reinterpret_cast<uint32_t*>(&val_out_f) | sign;
|
|
output[i] = *reinterpret_cast<float*>(&val_out);
|
|
}
|
|
}
|
|
} // namespace
|
|
|
|
/// @ingroup tbe-cpu-autovec
|
|
///
|
|
/// Autovectorized version of method `EmbeddingSpMDM_ref` for FP8 weight type.
|
|
///
|
|
/// @tparam InType input data type (`uint8_t` is used)
|
|
/// @tparam IndexType index data type (`int64_t` is used)
|
|
/// @tparam OffsetType offset data type (`int32_t` is used)
|
|
/// @tparam OutType output data type (`float` is used)
|
|
///
|
|
/// @param block_size Number of elements in a block (`int64_t`)
|
|
/// @param output_size Number of elements in output (`int64_t`)
|
|
/// @param index_size Number of elements in index (`int64_t`)
|
|
/// @param data_size Number of elements in data (`int64_t`)
|
|
/// @param input Address of input (`InType*`)
|
|
/// @param indices Address of index (`IndexType*`)
|
|
/// @param offsets_or_lengths Address of offset (`OffsetType*`)
|
|
/// @param weights Weights of sum; optional, can be null for non-weighted sum
|
|
/// (`float*`)
|
|
/// @param normalize_by_lengths Whether or not to normalize by lengths (`bool`)
|
|
/// @param out Address of output (`OutType*`)
|
|
/// @param is_weight_positional If `true`, weight is positional; set to `false`
|
|
/// for FP8 autovec implementation (`bool`)
|
|
/// @param use_offsets If `true`, will use offsets instead of lengths; set to
|
|
/// `true` for FP8 autovec implementation (`bool`)
|
|
/// @param output_stride If -1, output_stride is same as block_size; set to -1
|
|
/// for FP8 autovec implementation (`int64_t`)
|
|
/// @param exponent_bits Bits to use in exponent
|
|
/// @param exponent_bias Bias to use in exponent
|
|
/// @param is_bf16_out If `true`, output is `BFLOAT16` type; set to `false` for
|
|
/// FP8 autovec implementation (`bool`)
|
|
template <typename IndexType, typename OffsetType, typename OutType>
|
|
static bool ALWAYS_INLINE EmbeddingSpMDMFP8_autovec(
|
|
const int64_t block_size,
|
|
const int64_t output_size,
|
|
const int64_t index_size,
|
|
const int64_t data_size,
|
|
const uint8_t* input,
|
|
const IndexType* indices,
|
|
const OffsetType* offsets_or_lengths,
|
|
const float* weights,
|
|
bool normalize_by_lengths,
|
|
OutType* out,
|
|
const bool is_weight_positional,
|
|
const bool use_offsets,
|
|
const int64_t output_stride,
|
|
const int64_t input_stride,
|
|
const int exponent_bits,
|
|
const int exponent_bias,
|
|
const bool is_bf16_out) {
|
|
if (data_size < 0) {
|
|
return false;
|
|
}
|
|
|
|
std::array<float, LOCAL_STORAGE_SIZE> local_storage;
|
|
std::unique_ptr<float[]> heap_storage;
|
|
float* buf;
|
|
if (static_cast<size_t>(block_size) <= LOCAL_STORAGE_SIZE) {
|
|
buf = local_storage.data();
|
|
} else {
|
|
heap_storage.reset(new float[block_size]);
|
|
buf = heap_storage.get();
|
|
}
|
|
|
|
// more prefetch: prefetch up to 16 rows from the embedding table. Increasing
|
|
// prefetching helps reduce backend stall and therefore enable vectorization
|
|
// reach better of its potential. 16 is tuned for Neoverse-V2.
|
|
|
|
// more prefetch
|
|
// TODO: in the future we should adjust max_prefetch_bytes based on CPU cache
|
|
// size
|
|
constexpr int64_t max_prefetch_bytes = 4096;
|
|
// 16 is manually tuned for Neoverse-V2 for best performance
|
|
constexpr int64_t max_initial_prefetch_rows = 16;
|
|
constexpr int64_t CACHE_LINE_SIZE = 64;
|
|
const int64_t rows_to_prefetch =
|
|
std::min(max_initial_prefetch_rows, max_prefetch_bytes / input_stride);
|
|
const int64_t prefetch_stride = std::min(rows_to_prefetch, index_size);
|
|
// The following prefetch loop is written in this way for better performance.
|
|
// My understanding is that manually separating the case of input_stride being
|
|
// greater or not greater than cache line size will make the branch predictor
|
|
// work better. Same for line 113-126.
|
|
for (int pf_idx = 0; pf_idx < prefetch_stride; ++pf_idx) {
|
|
const uint8_t* prefetch_addr = input + input_stride * indices[pf_idx];
|
|
for (int64_t offset = 0; offset < input_stride; offset += CACHE_LINE_SIZE) {
|
|
do_prefetch(prefetch_addr + offset, 0, 0);
|
|
}
|
|
}
|
|
|
|
// Reference implementation of FP8 SLS. The algorithm is similar to FP32 SLS
|
|
// except for the FP8->FP32 conversion after reading the embedding weight.
|
|
int64_t current = 0;
|
|
|
|
for (int m = 0; m < output_size; ++m) {
|
|
memset(buf, 0, sizeof(float) * block_size);
|
|
int len = use_offsets ? offsets_or_lengths[m + 1] - offsets_or_lengths[m]
|
|
: offsets_or_lengths[m];
|
|
int64_t end = current + len;
|
|
if (end > index_size) {
|
|
return false;
|
|
}
|
|
|
|
// Adjust these as necessary to reflect actual batch size
|
|
const int batch_size = block_size; // Assuming the entire block is
|
|
// processed at once; adjust if needed
|
|
|
|
// Temporary buffer to hold the converted floats
|
|
std::unique_ptr<float[]> converted_inputs(new float[batch_size]);
|
|
|
|
const float* weights_addr = weights != nullptr
|
|
? (is_weight_positional ? weights : weights + current)
|
|
: nullptr;
|
|
for (; current < end; ++current) {
|
|
int64_t idx = indices[current];
|
|
if (idx < 0 || idx >= data_size) {
|
|
return false;
|
|
}
|
|
|
|
int64_t prefetch_idx =
|
|
indices[std::min(current + prefetch_stride, index_size - 1)];
|
|
|
|
do_prefetch(
|
|
reinterpret_cast<const char*>(input + input_stride * prefetch_idx),
|
|
0,
|
|
0);
|
|
if (input_stride > CACHE_LINE_SIZE) {
|
|
for (int64_t offset = CACHE_LINE_SIZE; offset < input_stride;
|
|
offset += CACHE_LINE_SIZE) {
|
|
do_prefetch(
|
|
reinterpret_cast<const char*>(
|
|
input + input_stride * prefetch_idx + offset),
|
|
0,
|
|
0);
|
|
}
|
|
}
|
|
|
|
float w = 1.f;
|
|
if (weights != nullptr) {
|
|
w = *weights_addr++;
|
|
}
|
|
// check if each loop interation depends on one another
|
|
// if not, approach it with parellel,
|
|
// the code is iterating thru a dimisonals of a embedding vectory
|
|
|
|
// Perform the batch conversion
|
|
Float8ToFloat_ref_batch(
|
|
input + input_stride * idx,
|
|
converted_inputs.get(),
|
|
batch_size,
|
|
exponent_bits,
|
|
exponent_bias);
|
|
|
|
// Now accumulate the results using vectorized operations if possible
|
|
const float* input_row = converted_inputs.get();
|
|
int64_t j = 0;
|
|
#ifdef FBGEMM_VECTOR_WIDTH
|
|
for (; j < block_size - (block_size % FBGEMM_VECTOR_WIDTH); ++j) {
|
|
float value = *input_row++;
|
|
buf[j] = std::fma(w, value, buf[j]);
|
|
}
|
|
#endif
|
|
for (; j < block_size; ++j) {
|
|
float value = *input_row++;
|
|
buf[j] = std::fma(w, value, buf[j]);
|
|
}
|
|
}
|
|
if (normalize_by_lengths && len) {
|
|
float scale = 1.f / len;
|
|
for (int j = 0; j < block_size; ++j) {
|
|
buf[j] *= scale;
|
|
}
|
|
}
|
|
|
|
fill_output(out, buf, block_size, is_bf16_out);
|
|
out += output_stride;
|
|
}
|
|
return current == index_size;
|
|
}
|
|
|
|
namespace {
|
|
namespace specialization_helper {
|
|
|
|
/// The idea with the specialization helper is to create a copy of a given
|
|
/// algorithm with some parameters set to fixed values (specialized) so the
|
|
/// compiler can perform additional optimization for the specific variant.
|
|
/// This is achieved by marking the generic functions `ALWAYS_INLINE` inline
|
|
/// and defining a macro invoking match/specialize so you can choose between
|
|
/// fixed and variable values for each parameter.
|
|
|
|
template <typename T>
|
|
struct FixedParameter {
|
|
T value;
|
|
};
|
|
struct VariableParameter {};
|
|
|
|
template <typename T>
|
|
ALWAYS_INLINE constexpr FixedParameter<T> fixed(T value) {
|
|
return FixedParameter<T>{value};
|
|
}
|
|
static constexpr VariableParameter var = VariableParameter();
|
|
|
|
template <typename T>
|
|
ALWAYS_INLINE bool match(VariableParameter, T) {
|
|
return true;
|
|
}
|
|
template <typename T>
|
|
ALWAYS_INLINE bool match(FixedParameter<T> fixed_parameter, T value) {
|
|
return fixed_parameter.value == value;
|
|
}
|
|
|
|
template <typename T>
|
|
ALWAYS_INLINE T specialize(VariableParameter, T value) {
|
|
return value;
|
|
}
|
|
template <typename T>
|
|
ALWAYS_INLINE T specialize(FixedParameter<T> fixed_parameter, T) {
|
|
return fixed_parameter.value;
|
|
}
|
|
} // namespace specialization_helper
|
|
} // namespace
|
|
|
|
template <typename InType>
|
|
static int64_t stride_SpMDMWithStrides(
|
|
int64_t block_size,
|
|
bool scale_bias_last) {
|
|
if (std::is_same<InType, uint8_t>::value) {
|
|
const size_t scale_bias_offset =
|
|
2 * (scale_bias_last ? sizeof(float) : sizeof(uint16_t));
|
|
return block_size + scale_bias_offset;
|
|
}
|
|
return block_size;
|
|
}
|
|
|
|
template <
|
|
typename InType,
|
|
typename IndexType,
|
|
typename OffsetType,
|
|
typename OutType>
|
|
typename EmbeddingSpMDMKernelSignature<InType, IndexType, OffsetType, OutType>::
|
|
Type
|
|
GenerateEmbeddingSpMDMWithStrides_autovec(
|
|
int64_t block_size,
|
|
bool has_weight,
|
|
bool normalize_by_lengths,
|
|
[[maybe_unused]] int prefetch,
|
|
bool is_weight_positional,
|
|
bool use_offsets,
|
|
int64_t output_stride,
|
|
int64_t input_stride,
|
|
bool scale_bias_last,
|
|
bool no_bag,
|
|
bool is_bf16_out,
|
|
bool is_bf16_in) {
|
|
if (output_stride == -1) {
|
|
output_stride = block_size;
|
|
}
|
|
if (input_stride == -1) {
|
|
input_stride = stride_SpMDMWithStrides<InType>(block_size, scale_bias_last);
|
|
}
|
|
|
|
using specialization_helper::fixed;
|
|
using specialization_helper::match;
|
|
using specialization_helper::specialize;
|
|
using specialization_helper::var;
|
|
|
|
#define SPECIALIZE( \
|
|
BLOCK_SIZE, \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
PREFETCH, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
OUTPUT_STRIDE, \
|
|
INPUT_STRIDE, \
|
|
SCALE_BIAS_LAST, \
|
|
NO_BAG, \
|
|
IS_BF16_OUT, \
|
|
IS_BF16_IN) \
|
|
if (match(BLOCK_SIZE, block_size) && match(HAS_WEIGHT, has_weight) && \
|
|
match(NORMALIZE_BY_LENGTHS, normalize_by_lengths) && \
|
|
match(PREFETCH, prefetch) && \
|
|
match(IS_WEIGHT_POSITIONAL, is_weight_positional) && \
|
|
match(USE_OFFSETS, use_offsets) && \
|
|
match(OUTPUT_STRIDE, output_stride) && \
|
|
match(INPUT_STRIDE, input_stride) && \
|
|
match(SCALE_BIAS_LAST, scale_bias_last) && match(NO_BAG, no_bag) && \
|
|
match(IS_BF16_OUT, is_bf16_out) && match(IS_BF16_IN, is_bf16_in)) { \
|
|
return [=](int64_t output_size, \
|
|
int64_t index_size, \
|
|
int64_t data_size, \
|
|
const InType* input, \
|
|
const IndexType* indices, \
|
|
const OffsetType* offsets_or_lengths, \
|
|
const float* weights, \
|
|
OutType* out) { \
|
|
const uint8_t* input_u8 = reinterpret_cast<const uint8_t*>(input); \
|
|
if (specialize(HAS_WEIGHT, has_weight)) { \
|
|
__builtin_assume(weights != nullptr); \
|
|
} else { \
|
|
weights = nullptr; \
|
|
} \
|
|
if (std::is_same<InType, uint8_t>::value) { \
|
|
assert(!specialize(IS_BF16_IN, is_bf16_in)); \
|
|
return EmbeddingSpMDM8Bit_autovec( \
|
|
specialize(BLOCK_SIZE, block_size), \
|
|
output_size, \
|
|
index_size, \
|
|
data_size, \
|
|
input_u8, \
|
|
indices, \
|
|
offsets_or_lengths, \
|
|
weights, \
|
|
specialize(NORMALIZE_BY_LENGTHS, normalize_by_lengths), \
|
|
out, \
|
|
specialize(IS_WEIGHT_POSITIONAL, is_weight_positional), \
|
|
specialize(USE_OFFSETS, use_offsets), \
|
|
specialize(OUTPUT_STRIDE, output_stride), \
|
|
specialize(INPUT_STRIDE, input_stride), \
|
|
specialize(SCALE_BIAS_LAST, scale_bias_last), \
|
|
specialize(NO_BAG, no_bag), \
|
|
specialize(IS_BF16_OUT, is_bf16_out)); \
|
|
} else { \
|
|
return EmbeddingSpMDM_autovec( \
|
|
/*block_size=*/specialize(BLOCK_SIZE, block_size), \
|
|
/*output_size=*/output_size, \
|
|
/*index_size=*/index_size, \
|
|
/*data_size=*/data_size, \
|
|
/*input=*/input, \
|
|
/*indices=*/indices, \
|
|
/*offsets_or_lengths=*/offsets_or_lengths, \
|
|
/*weights=*/weights, /*normalize_by_lengths=*/ \
|
|
specialize(NORMALIZE_BY_LENGTHS, normalize_by_lengths), \
|
|
/*out=*/out, /*is_weight_positional=*/ \
|
|
specialize(IS_WEIGHT_POSITIONAL, is_weight_positional), \
|
|
/*use_offsets=*/specialize(USE_OFFSETS, use_offsets), \
|
|
/*output_stride=*/specialize(OUTPUT_STRIDE, output_stride), \
|
|
/*input_stride=*/specialize(INPUT_STRIDE, input_stride), \
|
|
/*no_bag=*/specialize(NO_BAG, no_bag), \
|
|
/*is_bf16_out=*/specialize(IS_BF16_OUT, is_bf16_out), \
|
|
/*is_bf16_in=*/specialize(IS_BF16_IN, is_bf16_in)); \
|
|
} \
|
|
}; \
|
|
}
|
|
|
|
#define SPECIALIZE_BLOCK_SIZE( \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
PREFETCH, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
NO_BAG, \
|
|
IS_BF16_OUT, \
|
|
IS_BF16_IN) \
|
|
SPECIALIZE( \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{32}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
PREFETCH, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides<InType>(32, false)), \
|
|
/*SCALE_BIAS_LAST*/ fixed(false), \
|
|
NO_BAG, \
|
|
IS_BF16_OUT, \
|
|
IS_BF16_IN) \
|
|
SPECIALIZE( \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{64}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
PREFETCH, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides<InType>(64, false)), \
|
|
/*SCALE_BIAS_LAST*/ fixed(false), \
|
|
NO_BAG, \
|
|
IS_BF16_OUT, \
|
|
IS_BF16_IN) \
|
|
SPECIALIZE( \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{124}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
PREFETCH, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides<InType>(124, false)), \
|
|
/*SCALE_BIAS_LAST*/ fixed(false), \
|
|
NO_BAG, \
|
|
IS_BF16_OUT, \
|
|
IS_BF16_IN) \
|
|
SPECIALIZE( \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{128}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
PREFETCH, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides<InType>(128, false)), \
|
|
/*SCALE_BIAS_LAST*/ fixed(false), \
|
|
NO_BAG, \
|
|
IS_BF16_OUT, \
|
|
IS_BF16_IN) \
|
|
SPECIALIZE( \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{252}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
PREFETCH, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides<InType>(252, false)), \
|
|
/*SCALE_BIAS_LAST*/ fixed(false), \
|
|
NO_BAG, \
|
|
IS_BF16_OUT, \
|
|
IS_BF16_IN) \
|
|
SPECIALIZE( \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{256}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
PREFETCH, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides<InType>(256, false)), \
|
|
/*SCALE_BIAS_LAST*/ fixed(false), \
|
|
NO_BAG, \
|
|
IS_BF16_OUT, \
|
|
IS_BF16_IN) \
|
|
SPECIALIZE( \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{508}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
PREFETCH, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides<InType>(508, false)), \
|
|
/*SCALE_BIAS_LAST*/ fixed(false), \
|
|
NO_BAG, \
|
|
IS_BF16_OUT, \
|
|
IS_BF16_IN) \
|
|
SPECIALIZE( \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{512}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
PREFETCH, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMWithStrides<InType>(512, false)), \
|
|
/*SCALE_BIAS_LAST*/ fixed(false), \
|
|
NO_BAG, \
|
|
IS_BF16_OUT, \
|
|
IS_BF16_IN)
|
|
|
|
#ifdef FBGEMM_MORE_SPECIALIZATION
|
|
SPECIALIZE_BLOCK_SIZE(
|
|
/*HAS_WEIGHT*/ fixed(true),
|
|
/*NORMALIZE_BY_LENGTHS*/ fixed(false),
|
|
/*PREFETCH*/ var,
|
|
/*IS_WEIGHT_POSITIONAL*/ fixed(false),
|
|
/*USE_OFFSETS*/ fixed(true),
|
|
/*NO_BAG*/ fixed(false),
|
|
/*IS_BF16_OUT*/ var,
|
|
/*IS_BF16_IN*/ var)
|
|
SPECIALIZE_BLOCK_SIZE(
|
|
/*HAS_WEIGHT*/ fixed(false),
|
|
/*NORMALIZE_BY_LENGTHS*/ fixed(false),
|
|
/*PREFETCH*/ var,
|
|
/*IS_WEIGHT_POSITIONAL*/ fixed(false),
|
|
/*USE_OFFSETS*/ fixed(true),
|
|
/*NO_BAG*/ fixed(false),
|
|
/*IS_BF16_OUT*/ var,
|
|
/*IS_BF16_IN*/ var)
|
|
WARN_ONCE(
|
|
"fbgemm warning: "
|
|
"using non-specialized EmbeddingSpMDM_autovec (may be slow)\n"
|
|
" parameters: block_size: %ld has_weight: %d normalize_by_lengths: %d "
|
|
"is_weight_positional: %d use_offsets: %d output_stride: %ld "
|
|
"input_stride: %ld scale_bias_last: %d no_bag: %d\n",
|
|
static_cast<long>(block_size),
|
|
static_cast<int>(has_weight),
|
|
static_cast<int>(normalize_by_lengths),
|
|
static_cast<int>(is_weight_positional),
|
|
static_cast<int>(use_offsets),
|
|
static_cast<long>(output_stride),
|
|
static_cast<long>(input_stride),
|
|
static_cast<int>(scale_bias_last),
|
|
static_cast<int>(no_bag));
|
|
#endif
|
|
|
|
// Catch-all case.
|
|
SPECIALIZE(var, var, var, var, var, var, var, var, var, var, var, var)
|
|
abort(); // should not get here
|
|
|
|
#undef SPECIALIZE_BLOCK_SIZE
|
|
#undef SPECIALIZE
|
|
}
|
|
|
|
static int64_t stride_SpMDMNBitWith(int input_bit_rate, int64_t block_size) {
|
|
const int num_elem_per_byte = 8 / input_bit_rate;
|
|
const size_t scale_bias_size = 2 * sizeof(float16);
|
|
return div_up(block_size, num_elem_per_byte) + scale_bias_size;
|
|
}
|
|
|
|
template <typename IndexType, typename OffsetType, typename OutType>
|
|
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
|
|
uint8_t,
|
|
IndexType,
|
|
OffsetType,
|
|
OutType>::Type
|
|
GenerateEmbeddingSpMDMNBitWithStrides_autovec(
|
|
int input_bit_rate,
|
|
int64_t block_size,
|
|
bool has_weight,
|
|
bool normalize_by_lengths,
|
|
[[maybe_unused]] int prefetch,
|
|
bool is_weight_positional,
|
|
bool use_offsets,
|
|
int64_t output_stride,
|
|
int64_t input_stride,
|
|
bool scale_bias_last,
|
|
bool is_bf16_out,
|
|
bool no_bag,
|
|
int output_bit_rate) {
|
|
if (output_bit_rate == -1) {
|
|
output_bit_rate = 8 * sizeof(OutType);
|
|
}
|
|
if (output_stride == -1) {
|
|
output_stride = block_size;
|
|
}
|
|
|
|
if (input_stride == -1) {
|
|
input_stride = stride_SpMDMNBitWith(input_bit_rate, block_size);
|
|
}
|
|
|
|
using specialization_helper::fixed;
|
|
using specialization_helper::match;
|
|
using specialization_helper::specialize;
|
|
using specialization_helper::var;
|
|
|
|
#define SPECIALIZE( \
|
|
INPUT_BIT_RATE, \
|
|
BLOCK_SIZE, \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
OUTPUT_STRIDE, \
|
|
INPUT_STRIDE, \
|
|
SCALE_BIAS_LAST, \
|
|
IS_BF16_OUT, \
|
|
NO_BAG, \
|
|
OUTPUT_BIT_RATE) \
|
|
if (match(INPUT_BIT_RATE, input_bit_rate) && \
|
|
match(BLOCK_SIZE, block_size) && match(HAS_WEIGHT, has_weight) && \
|
|
match(NORMALIZE_BY_LENGTHS, normalize_by_lengths) && \
|
|
match(IS_WEIGHT_POSITIONAL, is_weight_positional) && \
|
|
match(USE_OFFSETS, use_offsets) && \
|
|
match(OUTPUT_STRIDE, output_stride) && \
|
|
match(INPUT_STRIDE, input_stride) && \
|
|
match(SCALE_BIAS_LAST, scale_bias_last) && \
|
|
match(IS_BF16_OUT, is_bf16_out) && match(NO_BAG, no_bag) && \
|
|
match(OUTPUT_BIT_RATE, output_bit_rate)) { \
|
|
return [=](int64_t output_size, \
|
|
int64_t index_size, \
|
|
int64_t data_size, \
|
|
const uint8_t* input, \
|
|
const IndexType* indices, \
|
|
const OffsetType* offsets_or_lengths, \
|
|
const float* weights, \
|
|
OutType* out) { \
|
|
if (specialize(HAS_WEIGHT, has_weight)) { \
|
|
__builtin_assume(weights != nullptr); \
|
|
} else { \
|
|
weights = nullptr; \
|
|
} \
|
|
return EmbeddingSpMDMNBit_autovec( \
|
|
/*input_bit_rate=*/specialize(INPUT_BIT_RATE, input_bit_rate), \
|
|
/*block_size=*/specialize(BLOCK_SIZE, block_size), \
|
|
/*output_size=*/output_size, \
|
|
/*index_size=*/index_size, \
|
|
/*data_size=*/data_size, \
|
|
/*input=*/input, \
|
|
/*indices=*/indices, \
|
|
/*offsets_or_lengths=*/offsets_or_lengths, \
|
|
/*weights=*/weights, /*normalize_by_lengths=*/ \
|
|
specialize(NORMALIZE_BY_LENGTHS, normalize_by_lengths), \
|
|
/*out=*/out, /*is_weight_positional=*/ \
|
|
specialize(IS_WEIGHT_POSITIONAL, is_weight_positional), \
|
|
/*use_offsets=*/specialize(USE_OFFSETS, use_offsets), \
|
|
/*output_stride=*/specialize(OUTPUT_STRIDE, output_stride), \
|
|
/*input_stride=*/specialize(INPUT_STRIDE, input_stride), \
|
|
/*scale_bias_last=*/specialize(SCALE_BIAS_LAST, scale_bias_last), \
|
|
/*is_bf16_out=*/specialize(IS_BF16_OUT, is_bf16_out), \
|
|
/*no_bag=*/specialize(NO_BAG, no_bag), \
|
|
/*output_bit_rate=*/specialize(OUTPUT_BIT_RATE, output_bit_rate)); \
|
|
}; \
|
|
}
|
|
|
|
#define SPECIALIZE_BLOCK_SIZE( \
|
|
INPUT_BIT_RATE, \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
SCALE_BIAS_LAST, \
|
|
IS_BF16_OUT, \
|
|
NO_BAG, \
|
|
OUTPUT_BIT_RATE) \
|
|
SPECIALIZE( \
|
|
INPUT_BIT_RATE, \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{32}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 32)), \
|
|
SCALE_BIAS_LAST, \
|
|
IS_BF16_OUT, \
|
|
NO_BAG, \
|
|
OUTPUT_BIT_RATE) \
|
|
SPECIALIZE( \
|
|
INPUT_BIT_RATE, \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{56}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 56)), \
|
|
SCALE_BIAS_LAST, \
|
|
IS_BF16_OUT, \
|
|
NO_BAG, \
|
|
OUTPUT_BIT_RATE) \
|
|
SPECIALIZE( \
|
|
INPUT_BIT_RATE, \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{64}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 64)), \
|
|
SCALE_BIAS_LAST, \
|
|
IS_BF16_OUT, \
|
|
NO_BAG, \
|
|
OUTPUT_BIT_RATE) \
|
|
SPECIALIZE( \
|
|
INPUT_BIT_RATE, \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{120}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 120)), \
|
|
SCALE_BIAS_LAST, \
|
|
IS_BF16_OUT, \
|
|
NO_BAG, \
|
|
OUTPUT_BIT_RATE) \
|
|
SPECIALIZE( \
|
|
INPUT_BIT_RATE, \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{128}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 128)), \
|
|
SCALE_BIAS_LAST, \
|
|
IS_BF16_OUT, \
|
|
NO_BAG, \
|
|
OUTPUT_BIT_RATE) \
|
|
SPECIALIZE( \
|
|
INPUT_BIT_RATE, \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{248}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 248)), \
|
|
SCALE_BIAS_LAST, \
|
|
IS_BF16_OUT, \
|
|
NO_BAG, \
|
|
OUTPUT_BIT_RATE) \
|
|
SPECIALIZE( \
|
|
INPUT_BIT_RATE, \
|
|
/*BLOCK_SIZE*/ fixed(int64_t{256}), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
/*OUTPUT_STRIDE*/ var, \
|
|
/*INPUT_STRIDE*/ fixed(stride_SpMDMNBitWith(INPUT_BIT_RATE.value, 256)), \
|
|
SCALE_BIAS_LAST, \
|
|
IS_BF16_OUT, \
|
|
NO_BAG, \
|
|
OUTPUT_BIT_RATE)
|
|
|
|
#define SPECIALIZE_INPUT_RATE( \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
SCALE_BIAS_LAST, \
|
|
IS_BF16_OUT, \
|
|
NO_BAG) \
|
|
SPECIALIZE_BLOCK_SIZE( \
|
|
/*INPUT_BIT_RATE*/ fixed(4), \
|
|
HAS_WEIGHT, \
|
|
NORMALIZE_BY_LENGTHS, \
|
|
IS_WEIGHT_POSITIONAL, \
|
|
USE_OFFSETS, \
|
|
SCALE_BIAS_LAST, \
|
|
IS_BF16_OUT, \
|
|
NO_BAG, \
|
|
/*OUTPUT_BIT_RATE*/ fixed(int{8 * sizeof(OutType)}))
|
|
|
|
#ifdef FBGEMM_MORE_SPECIALIZATION
|
|
SPECIALIZE_INPUT_RATE(
|
|
/*HAS_WEIGHT*/ fixed(true),
|
|
/*NORMALIZE_BY_LENGTHS*/ fixed(false),
|
|
/*IS_WEIGHT_POSITIONAL*/ fixed(false),
|
|
/*USE_OFFSETS*/ fixed(true),
|
|
/*SCALE_BIAS_LAST*/ fixed(false),
|
|
/*IS_BF16_OUT*/ var,
|
|
/*NO_BAG*/ fixed(false))
|
|
SPECIALIZE_INPUT_RATE(
|
|
/*HAS_WEIGHT*/ fixed(false),
|
|
/*NORMALIZE_BY_LENGTHS*/ fixed(false),
|
|
/*IS_WEIGHT_POSITIONAL*/ fixed(false),
|
|
/*USE_OFFSETS*/ fixed(true),
|
|
/*SCALE_BIAS_LAST*/ fixed(false),
|
|
/*IS_BF16_OUT*/ var,
|
|
/*NO_BAG*/ fixed(false))
|
|
WARN_ONCE(
|
|
"fbgemm warning: "
|
|
"using non-specialized EmbeddingSpMDMNBit_autovec (may be slow)\n"
|
|
" parameters: input_bit_rate: %d block_size: %ld has_weight: %d "
|
|
"normalize_by_lengths: %d is_weight_positional: %d use_offsets: %d "
|
|
"output_stride: %ld input_stride: %ld scale_bias_last: %d no_bag: %d "
|
|
"output_bit_rate: %d\n",
|
|
input_bit_rate,
|
|
static_cast<long>(block_size),
|
|
static_cast<int>(has_weight),
|
|
static_cast<int>(normalize_by_lengths),
|
|
static_cast<int>(is_weight_positional),
|
|
static_cast<int>(use_offsets),
|
|
static_cast<long>(output_stride),
|
|
static_cast<long>(input_stride),
|
|
static_cast<int>(scale_bias_last),
|
|
static_cast<int>(no_bag),
|
|
output_bit_rate);
|
|
#endif
|
|
|
|
// Catch-all case: Just specialize for input_bit_rate.
|
|
SPECIALIZE(
|
|
/*INPUT_BIT_RATE*/ fixed(2),
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var);
|
|
SPECIALIZE(
|
|
/*INPUT_BIT_RATE*/ fixed(4),
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var,
|
|
var);
|
|
abort(); // should not get here
|
|
|
|
#undef SPECIALIZE_BLOCK_SIZE
|
|
#undef SPECIALIZE_INPUT_RATE
|
|
#undef SPECIALIZE
|
|
}
|
|
|
|
template <typename IndexType, typename OffsetType, typename OutType>
|
|
typename EmbeddingSpMDMKernelSignature<
|
|
uint8_t,
|
|
IndexType,
|
|
OffsetType,
|
|
OutType>::Type
|
|
GenerateEmbeddingSpMDMFP8WithStrides_autovec(
|
|
int64_t block_size,
|
|
bool normalize_by_lengths,
|
|
bool is_weight_positional,
|
|
bool use_offsets,
|
|
int64_t output_stride,
|
|
int64_t input_stride,
|
|
int exponent_bits,
|
|
int exponent_bias,
|
|
bool is_bf16_out) {
|
|
if (output_stride == -1) {
|
|
output_stride = block_size;
|
|
}
|
|
if (input_stride == -1) {
|
|
input_stride = block_size;
|
|
}
|
|
return [=](int64_t output_size,
|
|
int64_t index_size,
|
|
int64_t data_size,
|
|
const uint8_t* input,
|
|
const IndexType* indices,
|
|
const OffsetType* offsets_or_lengths,
|
|
const float* weights,
|
|
OutType* out) {
|
|
return EmbeddingSpMDMFP8_autovec(
|
|
/*block_size=*/block_size,
|
|
/*output_size=*/output_size,
|
|
/*index_size=*/index_size,
|
|
/*data_size=*/data_size,
|
|
/*input=*/input,
|
|
/*indices=*/indices,
|
|
/*offsets_or_lengths=*/offsets_or_lengths,
|
|
/*weights=*/weights,
|
|
/*normalize_by_lengths=*/normalize_by_lengths,
|
|
/*out=*/out,
|
|
/*is_weight_positional=*/is_weight_positional,
|
|
/*use_offsets=*/use_offsets,
|
|
/*output_stride=*/output_stride,
|
|
/*input_stride=*/input_stride,
|
|
/*exponent_bits=*/exponent_bits,
|
|
/*exponent_bias=*/exponent_bias,
|
|
/*is_bf16_out=*/is_bf16_out);
|
|
};
|
|
}
|
|
|
|
template <typename InType, typename IndexType, typename OffsetType>
|
|
typename EmbeddingSpMDMRowWiseSparseKernelSignature<
|
|
InType,
|
|
IndexType,
|
|
OffsetType>::Type
|
|
GenerateEmbeddingSpMDMRowWiseSparse_autovec(
|
|
int64_t block_size,
|
|
bool has_weight,
|
|
bool normalize_by_lengths,
|
|
[[maybe_unused]] int prefetch,
|
|
bool is_weight_positional,
|
|
bool use_offsets) {
|
|
return [=](int64_t output_size,
|
|
int64_t index_size,
|
|
int64_t uncompressed_data_size,
|
|
const InType* input,
|
|
const IndexType* indices,
|
|
const OffsetType* offsets_or_lengths,
|
|
const float* weights,
|
|
float* out,
|
|
const int32_t* compressed_indices_table) {
|
|
if (!has_weight) {
|
|
weights = nullptr;
|
|
}
|
|
return EmbeddingSpMDMRowWiseSparse_autovec(
|
|
/*block_size=*/block_size,
|
|
/*output_size=*/output_size,
|
|
/*index_size=*/index_size,
|
|
/*uncompressed_data_size=*/uncompressed_data_size,
|
|
/*input=*/input,
|
|
/*indices=*/indices,
|
|
/*compressed_indices_table=*/compressed_indices_table,
|
|
/*offsets_or_lengths=*/offsets_or_lengths,
|
|
/*weights=*/weights,
|
|
/*normalize_by_lengths=*/normalize_by_lengths,
|
|
/*out=*/out,
|
|
/*is_weight_positional=*/is_weight_positional,
|
|
/*use_offsets=*/use_offsets);
|
|
};
|
|
}
|
|
|
|
#define INSTANTIATE_SPMDM_NBIT_WITH_STRIDES(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
|
|
template typename EmbeddingSpMDMKernelSignature< \
|
|
uint8_t, \
|
|
INDEX_TYPE, \
|
|
OFFSET_TYPE, \
|
|
OUT_TYPE>::Type FBGEMM_API \
|
|
GenerateEmbeddingSpMDMNBitWithStrides_autovec< \
|
|
INDEX_TYPE, \
|
|
OFFSET_TYPE, \
|
|
OUT_TYPE>( \
|
|
int input_bit_rate, \
|
|
int64_t block_size, \
|
|
bool has_weight, \
|
|
bool normalize_by_lengths, \
|
|
int prefetch, \
|
|
bool is_weight_positional, \
|
|
bool use_offsets, \
|
|
int64_t output_stride, \
|
|
int64_t input_stride, \
|
|
bool scale_bias_last, \
|
|
bool is_bf16_out, \
|
|
bool no_bag, \
|
|
int output_bit_rate);
|
|
|
|
#define INSTANTIATE_SPMDM_FP8(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
|
|
template typename EmbeddingSpMDMKernelSignature< \
|
|
uint8_t, \
|
|
INDEX_TYPE, \
|
|
OFFSET_TYPE, \
|
|
OUT_TYPE>::Type \
|
|
GenerateEmbeddingSpMDMFP8WithStrides_autovec< \
|
|
INDEX_TYPE, \
|
|
OFFSET_TYPE, \
|
|
OUT_TYPE>( \
|
|
int64_t block_size, \
|
|
bool normalize_by_lengths, \
|
|
bool is_weight_positional, \
|
|
bool use_offsets, \
|
|
int64_t output_stride, \
|
|
int64_t input_stride, \
|
|
int exponent_bits, \
|
|
int exponent_bias, \
|
|
bool is_bf16_out);
|
|
|
|
#define INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
|
|
INSTANTIATE_SPMDM_NBIT_WITH_STRIDES(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
|
|
INSTANTIATE_SPMDM_FP8(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE)
|
|
|
|
#define INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, OFFSET_TYPE) \
|
|
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float) \
|
|
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, float16) \
|
|
INSTANTIATE_SPMDM_BASE(INDEX_TYPE, OFFSET_TYPE, uint8_t)
|
|
|
|
#define INSTANTIATE_SPMDM_OFFSET_T(INDEX_TYPE) \
|
|
INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int32_t) \
|
|
INSTANTIATE_SPMDM_OUT_T(INDEX_TYPE, int64_t)
|
|
|
|
INSTANTIATE_SPMDM_OFFSET_T(int32_t)
|
|
INSTANTIATE_SPMDM_OFFSET_T(int64_t)
|
|
|
|
#undef INSTANTIATE_SPMDM_OFFSET_T
|
|
#undef INSTANTIATE_SPMDM_OUT_T
|
|
#undef INSTANTIATE_SPMDM_BASE
|
|
|
|
#define INSTANTIATE_SPMDM_ROWWISE(IN_TYPE, OFFSET_TYPE, OUT_TYPE) \
|
|
template typename EmbeddingSpMDMRowWiseSparseKernelSignature< \
|
|
IN_TYPE, \
|
|
OFFSET_TYPE, \
|
|
OUT_TYPE>::Type \
|
|
GenerateEmbeddingSpMDMRowWiseSparse_autovec<IN_TYPE, OFFSET_TYPE, OUT_TYPE>( \
|
|
int64_t block_size, \
|
|
bool has_weight, \
|
|
bool normalize_by_lengths, \
|
|
int prefetch, \
|
|
bool is_weight_positional, \
|
|
bool use_offsets);
|
|
|
|
#define INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \
|
|
template typename EmbeddingSpMDMKernelSignature< \
|
|
IN_TYPE, \
|
|
INDEX_TYPE, \
|
|
OFFSET_TYPE, \
|
|
OUT_TYPE>::Type \
|
|
GenerateEmbeddingSpMDMWithStrides_autovec< \
|
|
IN_TYPE, \
|
|
INDEX_TYPE, \
|
|
OFFSET_TYPE, \
|
|
OUT_TYPE>( \
|
|
int64_t block_size, \
|
|
bool has_weight, \
|
|
bool normalize_by_lengths, \
|
|
int prefetch, \
|
|
bool is_weight_positional, \
|
|
bool use_offsets, \
|
|
int64_t output_stride, \
|
|
int64_t input_stride, \
|
|
bool scale_bias_last, \
|
|
bool no_bag, \
|
|
bool is_bf16_out, \
|
|
bool is_bf16_in);
|
|
|
|
#define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \
|
|
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \
|
|
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float16) \
|
|
INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, std::uint8_t) \
|
|
INSTANTIATE_SPMDM_ROWWISE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE)
|
|
|
|
#define INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, INDEX_TYPE) \
|
|
INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, std::int32_t) \
|
|
INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, std::int64_t)
|
|
|
|
#define INSTANTIATE_SPMDM_INDEX_T(IN_TYPE) \
|
|
INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int32_t) \
|
|
INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, std::int64_t)
|
|
|
|
INSTANTIATE_SPMDM_INDEX_T(float)
|
|
INSTANTIATE_SPMDM_INDEX_T(float16)
|
|
INSTANTIATE_SPMDM_INDEX_T(std::uint8_t)
|
|
|
|
#undef INSTANTIATE_SPMDM_ROWWISE
|
|
#undef INSTANTIATE_SPMDM_INDEX_T
|
|
#undef INSTANTIATE_SPMDM_OFFSET_T
|
|
#undef INSTANTIATE_SPMDM_OUT_T
|
|
#undef INSTANTIATE_SPMDM_BASE
|
|
|
|
} // namespace fbgemm
|
|
|
|
#endif // #ifdef __linux__
|