/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #define FBGEMM_EXPORTS #include #include #include #include #include #include #include #include #include #include "./CodeCache.h" #include "./EmbeddingSpMDMAutovec.h" #include "./MaskAvx2.h" #include "./RefImplementations.h" #include "fbgemm/FbgemmEmbedding.h" #include "fbgemm/SimdUtils.h" #include "fbgemm/Utils.h" namespace fbgemm { namespace { namespace x86 = asmjit::x86; template < typename inType, typename indxType, typename offsetType, typename outType, bool ROWWISE_SPARSE> class ReturnFunctionSignature {}; template < typename inType, typename indxType, typename offsetType, typename outType> class ReturnFunctionSignature { public: using jit_embedding_kernel = bool (*)( int64_t output_size, int64_t index_size, int64_t data_size, const inType* input, const indxType* indices, const offsetType* offsets_or_lengths, const float* weights, outType* out, const int* mask); }; template < typename inType, typename indxType, typename offsetType, typename outType> class ReturnFunctionSignature { public: using jit_embedding_kernel = bool (*)( int64_t output_size, int64_t index_size, int64_t uncompressed_data_size, // int64_t compressed_data_size, const inType* input, const indxType* indices, const offsetType* offsets_or_lengths, const float* weights, outType* out, const int32_t* compressed_indices_table, const int* mask); }; template < typename inType, typename indxType, typename offsetType, typename outType, inst_set_t instSet, bool ROWWISE_SPARSE = false, bool THREAD_LOCAL = false> class GenEmbeddingSpMDMLookup { public: GenEmbeddingSpMDMLookup() {} typename ReturnFunctionSignature< inType, indxType, offsetType, outType, ROWWISE_SPARSE>::jit_embedding_kernel getOrCreate( int block_size, bool has_weight, bool is_weight_positional, bool normalize_by_lengths, int prefetch, bool use_offsets, int output_stride, int input_stride, bool scale_bias_last, bool is_bf16_out, bool is_bf16_in); private: static asmjit::JitRuntime& runtime() { static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, // depents on other static // variables. Required to prevent // initialization order fiasco return rt; } static std::mutex rtMutex_; ///< Controll access to runtime; // The hash depends on embedding dimension (block size), weighted sls, // positional weights, normalize by lenths, prefetch distance, use_offsets, // output_stride, input_stride, and scale_bias_last static CodeCache< std::tuple, typename ReturnFunctionSignature< inType, indxType, offsetType, outType, ROWWISE_SPARSE>::jit_embedding_kernel, THREAD_LOCAL> codeCache_; ///< JIT Code Cache for reuse. }; // GenEmbeddingSpmDMLookup template < typename inType, typename indxType, typename offsetType, typename outType, inst_set_t instSet, bool ROWWISE_SPARSE, bool THREAD_LOCAL> std::mutex GenEmbeddingSpMDMLookup< inType, indxType, offsetType, outType, instSet, ROWWISE_SPARSE, THREAD_LOCAL>::rtMutex_; template < typename inType, typename indxType, typename offsetType, typename outType, inst_set_t instSet, bool ROWWISE_SPARSE, bool THREAD_LOCAL> CodeCache< std::tuple, typename ReturnFunctionSignature< inType, indxType, offsetType, outType, ROWWISE_SPARSE>::jit_embedding_kernel, THREAD_LOCAL> GenEmbeddingSpMDMLookup< inType, indxType, offsetType, outType, instSet, ROWWISE_SPARSE, THREAD_LOCAL>::codeCache_; template < typename inType, typename indxType, typename offsetType, typename outType, inst_set_t instSet, bool ROWWISE_SPARSE, bool THREAD_LOCAL> typename ReturnFunctionSignature< inType, indxType, offsetType, outType, ROWWISE_SPARSE>::jit_embedding_kernel GenEmbeddingSpMDMLookup< inType, indxType, offsetType, outType, instSet, ROWWISE_SPARSE, THREAD_LOCAL>:: getOrCreate( int block_size, bool has_weight, bool is_weight_positional, bool normalize_by_lengths, int prefetch, bool use_offsets, int output_stride, int input_stride, bool scale_bias_last, bool is_bf16_out, bool is_bf16_in) { auto kernelSig = std::make_tuple( block_size, has_weight, is_weight_positional, normalize_by_lengths, prefetch, use_offsets, output_stride, input_stride, scale_bias_last, is_bf16_out, is_bf16_in); return codeCache_.getOrCreate( kernelSig, [&]() -> typename ReturnFunctionSignature< inType, indxType, offsetType, outType, ROWWISE_SPARSE>::jit_embedding_kernel { bool is_8bit_in = std::is_same::value; bool is_16bit_in = std::is_same::value; bool is_16bit_out = std::is_same::value; bool is_fp16_in = is_16bit_in && !is_bf16_in; bool is_fp16_out = is_16bit_out && !is_bf16_out; // TODO: Make this tunable int pref_dist = prefetch; bool areIndices64b = std::is_same::value; asmjit::CodeHolder code; code.init(runtime().environment()); x86::Assembler assembler(&code); x86::Emitter* a = assembler.as(); #if defined(FBGEMM_LOG_CODE) std::string filename = "embeddinglookup"; if (is_8bit_in) { filename += "_8bit"; } else if (is_fp16_in) { filename += "_fp16"; } else if (is_bf16_in) { filename += "_bf16"; } if (is_bf16_out) { filename += "_bf16_out"; } else if (is_fp16_out) { filename += "_fp16_out"; } filename += "_emd_dim_" + std::to_string(block_size); filename += areIndices64b ? "_64bit" : "_32bit"; filename += instSet == inst_set_t::avx512 ? "_avx512" : "_avx2"; if (prefetch) { filename += "_prefetch"; } if (has_weight) { filename += "_hasweight"; } if (normalize_by_lengths) { filename += "_normalize_by_lengths"; } if (!use_offsets) { filename += "_use_lengths"; } if (ROWWISE_SPARSE) { filename += "_rowwise_sparse"; } filename += "_out_stride_" + std::to_string(output_stride); if (!scale_bias_last) { filename += "_scale_bias_first"; } filename += ".txt"; FILE* codeLogFile = fopen(filename.c_str(), "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogFile); code.setLogger(codeLogger); #endif // arguments to the function created x86::Gp output_size = a->zdi(); // index_size will be overwritten to hold the end address of indices x86::Gp index_size = a->zsi(); x86::Gp data_size = a->zdx(); x86::Gp input = a->zcx(); int reg_id = 8; x86::Gp indices = a->gpz(reg_id); // 8 ++reg_id; x86::Gp lengths = a->gpz(reg_id); // 9 ++reg_id; x86::Gp weights = a->gpz(reg_id); // 10 ++reg_id; x86::Gp out = a->gpz(reg_id); // 11 x86::Gp compressed_indices_table; if (ROWWISE_SPARSE) { ++reg_id; compressed_indices_table = a->gpz(reg_id); // 12 } ++reg_id; x86::Gp scratchReg1_ = a->gpz(reg_id); // 12 or 13, also for mask ++reg_id; x86::Gpd lengths_R_ = a->gpz(reg_id).r32(); // 13 or 14 ++reg_id; x86::Gp scratchReg2_ = a->gpz(reg_id); // 14 or 15 asmjit::FuncDetail func; if (ROWWISE_SPARSE) { func.init( asmjit::FuncSignatureT< bool, int64_t, // output_size int64_t, // index_size int64_t, // uncompressed_data_size const inType*, // input uint8_t or float const indxType*, // indices const offsetType*, // offsets or lengths const float*, // weights outType*, // out const int32_t*, // compressed_indices_table and then mask const int*>(asmjit::CallConvId::kHost), a->environment()); } else { func.init( asmjit::FuncSignatureT< bool, int64_t, // output_size int64_t, // index_size int64_t, // data_size const inType*, // input uint8_t or float const indxType*, // indices const offsetType*, // offsets or lengths const float*, // weights outType*, // out and then mask const int*>(asmjit::CallConvId::kHost), a->environment()); } asmjit::FuncFrame frame; frame.init(func); if (instSet == inst_set_t::avx2) { frame.setDirtyRegs( asmjit::RegGroup::kVec, asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); } else { frame.setDirtyRegs( asmjit::RegGroup::kVec, asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) | asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) | asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31)); } frame.setDirtyRegs( asmjit::RegGroup::kGp, reg_id == 15 ? asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) : asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); asmjit::FuncArgsAssignment args(&func); if (ROWWISE_SPARSE) { args.assignAll( output_size, index_size, data_size, input, indices, lengths, weights, out, compressed_indices_table, scratchReg1_); } else { args.assignAll( output_size, index_size, data_size, input, indices, lengths, weights, out, scratchReg1_); } args.updateFuncFrame(frame); frame.finalize(); a->emitProlog(frame); a->emitArgsAssignment(frame, args); constexpr int vlen = simd_info::WIDTH_32BIT_ELEMS; constexpr int NUM_VEC_REG = simd_info::NUM_VEC_REGS; int unroll_factor = NUM_VEC_REG; typedef typename simd_info::vec_reg_t vec_reg_t; int num_vec_regs_per_block = (block_size + vlen - 1) / vlen; int remainder = block_size % vlen; vec_reg_t scale_vreg; // holds scale vec_reg_t bias_vreg; // holds bias vec_reg_t w_vreg; // for weighted sls -- weights vec_reg_t vlen_inv_vreg; // used for normalize by lengths -- 1/ lengths[i] vec_reg_t src_vreg; // for holding embedding value temporarily x86::Ymm mask_vreg; // mask for avx2 x86::Xmm mask_fp16_vreg; // mask for loading fp16 in avx2 vec_reg_t ones_vreg; // 2^15 for bf16_2_fp32_rn if (is_8bit_in) { // We need 2 vec registers for 1. scale 2. bias --unroll_factor; scale_vreg = vec_reg_t(unroll_factor); --unroll_factor; bias_vreg = vec_reg_t(unroll_factor); } if (is_bf16_out) { --unroll_factor; ones_vreg = vec_reg_t(unroll_factor); a->mov(scratchReg2_, 1 << 15); a->vpinsrd(ones_vreg.xmm(), ones_vreg.xmm(), scratchReg2_, 0); a->vpbroadcastd(ones_vreg, ones_vreg.xmm()); } if (is_8bit_in || is_16bit_in || (remainder && instSet == inst_set_t::avx2)) { --unroll_factor; src_vreg = vec_reg_t(unroll_factor); } if (has_weight) { --unroll_factor; w_vreg = vec_reg_t(unroll_factor); } if (remainder && instSet == inst_set_t::avx2) { // AVX512 doesn't need to use vector register for masking --unroll_factor; mask_vreg = x86::ymm(unroll_factor); if (remainder > 1 && (is_16bit_in || is_bf16_out || is_fp16_out)) { --unroll_factor; mask_fp16_vreg = x86::xmm(unroll_factor); } } if (normalize_by_lengths) { --unroll_factor; vlen_inv_vreg = vec_reg_t(unroll_factor); } if (remainder) { if (instSet == inst_set_t::avx2) { a->vmovups( mask_vreg, x86::ymmword_ptr( scratchReg1_, (vlen - remainder) % vlen * sizeof(int32_t))); if (is_16bit_in || is_bf16_out || is_fp16_out) { if (remainder > 1) { a->vmovups( mask_fp16_vreg, x86::xmmword_ptr( scratchReg1_, (vlen - remainder / 2) * sizeof(int32_t))); } // We need to keep using the stack during the main loop a->lea( x86::rsp, x86::dword_ptr( x86::rsp, static_cast(-vlen * sizeof(int32_t)))); } } else { a->mov(scratchReg1_, (1 << remainder) - 1); a->kmovw(x86::k(1), scratchReg1_); } } // Compute the end address of indices a->lea( index_size, x86::ptr(indices, index_size, areIndices64b ? 3 : 2)); asmjit::Label exit = a->newLabel(); asmjit::Label error = a->newLabel(); asmjit::Label LoopRangeIndexBegin = a->newLabel(); asmjit::Label LoopRangeIndexEnd = a->newLabel(); // rangeIndex loop begins (iterate output_size times) a->bind(LoopRangeIndexBegin); a->dec(output_size); a->jl(LoopRangeIndexEnd); if (normalize_by_lengths) { asmjit::Label IfLengthsBegin = a->newLabel(); asmjit::Label IfLengthsEnd = a->newLabel(); a->bind(IfLengthsBegin); if (use_offsets) { a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType))); a->sub(lengths_R_, x86::dword_ptr(lengths)); } else { a->mov(lengths_R_, x86::dword_ptr(lengths)); } a->cmp(lengths_R_, 1); // Initialize vlen_inv as 0 in case lengths is 0 a->vxorps(vlen_inv_vreg, vlen_inv_vreg, vlen_inv_vreg); a->jl(IfLengthsEnd); // OK to use vreg0 because it's for out_vreg used in the main loop vec_reg_t temp_vreg(0); if (instSet == inst_set_t::avx2) { a->mov(scratchReg1_, 1); a->cvtsi2ss(vlen_inv_vreg.xmm(), scratchReg1_); a->cvtsi2ss(temp_vreg.xmm(), lengths_R_); a->divss(vlen_inv_vreg.xmm(), temp_vreg.xmm()); a->vpbroadcastd(vlen_inv_vreg, vlen_inv_vreg.xmm()); } else { // avx512 a->mov(scratchReg1_, 1); a->cvtsi2ss(temp_vreg.xmm(), scratchReg1_); a->vpbroadcastd(vlen_inv_vreg, temp_vreg.xmm()); a->vpbroadcastd(temp_vreg, lengths_R_); a->vcvtdq2ps(temp_vreg, temp_vreg); a->vdivps(vlen_inv_vreg, vlen_inv_vreg, temp_vreg); } a->bind(IfLengthsEnd); } for (int vec_idx = 0; vec_idx < num_vec_regs_per_block; vec_idx += unroll_factor) { int cur_unroll_factor = std::min(unroll_factor, num_vec_regs_per_block - vec_idx); // Initialize output regs for (int v = 0; v < cur_unroll_factor; ++v) { vec_reg_t out_vreg = vec_reg_t(v); a->vxorps(out_vreg, out_vreg, out_vreg); } if (use_offsets) { a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType))); a->sub(lengths_R_, x86::dword_ptr(lengths)); } else { a->mov(lengths_R_, x86::dword_ptr(lengths)); } // Array out of bound check a->lea( scratchReg1_, x86::ptr(indices, lengths_R_, areIndices64b ? 3 : 2)); a->cmp(scratchReg1_, index_size); a->jg(error); asmjit::Label LoopDataIndexBegin = a->newLabel(); asmjit::Label LoopDataIndexEnd = a->newLabel(); asmjit::Label ValidIndexLabel = a->newLabel(); // dataIndex loop begins (iterate lengths_R_ times) a->bind(LoopDataIndexBegin); a->dec(lengths_R_); a->jl(LoopDataIndexEnd); // Array out of bound check if (areIndices64b) { a->mov(scratchReg1_, x86::qword_ptr(indices)); } else { a->mov(scratchReg1_.r32(), x86::dword_ptr(indices)); } if (!scale_bias_last) { // When scale_bias_last == false, assume this is for table batched // embedding (TBE) that can get -1 for pruned rows. if (areIndices64b) { a->cmp(scratchReg1_, static_cast(-1)); } else { a->cmp(scratchReg1_.r32(), static_cast(-1)); } a->jne(ValidIndexLabel); a->add(indices, static_cast(sizeof(indxType))); if (has_weight) { a->add(weights, static_cast(sizeof(float))); } a->jmp(LoopDataIndexBegin); a->bind(ValidIndexLabel); } // A trick to check x >= data_size or x < 0 in one shot by treating // scratchReg1_ as if it has unsigned value // (https://stackoverflow.com/a/34072155). a->cmp(scratchReg1_, data_size); a->jae(error); if (ROWWISE_SPARSE) { a->mov( scratchReg1_.r32(), x86::dword_ptr( compressed_indices_table, scratchReg1_, 2)); // use of 2 is to multiply by 4 } int fused_block_size = input_stride * sizeof(inType); if (pref_dist) { asmjit::Label pref_dist_reset_start = a->newLabel(); asmjit::Label pref_dist_reset_end = a->newLabel(); // out of bound handling for prefetch a->lea( scratchReg2_, x86::ptr(indices, pref_dist * sizeof(indxType))); a->cmp(scratchReg2_, index_size); a->jge(pref_dist_reset_start); if (areIndices64b) { a->mov( scratchReg2_, x86::qword_ptr(indices, pref_dist * sizeof(indxType))); } else { a->mov( scratchReg2_.r32(), x86::dword_ptr(indices, pref_dist * sizeof(indxType))); } a->jmp(pref_dist_reset_end); a->bind(pref_dist_reset_start); // things are not okay just get the current row // this can be improved to getting the max dist row. if (areIndices64b) { a->mov(scratchReg2_, x86::qword_ptr(indices)); } else { a->mov(scratchReg2_.r32(), x86::dword_ptr(indices)); } a->bind(pref_dist_reset_end); if (ROWWISE_SPARSE) { asmjit::Label rowwise_sparse_pref_corner_case_begin = a->newLabel(); asmjit::Label rowwise_sparse_pref_corner_case_end = a->newLabel(); a->cmp(scratchReg2_, data_size); a->jae(rowwise_sparse_pref_corner_case_begin); a->mov( scratchReg2_.r32(), x86::dword_ptr( compressed_indices_table, scratchReg2_, 2)); // use of 2 is to multiply by 4 a->test(scratchReg2_.r32(), scratchReg2_.r32()); // Check negative a->jns(rowwise_sparse_pref_corner_case_end); a->bind(rowwise_sparse_pref_corner_case_begin); // For corner case, just set prefetch row id to 0. a->xor_(scratchReg2_.r32(), scratchReg2_.r32()); a->bind(rowwise_sparse_pref_corner_case_end); } a->imul(scratchReg2_, static_cast(fused_block_size)); } a->add(indices, static_cast(sizeof(indxType))); if (has_weight) { a->vbroadcastss(w_vreg, x86::dword_ptr(weights)); a->add(weights, static_cast(sizeof(float))); } if (ROWWISE_SPARSE) { a->cmp(scratchReg1_.r32(), static_cast(-1)); a->je(LoopDataIndexBegin); } a->imul(scratchReg1_, static_cast(fused_block_size)); // broadcast the scale x86::Mem scale_src, bias_src; constexpr unsigned int CACHE_LINE_LEN = 64; if (is_8bit_in) { if (scale_bias_last) { scale_src = x86::dword_ptr( input, scratchReg1_, 0, block_size * sizeof(uint8_t)); bias_src = x86::dword_ptr( input, scratchReg1_, 0, block_size * sizeof(uint8_t) + sizeof(float)); a->vbroadcastss(scale_vreg, scale_src); a->vbroadcastss(bias_vreg, bias_src); } else { scale_src = x86::word_ptr(input, scratchReg1_); bias_src = x86::word_ptr(input, scratchReg1_, 0, sizeof(uint16_t)); a->vpbroadcastw(scale_vreg.half(), scale_src); a->vpbroadcastw(bias_vreg.half(), bias_src); a->vcvtph2ps(scale_vreg, scale_vreg.half()); a->vcvtph2ps(bias_vreg, bias_vreg.half()); } if (pref_dist && fused_block_size % CACHE_LINE_LEN > 0 && fused_block_size % CACHE_LINE_LEN <= 2 * sizeof(float)) { a->prefetcht0(x86::dword_ptr( input, scratchReg2_, 0, fused_block_size / CACHE_LINE_LEN * CACHE_LINE_LEN)); } } if (has_weight && is_8bit_in) { a->vmulps(scale_vreg, scale_vreg, w_vreg); a->vmulps(bias_vreg, bias_vreg, w_vreg); } // The main computation int src_addr_offset = is_8bit_in && !scale_bias_last ? 2 * sizeof(uint16_t) : 0; for (int v = 0; v < cur_unroll_factor; ++v) { constexpr int BYTES_PER_VLOAD = vlen * sizeof(inType); auto src_addr = x86::dword_ptr( input, scratchReg1_, 0, src_addr_offset + (vec_idx + v) * BYTES_PER_VLOAD); vec_reg_t out_vreg = vec_reg_t(v); // For 8bit SLS convert usigned 8-bit to 32bit int, then to float // multiply with scale and then add with bias if (is_8bit_in) { if (remainder && vec_idx + v == num_vec_regs_per_block - 1 && instSet == inst_set_t::avx512) { a->k(x86::k(1)).z().vpmovzxbd(src_vreg, src_addr); } else { // We don't use a mask for AVX2 since we can use the extra // "padding" of the 2 floats (= 8 chars) scale and bias // this ensures we never access out of bound data a->vpmovzxbd(src_vreg, src_addr); } a->vcvtdq2ps(src_vreg, src_vreg); a->vaddps(out_vreg, out_vreg, bias_vreg); a->vfmadd231ps(out_vreg, src_vreg, scale_vreg); } else if (is_16bit_in) { if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if (instSet == inst_set_t::avx2) { if (remainder % 2 == 0) { a->vmaskmovps(src_vreg.xmm(), mask_fp16_vreg, src_addr); } else { a->vpbroadcastw( src_vreg.xmm(), x86::word_ptr( input, scratchReg1_, 0, src_addr_offset + (vec_idx + v) * BYTES_PER_VLOAD + (remainder - 1) * sizeof(inType))); if (remainder > 1) { // AVX2 can't do masking for the last 16-bit so we store // them to a stack and reload. // First put broadcasted last 16-bit element a->vmovups(x86::xmmword_ptr(x86::rsp), src_vreg.xmm()); // Mask store the remaining 16-bit elements a->vmaskmovps(src_vreg.xmm(), mask_fp16_vreg, src_addr); a->vmaskmovps( x86::xmmword_ptr(x86::rsp), mask_fp16_vreg, src_vreg.xmm()); // Load combined 16-bit elements a->vmovups(src_vreg.xmm(), x86::xmmword_ptr(x86::rsp)); } // remainder > 1 } // remainder % 2 if (is_fp16_in) { a->vcvtph2ps(src_vreg.ymm(), src_vreg.xmm()); } else if (is_bf16_in) { // bf16 a->vpmovzxwd(src_vreg.ymm(), src_vreg.xmm()); a->vpslld(src_vreg.ymm(), src_vreg.ymm(), 16); } } else { // avx512 if (is_fp16_in) { a->k(x86::k(1)).z().vcvtph2ps(src_vreg, src_addr); } else if (is_bf16_in) { // bf16 a->k(x86::k(1)).z().vpmovzxwd(src_vreg, src_addr); a->k(x86::k(1)).z().vpslld(src_vreg, src_vreg, 16); } } } else { // no remainder if (is_fp16_in) { a->vcvtph2ps(src_vreg, src_addr); } else if (is_bf16_in) { // bf16 a->vpmovzxwd(src_vreg, src_addr); a->vpslld(src_vreg, src_vreg, 16); } } if (has_weight) { a->vfmadd231ps(out_vreg, w_vreg, src_vreg); } else { a->vaddps(out_vreg, out_vreg, src_vreg); } } else { // This part for FP32 SLS if (remainder && vec_idx + v == num_vec_regs_per_block - 1 && instSet == inst_set_t::avx2) { a->vmaskmovps(src_vreg.ymm(), mask_vreg.ymm(), src_addr); } if (has_weight) { if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if (instSet == inst_set_t::avx2) { a->vfmadd231ps(out_vreg, w_vreg, src_vreg); } else { a->k(x86::k(1)).vfmadd231ps(out_vreg, w_vreg, src_addr); } } else { a->vfmadd231ps(out_vreg, w_vreg, src_addr); } } else { if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if (instSet == inst_set_t::avx2) { a->vaddps(out_vreg, out_vreg, src_vreg); } else { a->k(x86::k(1)).vaddps(out_vreg, out_vreg, src_addr); } } else { a->vaddps(out_vreg, out_vreg, src_addr); } } } constexpr int VLOAD_PER_CACHE_LINE = CACHE_LINE_LEN / BYTES_PER_VLOAD; if (pref_dist && (vec_idx + v) % VLOAD_PER_CACHE_LINE == 0) { a->prefetcht0(x86::dword_ptr( input, scratchReg2_, 0, (vec_idx + v) * BYTES_PER_VLOAD)); } } a->jmp(LoopDataIndexBegin); a->bind(LoopDataIndexEnd); // This loop is for writing back out_vreg (results) // back to memory for (int v = 0; v < cur_unroll_factor; ++v) { auto dst_addr = x86::dword_ptr(out, (vec_idx + v) * vlen * sizeof(outType)); vec_reg_t out_vreg = vec_reg_t(v); if (normalize_by_lengths) { a->vmulps(out_vreg, out_vreg, vlen_inv_vreg); } if (std::is_same::value) { if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if (instSet == inst_set_t::avx2) { a->vmaskmovps(dst_addr, mask_vreg, out_vreg.ymm()); } else { a->k(x86::k(1)).vmovups(dst_addr, out_vreg); } } else { a->vmovups(dst_addr, out_vreg); } } else { // fp16/bf16 output if (instSet == inst_set_t::avx2) { // round nearest with no exception if (is_fp16_out) { a->vcvtps2ph(out_vreg.xmm(), out_vreg, 8); } else if (is_bf16_out) { a->vpaddd(out_vreg, out_vreg, ones_vreg); a->vpsrld(out_vreg, out_vreg, 16); a->vpackusdw(out_vreg, out_vreg, out_vreg); a->vpermq(out_vreg, out_vreg, 0xd8); } if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if (remainder > 1) { a->vmaskmovps(dst_addr, mask_fp16_vreg, out_vreg.xmm()); } if (remainder % 2 != 0) { a->vmovups(x86::xmmword_ptr(x86::rsp), out_vreg.xmm()); a->mov( scratchReg1_.r16(), x86::word_ptr( x86::rsp, (remainder - 1) * sizeof(outType))); a->mov( x86::word_ptr( out, ((vec_idx + v) * vlen + (remainder - 1)) * sizeof(outType)), scratchReg1_.r16()); } } else { a->vmovups(dst_addr, out_vreg.xmm()); } } else { if (remainder && vec_idx + v == num_vec_regs_per_block - 1) { if (is_fp16_out) { a->k(x86::k(1)).vcvtps2ph(dst_addr, out_vreg, 8); } else if (is_bf16_out) { // bf16 a->k(x86::k(1)).vpaddd(out_vreg, out_vreg, ones_vreg); a->k(x86::k(1)).vpsrld(out_vreg, out_vreg, 16); a->k(x86::k(1)).vpmovdw(dst_addr, out_vreg); } } else { if (is_fp16_out) { a->vcvtps2ph(dst_addr, out_vreg, 8); } else if (is_bf16_out) { // bf16 a->vpaddd(out_vreg, out_vreg, ones_vreg); a->vpsrld(out_vreg, out_vreg, 16); a->vpmovdw(dst_addr, out_vreg); } } } } } if (vec_idx + unroll_factor < num_vec_regs_per_block || (has_weight && is_weight_positional)) { // Reset lengths_R_, indices, weights to run the dataIndex loop // again if (use_offsets) { a->mov(lengths_R_, x86::dword_ptr(lengths, sizeof(offsetType))); a->sub(lengths_R_, x86::dword_ptr(lengths)); } else { a->mov(lengths_R_, x86::dword_ptr(lengths)); } if (has_weight) { a->imul( scratchReg1_, lengths_R_, static_cast(sizeof(float))); a->sub(weights, scratchReg1_); if (vec_idx + unroll_factor < num_vec_regs_per_block) { a->imul( scratchReg1_, static_cast(sizeof(indxType) / sizeof(float))); a->sub(indices, scratchReg1_); } } else { a->imul( scratchReg1_, lengths_R_, static_cast(sizeof(indxType))); a->sub(indices, scratchReg1_); } } } a->add(lengths, static_cast(sizeof(offsetType))); a->add(out, static_cast(output_stride * sizeof(outType))); a->jmp(LoopRangeIndexBegin); a->bind(LoopRangeIndexEnd); a->cmp(indices, index_size); a->jne(error); a->mov(x86::eax, true); a->jmp(exit); a->bind(error); a->mov(x86::eax, false); a->bind(exit); if (remainder && instSet == inst_set_t::avx2 && (is_16bit_in || is_bf16_out || is_fp16_out)) { a->lea(x86::rsp, x86::ymmword_ptr(x86::rsp, vlen * sizeof(int32_t))); } a->emitEpilog(frame); // jit_fused8bitembedding_kernel fn; typename ReturnFunctionSignature< inType, indxType, offsetType, outType, ROWWISE_SPARSE>::jit_embedding_kernel fn; asmjit::Error err; { std::unique_lock lock(rtMutex_); err = runtime().add(&fn, &code); } if (err) { std::cout << "Error: in fn add" << std::endl; return nullptr; } #if defined(FBGEMM_LOG_CODE) fclose(codeLogFile); delete codeLogger; #endif return fn; }); } } // namespace template < typename inType, typename indxType, typename offsetType, typename outType, bool THREAD_LOCAL> typename EmbeddingSpMDMKernelSignature:: Type GenerateEmbeddingSpMDMWithStrides( const int64_t block_size, [[maybe_unused]] bool has_weight, bool normalize_by_lengths, [[maybe_unused]] int prefetch, bool is_weight_positional, bool use_offsets, int64_t output_stride /*=-1*/, int64_t input_stride /*=-1*/, bool scale_bias_last /*=true*/, bool no_bag /*=false*/, bool is_bf16_out /*=false*/, bool is_bf16_in /*=false*/) { #if defined(__APPLE__) || defined(_WIN32) if (std::is_same::value && is_bf16_in && std::is_same::value) { throw std::runtime_error( "Bfloat16 input with float32 output is not yet supported on Apple or Windows"); } #endif if (output_stride == -1) { output_stride = block_size; } if (input_stride == -1) { if (std::is_same::value) { const auto scale_bias_offset = 2 * (scale_bias_last ? sizeof(float) : sizeof(uint16_t)); input_stride = block_size + scale_bias_offset; } else { input_stride = block_size; } } #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 if (!no_bag) { if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } const inst_set_t isa = fbgemmInstructionSet(); if ((std::is_same::value || std::is_same::value) && block_size == 1 && isYmm(isa) && output_stride == block_size && input_stride == block_size && std::is_same::value && !is_asmjit_disabled()) { return [=](int64_t output_size, int64_t index_size, int64_t data_size, const inType* input, const indxType* indices, const offsetType* offsets_or_lengths, const float* weights, // optional, can be null for non-weighted sum outType* out) { return internal::EmbeddingSpMDMBlockSize1_( output_size, index_size, data_size, input, indices, offsets_or_lengths, weights, normalize_by_lengths, reinterpret_cast(out), is_weight_positional, use_offsets, is_bf16_out); }; } if (isZmm(isa) && !is_asmjit_disabled()) { static GenEmbeddingSpMDMLookup< inType, indxType, offsetType, outType, inst_set_t::avx512, /*ROWWISE_SPARSE=*/false, THREAD_LOCAL> kernel_generator; const auto original_func = kernel_generator.getOrCreate( block_size, has_weight, is_weight_positional, normalize_by_lengths, prefetch, use_offsets, output_stride, input_stride, scale_bias_last, is_bf16_out, is_bf16_in); return [=](int64_t output_size, int64_t index_size, int64_t data_size, const inType* input, const indxType* indices, const offsetType* offsets_or_lengths, const float* weights, outType* out) { return original_func( output_size, index_size, data_size, input, indices, offsets_or_lengths, weights, out, nullptr /* mask not used in avx512 */); }; } if (isYmm(isa) && !is_asmjit_disabled()) { static GenEmbeddingSpMDMLookup< inType, indxType, offsetType, outType, inst_set_t::avx2, /*ROWWISE_SPARSE=*/false, THREAD_LOCAL> kernel_generator; const auto original_func = kernel_generator.getOrCreate( block_size, has_weight, is_weight_positional, normalize_by_lengths, prefetch, use_offsets, output_stride, input_stride, scale_bias_last, is_bf16_out, is_bf16_in); return [=](int64_t output_size, int64_t index_size, int64_t data_size, const inType* input, const indxType* indices, const offsetType* offsets_or_lengths, const float* weights, outType* out) { return original_func( output_size, index_size, data_size, input, indices, offsets_or_lengths, weights, out, internal::avx2_ps_or_epi32_combined_mask); }; } } #endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 #ifdef FBGEMM_AUTOVEC_AVAILABLE if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } if ((is_autovec_forced() || fbgemmHasArmSve2Support()) && !is_autovec_disabled()) { return GenerateEmbeddingSpMDMWithStrides_autovec< /*InType=*/inType, /*IndexType=*/indxType, /*OffsetType=*/offsetType, /*OutType=*/outType>( /*block_size=*/block_size, /*has_weight=*/has_weight, /*normalize_by_lengths=*/normalize_by_lengths, /*prefetch=*/prefetch, /*is_weight_positional=*/is_weight_positional, /*use_offsets=*/use_offsets, /*output_stride=*/output_stride, /*input_stride=*/input_stride, /*scale_bias_last=*/scale_bias_last, /*no_bag=*/no_bag, /*is_bf16_out=*/is_bf16_out, /*is_bf16_in=*/is_bf16_in); } #endif #ifdef VLOG VLOG(0) << "AVX2 or AVX512 not found, taking the slow path"; #endif return [=](int64_t output_size, int64_t index_size, int64_t data_size, const inType* input, const indxType* indices, const offsetType* offsets_or_lengths, const float* weights, outType* out) { return EmbeddingSpMDM_ref( block_size, output_size, index_size, data_size, input, indices, offsets_or_lengths, weights, normalize_by_lengths, out, is_weight_positional, use_offsets, output_stride, input_stride, scale_bias_last, no_bag, is_bf16_out, is_bf16_in); }; } template < typename inType, typename indxType, typename offsetType, typename outType, bool THREAD_LOCAL> typename EmbeddingSpMDMKernelSignature:: Type GenerateEmbeddingSpMDM( const int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch, bool is_weight_positional, bool use_offsets, bool is_bf16_out, bool is_bf16_in) { return GenerateEmbeddingSpMDMWithStrides< inType, indxType, offsetType, outType, THREAD_LOCAL>( block_size, has_weight, normalize_by_lengths, prefetch, is_weight_positional, use_offsets, /*output_stride=*/-1, /*input_stride=*/-1, /*scale_bias_last=*/true, /*no_bag=*/false, is_bf16_out, is_bf16_in); } template typename EmbeddingSpMDMKernelSignature:: Type GenerateEmbeddingSpMDMFP8WithStrides( const int64_t block_size, bool normalize_by_lengths, bool is_weight_positional, bool use_offsets, int64_t output_stride /*=-1*/, int64_t input_stride /*=-1*/, int exponent_bits, int exponent_bias, bool is_bf16_out) { if (output_stride == -1) { output_stride = block_size; } if (input_stride == -1) { input_stride = block_size; } #ifdef FBGEMM_AUTOVEC_AVAILABLE if (!is_autovec_disabled()) { // There is only the reference implementation for FP8 embedding return GenerateEmbeddingSpMDMFP8WithStrides_autovec< /*IndexType=*/indxType, /*OffsetType=*/offsetType, /*OutType=*/outType>( /*block_size=*/block_size, /*normalize_by_lengths=*/normalize_by_lengths, /*is_weight_positional=*/is_weight_positional, /*use_offsets=*/use_offsets, /*output_stride=*/output_stride, /*input_stride=*/input_stride, /*exponent_bits=*/exponent_bits, /*exponent_bias=*/exponent_bias, /*is_bf16_out=*/is_bf16_out); } #endif // There is only the reference implementation for FP8 embedding return [=](int64_t output_size, int64_t index_size, int64_t data_size, const uint8_t* input, const indxType* indices, const offsetType* offsets_or_lengths, const float* weights, outType* out) { return EmbeddingSpMDMFP8_ref( // changed this from ref -> autovec. block_size, output_size, index_size, data_size, input, indices, offsets_or_lengths, weights, normalize_by_lengths, out, is_weight_positional, use_offsets, output_stride, input_stride, exponent_bits, exponent_bias, is_bf16_out); }; } template typename EmbeddingSpMDMRowWiseSparseKernelSignature< inType, indxType, offsetType>::Type GenerateEmbeddingSpMDMRowWiseSparse( const int64_t block_size, bool has_weight, bool normalize_by_lengths, int prefetch, bool is_weight_positional, bool use_offsets) { #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 int64_t input_stride = block_size; if (std::is_same::value) { const auto scale_bias_offset = 2 * sizeof(float); input_stride = block_size + scale_bias_offset; } if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } inst_set_t isa = fbgemmInstructionSet(); if (isZmm(isa)) { static GenEmbeddingSpMDMLookup< inType, indxType, offsetType, /*outType=*/float, inst_set_t::avx512, /*rowwise_sparse=*/true> kernel_generator; const auto original_func = kernel_generator.getOrCreate( block_size, has_weight, is_weight_positional, normalize_by_lengths, prefetch, use_offsets, /*output_stride=*/block_size, input_stride, /*scale_bias_last=*/true, /*is_bf16_out=*/false, /*is_bf16_in=*/false); return [=](int64_t output_size, int64_t index_size, int64_t uncompressed_data_size, const inType* input, const indxType* indices, const offsetType* offsets_or_lengths, const float* weights, float* out, const int32_t* compressed_indices_table) { return original_func( output_size, index_size, uncompressed_data_size, input, indices, offsets_or_lengths, weights, out, compressed_indices_table, nullptr /* mask not used in avx512 */); }; } if (isYmm(isa)) { static GenEmbeddingSpMDMLookup< inType, indxType, offsetType, /*outType=*/float, inst_set_t::avx2, /*rowwise_sparse=*/true> kernel_generator; const auto original_func = kernel_generator.getOrCreate( block_size, has_weight, is_weight_positional, normalize_by_lengths, prefetch, use_offsets, /*output_stride=*/block_size, input_stride, /*scale_bias_last=*/true, /*is_bf16_out=*/false, /*is_bf16_in=*/false); return [=](int64_t output_size, int64_t index_size, int64_t uncompressed_data_size, const inType* input, const indxType* indices, const offsetType* offsets_or_lengths, const float* weights, float* out, const int32_t* compressed_indices_table) { return original_func( output_size, index_size, uncompressed_data_size, input, indices, offsets_or_lengths, weights, out, compressed_indices_table, internal::avx2_ps_or_epi32_combined_mask); }; } #endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 #ifdef VLOG VLOG(0) << "AVX2 or AVX512 not found, taking the slow path"; #endif #ifdef FBGEMM_AUTOVEC_AVAILABLE if (is_autovec_forced()) { return GenerateEmbeddingSpMDMRowWiseSparse_autovec< /*InType=*/inType, /*IndexType=*/indxType, /*OffsetType=*/offsetType>( /*block_size=*/block_size, /*has_weight=*/has_weight, /*normalize_by_lengths=*/normalize_by_lengths, /*prefetch=*/prefetch, /*is_weight_positional=*/is_weight_positional, /*use_offsets=*/use_offsets); } #endif return [=](int64_t output_size, int64_t index_size, int64_t uncompressed_data_size, const inType* input, const indxType* indices, const offsetType* offsets_or_lengths, const float* weights, // optional, can be null for non-weighted sum float* out, const int32_t* compressed_indices_table) { return EmbeddingSpMDMRowWiseSparse_ref( block_size, output_size, index_size, uncompressed_data_size, // compressed_data_size, input, indices, compressed_indices_table, offsets_or_lengths, weights, normalize_by_lengths, out, is_weight_positional, use_offsets); }; } #define INSTANTIATE_SPMDM_BASE( \ IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, THREAD_LOCAL) \ template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \ IN_TYPE, \ INDEX_TYPE, \ OFFSET_TYPE, \ OUT_TYPE>::Type \ GenerateEmbeddingSpMDMWithStrides< \ IN_TYPE, \ INDEX_TYPE, \ OFFSET_TYPE, \ OUT_TYPE, \ THREAD_LOCAL>( \ const int64_t block_size, \ bool has_weight, \ bool normalize_by_lengths, \ int prefetch, \ bool is_weight_positional, \ bool use_offsets, \ int64_t output_stride, \ int64_t input_stride, \ bool scale_bias_last, \ bool no_bag, \ bool is_bf16_out, \ bool is_bf16_in); #define INSTANTIATE_SPMDMFP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \ uint8_t, \ INDEX_TYPE, \ OFFSET_TYPE, \ OUT_TYPE>::Type \ GenerateEmbeddingSpMDMFP8WithStrides( \ const int64_t block_size, \ bool normalize_by_lengths, \ bool is_weight_positional, \ bool use_offsets, \ int64_t output_stride, \ int64_t input_stride, \ int exponent_bits, \ int exponent_bias, \ bool is_bf16_out); #define INSTANTIATE_SPMDM_NOSTRIDE_BASE( \ IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, THREAD_LOCAL) \ template FBGEMM_API typename EmbeddingSpMDMKernelSignature< \ IN_TYPE, \ INDEX_TYPE, \ OFFSET_TYPE, \ OUT_TYPE>::Type \ GenerateEmbeddingSpMDM< \ IN_TYPE, \ INDEX_TYPE, \ OFFSET_TYPE, \ OUT_TYPE, \ THREAD_LOCAL>( \ const int64_t block_size, \ bool has_weight, \ bool normalize_by_lengths, \ int prefetch, \ bool is_weight_positional, \ bool use_offsets, \ bool is_bf16_out, \ bool is_bf16_in); #define INSTANTIATE_SPMDM_ROWWISE_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \ template FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature< \ IN_TYPE, \ INDEX_TYPE, \ OFFSET_TYPE>::Type \ GenerateEmbeddingSpMDMRowWiseSparse( \ const int64_t block_size, \ bool has_weight, \ bool normalize_by_lengths, \ int prefetch, \ bool is_weight_positional, \ bool use_offsets); #define INSTANTIATE_SPMDMFP8_BASE_uint8_t(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ INSTANTIATE_SPMDMFP8_BASE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) #define INSTANTIATE_SPMDMFP8_BASE_float(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) #define INSTANTIATE_SPMDMFP8_BASE_uint16_t(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) #define INSTANTIATE_SPMDM_BASE_THREAD_LOCAL( \ IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, true) \ INSTANTIATE_SPMDM_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) #define INSTANTIATE_SPMDM_NON_BASE_THREAD_LOCAL( \ IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) \ INSTANTIATE_SPMDM_NOSTRIDE_BASE( \ IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, true) \ INSTANTIATE_SPMDM_NOSTRIDE_BASE( \ IN_TYPE, INDEX_TYPE, OFFSET_TYPE, OUT_TYPE, false) \ INSTANTIATE_SPMDMFP8_BASE_##IN_TYPE(INDEX_TYPE, OFFSET_TYPE, OUT_TYPE) #define INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) \ INSTANTIATE_SPMDM_BASE_THREAD_LOCAL(IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \ INSTANTIATE_SPMDM_BASE_THREAD_LOCAL( \ IN_TYPE, INDEX_TYPE, OFFSET_TYPE, uint16_t) \ INSTANTIATE_SPMDM_BASE_THREAD_LOCAL( \ IN_TYPE, INDEX_TYPE, OFFSET_TYPE, uint8_t) \ INSTANTIATE_SPMDM_NON_BASE_THREAD_LOCAL( \ IN_TYPE, INDEX_TYPE, OFFSET_TYPE, float) \ INSTANTIATE_SPMDM_NON_BASE_THREAD_LOCAL( \ IN_TYPE, INDEX_TYPE, OFFSET_TYPE, uint16_t) \ INSTANTIATE_SPMDM_ROWWISE_BASE(IN_TYPE, INDEX_TYPE, OFFSET_TYPE) #define INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, INDEX_TYPE) \ INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, int32_t) \ INSTANTIATE_SPMDM_OUT_T(IN_TYPE, INDEX_TYPE, int64_t) #define INSTANTIATE_SPMDM_INDEX_T(IN_TYPE) \ INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, int32_t) \ INSTANTIATE_SPMDM_OFFSET_T(IN_TYPE, int64_t) INSTANTIATE_SPMDM_INDEX_T(float) INSTANTIATE_SPMDM_INDEX_T(uint16_t) INSTANTIATE_SPMDM_INDEX_T(uint8_t) #undef INSTANTIATE_SPMDM_INDEX_T #undef INSTANTIATE_SPMDM_OFFSET_T #undef INSTANTIATE_SPMDM_OUT_T #undef INSTANTIATE_SPMDM_THREAD_LOCAL #undef INSTANTIATE_SPMDM_BASE #undef INSTANTIATE_SPMDMFP8_BASE #undef INSTANTIATE_SPMDM_NOSTRIDE_BASE #undef INSTANTIATE_SPMDM_ROWWISE_BASE template void compressed_indices_remap( std::int32_t offsets_len, const IndexType* indices, const int32_t* compressed_indices_mapping, const IndexType* offsets, const float* weights, // optional, can be null, IndexType* out_indices, IndexType* out_offsets, float* out_weights) { #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } #ifndef NO_AVX512 const inst_set_t isa = fbgemmInstructionSet(); if (isZmm(isa)) { #ifndef USE_ROCM if (weights == nullptr) { internal::compressed_indices_remap_avx512( offsets_len, indices, compressed_indices_mapping, offsets, weights, out_indices, out_offsets, out_weights); return; } else { internal::compressed_indices_remap_avx512( offsets_len, indices, compressed_indices_mapping, offsets, weights, out_indices, out_offsets, out_weights); return; } #endif // USE_ROCM } #endif // NO_AVX512 #endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 // Non-vectorized fallback implementation compressed_indices_remap_ref( offsets_len, indices, compressed_indices_mapping, offsets, weights, out_indices, out_offsets, out_weights); } #define INSTANTIATE_REMAP_BASE(INDEX_TYPE) \ template FBGEMM_API void compressed_indices_remap( \ std::int32_t offsets_numel, \ const INDEX_TYPE* indices, \ const int32_t* compressed_indices_mapping, \ const INDEX_TYPE* offsets, \ const float* weights, \ INDEX_TYPE* out_indices, \ INDEX_TYPE* out_offsets, \ float* out_weights); INSTANTIATE_REMAP_BASE(int32_t) INSTANTIATE_REMAP_BASE(int64_t) #undef INSTANTIATE_REMAP_BASE } // namespace fbgemm