/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #define FBGEMM_EXPORTS #include "fbgemm/FbgemmSparse.h" #include #include #include #include #include #include #include #include #include "fbgemm/Utils.h" #include "fbgemm/spmmUtils.h" using namespace std; namespace fbgemm { template FBGEMM_API std::unique_ptr> fbgemmDenseToCSR(int R, int C, const T* inp, int ld) { unique_ptr> csr(new CSRMatrix()); csr->rowPtr.push_back(0); int nnz = 0; for (int i = 0; i < R; ++i) { for (int j = 0; j < C; ++j) { if (inp[i * ld + j] != 0) { csr->values.push_back(inp[i * ld + j]); csr->colIdx.push_back(j); nnz++; } } csr->rowPtr.push_back(nnz); } return csr; } template std::unique_ptr> fbgemmDenseToCSR(int R, int C, const T* inp) { return fbgemmDenseToCSR(R, C, inp, C); } template FBGEMM_API std::unique_ptr> fbgemmDenseToCSR(int R, int C, const int8_t* inp); template FBGEMM_API std::unique_ptr> fbgemmDenseToCSR(int R, int C, const float* inp); template FBGEMM_API std::unique_ptr> fbgemmDenseToCSR(int R, int C, const int8_t* inp, int ld); template FBGEMM_API std::unique_ptr> fbgemmDenseToCSR(int R, int C, const float* inp, int ld); template FBGEMM_API std::unique_ptr> fbgemmDenseToBCSR(int R, int C, const T* inp, int ld) { unique_ptr> bcsr(new BCSRMatrix(R, C)); bcsr->pack(inp, ld); return bcsr; } template FBGEMM_API std::unique_ptr> fbgemmDenseToBCSR(int R, int C, const T* inp) { return fbgemmDenseToBCSR(R, C, inp, C); } #if __cplusplus < 201703L template constexpr int BCSRMatrix::RB; template constexpr int BCSRMatrix::CB; template constexpr int BCSRMatrix::COLTILE; #endif template void BCSRMatrix::pack(const DTYPE* src, size_t ld) { rowBPtr.push_back(0); int nnzb = 0; int numCOLTILEs = (C + COLTILE - 1) / COLTILE; int rowBlocks = (R + RB - 1) / RB; for (int jt = 0; jt < numCOLTILEs; ++jt) { for (int i = 0; i < rowBlocks; ++i) { int curCols = min(C - jt * COLTILE, COLTILE); int curColBlocks = (curCols + CB - 1) / CB; std::array rowSum = {0}; for (int j = 0; j < curColBlocks; ++j) { // is the whole block zero? bool isCurrentBlockNonZero = false; for (int ib = 0; ib < RB; ++ib) { // break if already found a non-zero element or // out of bounds if (isCurrentBlockNonZero || (i * RB + ib) >= R) { break; } for (int jb = 0; jb < CB; ++jb) { // within bound? if ((jt * COLTILE + j * CB + jb) >= C) { continue; } else { if (src[(i * RB + ib) * ld + jt * COLTILE + j * CB + jb] != 0) { isCurrentBlockNonZero = true; break; } } } } if (isCurrentBlockNonZero) { for (int ib = 0; ib < RB; ++ib) { for (int jb = 0; jb < CB; ++jb) { if ((i * RB + ib) >= R || (jt * COLTILE + j * CB + jb) >= C) { // zero fill values.push_back(0); } else { DTYPE val = src[(i * RB + ib) * ld + jt * COLTILE + j * CB + jb]; values.push_back(val); rowSum[ib] += static_cast(val); } } } colBIdx.push_back(j); nnzb++; } } rowBPtr.push_back(nnzb); // Note: in row_offsets we don't need to subtract the constant term // weight_zero_point * C because it's 0 as weight_zero_point is always 0 // for sparse kernels. for (int ib = 0; ib < RB; ++ib) { if (jt) { row_offsets[i * RB + ib] += rowSum[ib]; } else { row_offsets[i * RB + ib] = rowSum[ib]; } } } } } template void BCSRMatrix::pack(const DTYPE* src) { pack(src, C); } template void BCSRMatrix::unpack(T* dst, size_t ld) { // zero out destination memset(dst, 0, R * C * sizeof(T)); int numCOLTILEs = (C + COLTILE - 1) / COLTILE; int rowBlocks = (R + RB - 1) / RB; for (int jt = 0; jt < numCOLTILEs; ++jt) { for (int i = 0; i < rowBlocks; ++i) { // For the current tile, rowBPtr starts from currentTileIdx (i.e., jt) * R for (int r = rowBPtr[jt * R + i]; r < rowBPtr[jt * R + i + 1]; ++r) { int curColIdx = colBIdx[r]; for (int ib = 0; ib < RB; ++ib) { for (int jb = 0; jb < CB; ++jb) { // Are we within bounds of destination matrix? if ((i * RB + ib) < R && (jt * COLTILE + curColIdx * CB + jb) < C) { dst[(i * RB + ib) * ld + jt * COLTILE + curColIdx * CB + jb] = values[r * RB * CB + ib * CB + jb]; } } } } } } } template void BCSRMatrix::unpack(T* dst) { unpack(dst, C); } template struct BCSRMatrix; template struct CSRMatrix; template struct CSRMatrix; template FBGEMM_API std::unique_ptr> fbgemmDenseToBCSR(int R, int C, const int8_t* inp); template FBGEMM_API std::unique_ptr> fbgemmDenseToBCSR(int R, int C, const int8_t* inp, int ld); void SparseDenseMM( int M, int N, const int* row_ptr, const int* col_idx, const float* values, const float* B, int ldb, float* C, int ldc, bool accum) { static const auto iset = fbgemmInstructionSet(); // Run time CPU detection if (isZmm(iset)) { internal::SparseDenseMMAvx512( M, N, row_ptr, col_idx, values, B, ldb, C, ldc, accum); } else if (isYmm(iset)) { internal::SparseDenseMMAvx2( M, N, row_ptr, col_idx, values, B, ldb, C, ldc, accum); } else { sparseDenseMMRef(M, N, row_ptr, col_idx, values, B, ldb, C, ldc, accum); } } template FBGEMM_API void fbgemmSparseDenseInt8MM( int N, const std::unique_ptr>& bcsr, const uint8_t* B, int ldb, int32_t* C_i32, uint8_t* C_u8, int ldc, trRequantizationParams_t& rParams, bool accum, int thread_id, int num_threads) { static const auto iset = fbgemmInstructionSet(); // No parallelization currently // All work is done by thread 0 if (thread_id > 0) { return; } // Run time CPU detection if (isZmm(iset)) { internal::SparseDenseInt8MMAvx512( N, bcsr, B, ldb, C_i32, C_u8, ldc, rParams, accum, thread_id, num_threads); } else if (isYmm(iset)) { internal::SparseDenseInt8MMAvx2( N, bcsr, B, ldb, C_i32, C_u8, ldc, rParams, accum, thread_id, num_threads); } else { sparseDenseInt8MMRef( N, bcsr, B, ldb, C_i32, C_u8, ldc, rParams, accum, thread_id, num_threads); } } #define CREATE_INSTANCE(FUSE_RELU, QGRAN) \ template FBGEMM_API void fbgemmSparseDenseInt8MM( \ int N, \ const std::unique_ptr>& bcsr, \ const uint8_t* B, \ int ldb, \ int32_t* C_i32, \ uint8_t* C_u8, \ int ldc, \ trRequantizationParams_t& rParams, \ bool accum, \ int thread_id, \ int num_threads); CREATE_INSTANCE(true, QuantizationGranularity::TENSOR) CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL) CREATE_INSTANCE(false, QuantizationGranularity::TENSOR) CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL) #undef CREATE_INSTANCE } // namespace fbgemm