/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include "bench/BenchUtils.h" #include "fbgemm/FbgemmSparse.h" #include "fbgemm/Utils.h" #include "fbgemm/spmmUtils.h" #include "src/RefImplementations.h" using namespace std; using namespace fbgemm; vector qGranularityVals{ QuantizationGranularity::TENSOR, QuantizationGranularity::OUT_CHANNEL}; // tuple represents M, N, K, fnz, fuse_relu and QuantizationGranularity class SPMMInt8Test : public testing::TestWithParam< tuple> {}; INSTANTIATE_TEST_CASE_P( InstantiationName, SPMMInt8Test, ::testing::Combine( ::testing::ValuesIn({1, 2, 3, 4, 7, 13, 16, 20, 32}), // M ::testing::ValuesIn({1, 2, 3, 4, 7, 13, 16, 20, 32}), // N ::testing::ValuesIn( {1, 2, 3, 4, 7, 8, 14, 24, 4000, 4001, 4096, 5000}), // K ::testing::ValuesIn({0.1f, 0.2f, 0.3f}), // fnz ::testing::Bool(), // fuse relu ::testing::ValuesIn(qGranularityVals))); // QuantizationGranularity /** * Test for sparse-dense matrix-matrix multiplication (int8) */ TEST_P(SPMMInt8Test, spInt8) { int M, N, K; float fnz; bool fuse_relu; QuantizationGranularity qGran; tie(M, N, K, fnz, fuse_relu, qGran) = GetParam(); auto aData = getRandomBlockSparseMatrix( M, K, 1.0, 1 /* rowBlockSize */, 1 /* colBlockSize */); auto bData = getRandomBlockSparseMatrix(K, N, fnz); auto cData = getRandomBlockSparseMatrix( M, N, 1.0, 1 /* rowBlockSize */, 1 /* colBlockSize */); aligned_vector atData(K * M); aligned_vector btData(N * K); aligned_vector ctDataRef(N * M, 5); aligned_vector ctDataRef_u8(N * M, 7); aligned_vector ctDataIntrin_i32(N * M, 9); aligned_vector ctDataIntrin_u8(N * M, 11); transpose_matrix(M, K, aData.data(), K, atData.data(), M); transpose_matrix(K, N, bData.data(), N, btData.data(), K); unique_ptr> bcsr = fbgemmDenseToBCSR(N, K, btData.data()); // output scale and zero point float scale = 128.0f; int32_t zero_point = 2; int32_t act_zero_point = 2; // symmetric quant for weights aligned_vector weight_zero_point(N); randFill(weight_zero_point, 0, 0); // Each row of weight matrix has it's own scale // The following is a multiplication activation scale with // weight scales. aligned_vector act_times_w_scale(N); randFill(act_times_w_scale, -8.0f, 8.0f); trRequantizationParams_t reqParams = { act_zero_point, weight_zero_point.data(), zero_point, scale, bcsr->row_offsets.data(), nullptr, nullptr, act_times_w_scale.data()}; int ldat = M; int ldct = M; if (fuse_relu) { if (qGran == QuantizationGranularity::TENSOR) { fbgemmSparseDenseInt8MM( M, bcsr, atData.data(), ldat, ctDataIntrin_i32.data(), ctDataIntrin_u8.data(), ldct, reqParams); } else { fbgemmSparseDenseInt8MM( M, bcsr, atData.data(), ldat, ctDataIntrin_i32.data(), ctDataIntrin_u8.data(), ldct, reqParams); } } else { if (qGran == QuantizationGranularity::TENSOR) { fbgemmSparseDenseInt8MM( M, bcsr, atData.data(), ldat, ctDataIntrin_i32.data(), ctDataIntrin_u8.data(), ldct, reqParams); } else { fbgemmSparseDenseInt8MM( M, bcsr, atData.data(), ldat, ctDataIntrin_i32.data(), ctDataIntrin_u8.data(), ldct, reqParams); } } matmul_u8i8acc32_ref( M, N, K, K, // lda N, // ldb N, // ldc aData.data(), bData.data(), cData.data()); transpose_matrix(M, N, cData.data(), N, ctDataRef.data(), M); // ctDataRef is nxm block_type_t block{0, N, 0, M}; if (fuse_relu) { if (qGran == QuantizationGranularity::TENSOR) { trRequantizeRef( ctDataRef_u8.data(), ctDataRef.data(), block, M, M, reqParams); } else { trRequantizeRef( ctDataRef_u8.data(), ctDataRef.data(), block, M, M, reqParams); } } else { if (qGran == QuantizationGranularity::TENSOR) { trRequantizeRef( ctDataRef_u8.data(), ctDataRef.data(), block, M, M, reqParams); } else { trRequantizeRef( ctDataRef_u8.data(), ctDataRef.data(), block, M, M, reqParams); } } // printMatrix(matrix_op_t::NoTranspose, ctDataRef_u8.data(), n, m, m, // "ctDataRef_u8"); // printMatrix(matrix_op_t::NoTranspose, ctDataIntrin_u8.data(), n, m, m, // "ctDataIntrin_u8"); // // Compare results for (size_t i = 0; i < ctDataRef.size(); i++) { EXPECT_EQ(ctDataRef_u8[i], ctDataIntrin_u8[i]) << "Results differ ref " << ctDataRef_u8[i] << " and test " << ctDataIntrin_u8[i] << " at " << i; } }