/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include #include #include #include #include #include #include "./CodeCache.h" #include "fbgemm/Fbgemm.h" #include "fbgemm/SimdUtils.h" // #define FBGEMM_LOG_CODE 1 namespace fbgemm { namespace x86 = asmjit::x86; /** * @brief Generate instructions for initializing the C registers to 0. */ void initCRegs(x86::Emitter* a, int rowRegs, int colRegs); /** * @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator. * @tparam TA Type of matrix A. * @tparam TB Type of matrix B. * @tparam TC Type of matrix C. * @tparam accT Accumulation type, currently we support 16-bit (std::int16_t) or * 32-bit (std::int32_t) accumulation. */ template class CodeGenBase { public: using jit_micro_kernel_fp = void (*)( const TA* bufferA, const TB* bufferB, const TB* b_pf, TC* bufferC, int kc, int ldc); /** * @brief Constructor for initializing AVX2/AVX512 registers. */ CodeGenBase(const BlockingFactors* params = nullptr) : blocking_params(params) {} /** * @brief Get or Create the instructions for macro-kernel. * * If the problem size (mc, nc) and accumulation flag (accum) can be found in * the code cache (a hash map), then get the macro-kernel instructions * directly from it. Otherwise, create the instructions for macro-kernel, and * store that into the code cache. */ template jit_micro_kernel_fp getOrCreate(bool accum, int32_t mc, int32_t nc, int32_t kc); /** * @brief Generate instructions for computing block in the rank-k update. */ template void genComputeBlock( x86::Emitter* a, x86::Gp buffer_A, x86::Gp buffer_B, x86::Gp B_pf, int rowRegs, int colRegs, int lda); /** * @brief Generate instructions for storing the C registers back to the * memory. */ template void storeCRegs( x86::Emitter* a, int rowRegs, int colRegs, x86::Gp C_Offset, x86::Gp ldcReg, bool accum); const BlockingFactors* blocking_params; /** * @brief Generate filename to dump generated code * (debug-only) */ template static std::string getCodeLoggingFile( bool accum, int mc, int nc, int NCB, int KCB, int MR, int NR) { std::ostringstream oss; oss << "gemm_"; if (std::is_same::value) { oss << "acc16_"; } else if (std::is_same::value) { oss << "acc32_"; } else { oss << "unknown_"; } oss << "accum-" + std::to_string(accum) << "_MC-" + std::to_string(mc) << "_NC-" + std::to_string(nc) << "_NCB-" + std::to_string(NCB) << "_KCB-" + std::to_string(KCB) << "_MR-" + std::to_string(MR) << "_NR-" + std::to_string(NR); if (instSet == inst_set_t::avx512_vnni) { oss << "_avx512vnni"; } else if (instSet == inst_set_t::avx512) { oss << "_avx512"; } else if (instSet == inst_set_t::avx512_ymm) { oss << "_avx512_ymm"; } else if (instSet == inst_set_t::avx2) { oss << "_avx2"; } oss << ".txt"; return oss.str(); } private: static asmjit::JitRuntime& runtime() { static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, // depents on other static // variables. Required to prevent // initialization order fiasco return rt; } static std::mutex rtMutex_; ///< Controll access to runtime; // The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr static CodeCache< std::tuple, jit_micro_kernel_fp> codeCache_; ///< JIT Code Cache for reuse. }; template std::mutex CodeGenBase::rtMutex_; template CodeCache< std::tuple, typename CodeGenBase::jit_micro_kernel_fp> CodeGenBase::codeCache_; } // namespace fbgemm