/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include "bench/AlignedVec.h" #include "bench/BenchUtils.h" #include "fbgemm/Fbgemm.h" #include "fbgemm/FbgemmI8DepthwiseAvx2.h" #include "src/DirectConv.h" #include "src/OptimizedKernelsAvx2.h" #include "src/RefImplementations.h" using namespace std; namespace fbgemm { // From Xray OCR // clang-format off // conv_param_t<>(N, IC, OC, H, W, G, // /* kern */ {kernel1, kernel2}, /* stride */ {stride1, stride2}, /* //padding */ {pad, pad, pad, pad}, // /* dialation */ {1, 1}, /* otpt_pad */ {0,0}, /* trans */ transpose), // 2D conv shapes vector> shapes = { // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, // pad_h_top, pad_w_left, pad_h_bottom, pad_w_right, // (dilation_h, dilation_w, output_padding_h, output_padding_w, tranpose) // 2D convolutions // regular // Ferraris Model // Data from - // https://docs.google.com/spreadsheets/d/1VM-nglZl-pSwBdgYm3VbeLRcORc5y_vTRl9anRCUSDQ/edit#gid=1776750723 // conv_param_t<>(N, IC, OC, H, W, G, // /* kern */ {kernel1, kernel2}, /* stride */ {stride1, stride2}, /* //padding */ {pad, pad, pad, pad}, // /* dialation */ {1, 1}, /* otpt_pad */ {0,0}, /* trans */ transpose), conv_param_t<>(1, 128, 128, {2, 257}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, {1, 1}, {0, 0}, false), conv_param_t<>(1, 16, 16, {2, 126}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, {1, 1}, {0, 0}, false), conv_param_t<>(1, 64, 64, {2, 257}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, {1, 1}, {0, 0}, false), }; vector> shapes_trans = { conv_param_t<>(1, 256, 176, {2, 4}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, {1, 1}, {0, 0}, true), conv_param_t<>(1, 128, 128, {4, 12}, 1, {2, 6}, {1, 1}, {0, 0, 0, 0}, {1, 1}, {0, 0}, true), conv_param_t<>(1, 512, 64, {4, 50}, 1, {2, 6}, {1, 1}, {0, 0, 0, 0}, {1, 1}, {0, 0}, true), }; namespace { /* class FBGemmDirectConvTest : public testing::TestWithParam> {}; */ class FBGemmDirectConvTransTest : public testing::TestWithParam> {}; class FBGemmDirectConvTransFbgemmTest : public testing::TestWithParam> {}; } // namespace template void transposeConvWeights_KwIchO8I4( const conv_param_t& conv_p, const std::int8_t* src, std::int8_t* dest) { int G = conv_p.G; int IC_per_G = conv_p.IC / conv_p.G; int OC_per_G = conv_p.OC / conv_p.G; int filter_prod = std::accumulate( conv_p.K.begin(), conv_p.K.begin() + SPATIAL_DIM, 1, std::multiplies()); // Transforms weights from G K/G (T R S C/G) to G (T R S C/G) K/G format. // the transposed weight layout: W[oc/8][h][w][ic/4][8][4] for (int g = 0; g < G; ++g) { for (int k = 0; k < OC_per_G; ++k) { for (int f = 0; f < filter_prod; ++f) { for (int c = 0; c < IC_per_G; ++c) { int ocB = k / 8; int ocb = k % 8; int icB = c / 4; int icb = c % 4; dest [(((ocB * filter_prod + f) * (IC_per_G / 4) + icB) * 8 + ocb) * 4 + icb] = src[((g * OC_per_G + k) * filter_prod + f) * IC_per_G + c]; } } } } } void directConvRowSum( const conv_param_t<2>& conv_p, uint8_t* A, int32_t* inSum, int32_t* rowSum) { int IN0 = conv_p.IN_DIM[0]; int IN1 = conv_p.IN_DIM[1]; int IC = conv_p.IC; int K0 = conv_p.K[0]; int K1 = conv_p.K[1]; int OUT0 = conv_p.OUT_DIM[0]; int OUT1 = conv_p.OUT_DIM[1]; int stride = conv_p.stride[1]; memset(rowSum, 0, sizeof(int32_t) * OUT0 * OUT1); for (int ih = 0; ih < IN0; ++ih) for (int iw = 0; iw < IN1; ++iw) { inSum[ih * IN1 + iw] = reduceAvx2(A + ih * IN1 * IC + iw * IC, IC); } for (int ih = 0; ih < IN0; ++ih) for (int iw = 0; iw < IN1; iw++) { for (int r = 0; r < K0; ++r) { for (int s = 0; s < K1; ++s) { rowSum[(ih + r) * OUT1 + iw * stride + s] += inSum[ih * IN1 + iw]; } } } /* compare_buffers( rowSum, rowoffsets, OUT0, OUT1, OUT1, 5); */ } void col_offsets_with_zero_pt_s8acc32_DirectConvT_ref( const conv_param_t<2>& conv_p, const int8_t* Bint8, const int32_t* B_zero_point, int32_t* col_offsets, int ncols_per_quant_group) { int IC = conv_p.IC; int OC = conv_p.OC; array IN_DIM = conv_p.IN_DIM; array OUT_DIM = conv_p.OUT_DIM; array K = conv_p.K; array stride = conv_p.stride; int MDim = conv_p.MB * OUT_DIM[0] * OUT_DIM[1]; int NDim = conv_p.OC / conv_p.G; // int KDim = K[0] * K[1] * conv_p.IC; std::memset(col_offsets, 0, MDim * NDim); vector count(MDim * NDim, 0); for (int oc = 0; oc < OC; oc++) { for (int ih = 0; ih < IN_DIM[0]; ih++) { for (int iw = 0; iw < IN_DIM[1]; iw++) { for (int kh = 0; kh < K[0]; kh++) { for (int kw = 0; kw < K[1]; kw++) { for (int ic = 0; ic < IC; ic++) { int oh = ih * stride[0] + kh; int ow = iw * stride[1] + kw; col_offsets[(oh * OUT_DIM[1] + ow) * OC + oc] += Bint8 [(((((oc / 8) * K[0] + kh) * K[1] + kw) * (IC / 4) + ic / 4) * 8 + (oc % 8)) * 4 + (ic % 4)]; count[(oh * OUT_DIM[1] + ow) * OC + oc]++; } } } } } } for (int oc = 0; oc < OC; oc++) { for (int oh = 0; oh < OUT_DIM[0]; oh++) { for (int ow = 0; ow < OUT_DIM[1]; ow++) { col_offsets[(oh * OUT_DIM[1] + ow) * OC + oc] -= B_zero_point[oc / ncols_per_quant_group] * count[(oh * OUT_DIM[1] + ow) * OC + oc]; } } } } void QuantizeDirectConv_ref( const conv_param_t<2>& conv_p, aligned_vector Aint8, aligned_vector Bint8, aligned_vector& Cint32_ref, aligned_vector& Cint8_ref, int32_t Aint8_zero_point, aligned_vector C_multiplier, int32_t C_zero_point, aligned_vector Bint8_zero_point) { int im_out_dim = accumulate( conv_p.OUT_DIM.begin(), conv_p.OUT_DIM.end(), 1, multiplies()); int kernel_dim = accumulate(conv_p.K.begin(), conv_p.K.end(), 1, multiplies()); aligned_vector Bint8_tr( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); transposeConvWeights<2>(conv_p, Bint8.data(), Bint8_tr.data()); conv_ref( conv_p, Aint8.data(), Aint8_zero_point, Bint8_tr.data(), Cint32_ref.data()); // matrix dimensions after im2col int MDim = conv_p.MB * im_out_dim; int NDim = conv_p.OC / conv_p.G; int KDim = kernel_dim * conv_p.IC; int KDimPerGroup = KDim / conv_p.G; int OC_per_G = conv_p.OC / conv_p.G; // computing row offset vector row_offsets(MDim); vector Aint8_im2col(MDim * KDim); im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); vector row_offsets_sum(MDim, 0); vector in_row_offsets_sum(conv_p.IN_DIM[0] * conv_p.IN_DIM[1], 0); // computing column offset vector col_offsets(conv_p.OC); for (int g = 0; g < conv_p.G; ++g) { col_offsets_with_zero_pt_s8acc32_ref( KDimPerGroup, OC_per_G, OC_per_G, Bint8_tr.data() + g * KDimPerGroup * OC_per_G, Bint8_zero_point.data(), col_offsets.data() + g * OC_per_G, conv_p.OC); } for (int g = 0; g < conv_p.G; ++g) { row_offsets_u8acc32_ref( MDim, KDimPerGroup, KDim, Aint8_im2col.data() + g * KDimPerGroup, row_offsets.data()); requantize_u8acc32_ref( MDim, NDim, conv_p.G * NDim, Cint32_ref.data() + g * NDim, Cint8_ref.data() + g * NDim, C_multiplier.data() + g * NDim / conv_p.OC, C_zero_point, Aint8_zero_point, Bint8_zero_point.data() + g * NDim / conv_p.OC, row_offsets.data(), col_offsets.data() + g * NDim, nullptr, conv_p.OC); } } /* INSTANTIATE_TEST_CASE_P( InstantiationName, FBGemmDirectConvTest, ::testing::Combine( ::testing::Bool(), // a_symmetric ::testing::Bool(), // b_symmetric ::testing::Values(1, 2))); // oc_per_g TEST_P(FBGemmDirectConvTest, Test2D) { bool a_symmetric, b_symmetric; int oc_per_g; tie(a_symmetric, b_symmetric, oc_per_g) = GetParam(); for (auto conv_p : shapes) { int im_in_dim = accumulate( conv_p.IN_DIM.begin(), conv_p.IN_DIM.end(), 1, multiplies()); aligned_vector aBuf(conv_p.MB * im_in_dim * conv_p.IC); int kernel_dim = accumulate(conv_p.K.begin(), conv_p.K.end(), 1, multiplies()); aligned_vector bBuf( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); aligned_vector bBuf_pf( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); aligned_vector Bint8_tr( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); aligned_vector Bint8_tr_vec( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); aligned_vector C_multiplier(1); randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); int32_t C_zero_point = 5; int im_out_dim = accumulate( conv_p.OUT_DIM.begin(), conv_p.OUT_DIM.end(), 1, multiplies()); // matrix dimensions after im2col int MDim = conv_p.MB * im_out_dim; int NDim = conv_p.OC / conv_p.G; int KDim = kernel_dim * conv_p.IC; int KDimPerGroup = KDim / conv_p.G; int OC_per_G = conv_p.OC / conv_p.G; aligned_vector Cint32_ref(conv_p.MB * im_out_dim * conv_p.OC); aligned_vector Cint8_ref(Cint32_ref.size(), 0); aligned_vector Cint32_fb(Cint32_ref.size()); aligned_vector Cint8_fb(Cint32_ref.size(), 0); aligned_vector Cint8_fb2(Cint32_ref.size(), 0); aligned_vector Cint32_fb2(Cint32_ref.size()); DirectConvCodeGenBase::jit_micro_kernel_fp fn; // fn = GemmGetOrCreate( // true, _MB, _NB, _KB); DirectConvCodeGenBase codeObj; fn = codeObj.getOrCreateDirectConv( true, conv_p.OUT_DIM[1], conv_p.IN_DIM[1] * conv_p.IC, conv_p.stride[1] * conv_p.IC); randFill(aBuf, 0, 5); randFill(bBuf, -4, 4); randFill(bBuf_pf, -4, 4); int32_t Aint8_zero_point = 4; aligned_vector Bint8_zero_point(1); randFill(Bint8_zero_point, -3, -1); aligned_vector bBuf_tr(bBuf.size()); transposeConvWeights_KwIchO8I4<2>(conv_p, bBuf.data(), bBuf_tr.data()); for (int i = 0; i < conv_p.OC; i += 8) { fn(aBuf.data(), bBuf_tr.data() + i * kernel_dim * conv_p.IC, bBuf_pf.data(), Cint32_fb.data() + i, conv_p.IC * conv_p.K[1], conv_p.OC); } // reference quantized int8 convolution implementation QuantizeDirectConv_ref( conv_p, aBuf, bBuf, Cint32_ref, Cint8_ref, Aint8_zero_point, C_multiplier, C_zero_point, Bint8_zero_point); compare_buffers( Cint32_fb.data(), Cint32_ref.data(), conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1], conv_p.OC, conv_p.OC, 5); // computing column offset vector col_offsets(conv_p.OC); transposeConvWeights<2>(conv_p, bBuf.data(), Bint8_tr.data()); for (int g = 0; g < conv_p.G; ++g) { col_offsets_with_zero_pt_s8acc32_ref( KDimPerGroup, OC_per_G, OC_per_G, Bint8_tr.data() + g * KDimPerGroup * OC_per_G, Bint8_zero_point.data(), col_offsets.data() + g * OC_per_G, conv_p.OC); } vector row_offsets(MDim); vector Aint8_im2col(MDim * KDim); im2col_ref(conv_p, aBuf.data(), Aint8_zero_point, Aint8_im2col.data()); for (int g = 0; g < conv_p.G; ++g) { row_offsets_u8acc32_ref( MDim, KDimPerGroup, KDim, Aint8_im2col.data() + g * KDimPerGroup, row_offsets.data()); requantize_u8acc32_ref( MDim, NDim, conv_p.G * NDim, Cint32_fb.data() + g * NDim, Cint8_fb.data() + g * NDim, C_multiplier.data() + g * NDim / conv_p.OC, C_zero_point, Aint8_zero_point, Bint8_zero_point.data() + g * NDim / conv_p.OC, row_offsets.data(), col_offsets.data() + g * NDim, nullptr, conv_p.OC); } // correctness check for (int n = 0; n < conv_p.MB; ++n) { for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) { for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) { for (int k = 0; k < conv_p.OC; ++k) { int H_OUT = conv_p.OUT_DIM[0]; int W_OUT = conv_p.OUT_DIM[1]; int OC = conv_p.OC; int32_t expected = Cint8_ref[((n * H_OUT + h) * W_OUT + w) * OC + k]; int32_t actual = Cint8_fb[((n * H_OUT + h) * W_OUT + w) * OC + k]; EXPECT_EQ(actual, expected) << "Directconv " << conv_p.K[0] << "x" << conv_p.K[1] << " results differ at (" << n << ", " << h << ", " << w << ", " << k << ")."; } } } } } // for each shape } */ INSTANTIATE_TEST_CASE_P( InstantiationName, FBGemmDirectConvTransTest, ::testing::Combine( ::testing::Bool(), // a_symmetric ::testing::Bool(), // b_symmetric ::testing::Values(1, 2))); // oc_per_g TEST_P(FBGemmDirectConvTransTest, Test2D) { bool a_symmetric, b_symmetric; int oc_per_g; tie(a_symmetric, b_symmetric, oc_per_g) = GetParam(); for (auto conv_p : shapes_trans) { int im_in_dim = accumulate( conv_p.IN_DIM.begin(), conv_p.IN_DIM.end(), 1, multiplies()); aligned_vector aBuf(conv_p.MB * im_in_dim * conv_p.IC); int kernel_dim = accumulate(conv_p.K.begin(), conv_p.K.end(), 1, multiplies()); aligned_vector bBuf( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); aligned_vector bBuf_pf( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); aligned_vector Bint8_tr( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); aligned_vector Bint8_tr_vec( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); aligned_vector C_multiplier(1); randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); int32_t C_zero_point = 5; int im_out_dim = accumulate( conv_p.OUT_DIM.begin(), conv_p.OUT_DIM.end(), 1, multiplies()); // matrix dimensions after im2col int MDim = conv_p.MB * im_out_dim; int NDim = conv_p.OC / conv_p.G; int KDim = kernel_dim * conv_p.IC; int KDimPerGroup = KDim / conv_p.G; int OC_per_G = conv_p.OC / conv_p.G; aligned_vector Cint32_ref(conv_p.MB * im_out_dim * conv_p.OC); aligned_vector Cint8_ref(Cint32_ref.size(), 0); aligned_vector Cint32_fb(Cint32_ref.size(), 0); aligned_vector Cint8_fb(Cint32_ref.size(), 0); aligned_vector Cint8_fb2(Cint32_ref.size(), 0); aligned_vector Cint32_fb2(Cint32_ref.size()); randFill(aBuf, 0, 5); randFill(bBuf, -4, 4); randFill(bBuf_pf, -4, 4); int32_t Aint8_zero_point = 4; aligned_vector Bint8_zero_point(1); randFill(Bint8_zero_point, -3, -1); aligned_vector &Bint8 = bBuf; aligned_vector &Aint8 = aBuf; // reference implementation // conv_ref expects weights to be in G (R S C/G) K/G transposeConvWeights<2>(conv_p, Bint8.data(), Bint8_tr.data()); transposeConvWeights_KwIchO8I4<2>( conv_p, Bint8.data(), Bint8_tr_vec.data()); conv_ref( // DirectConvTrans_ref( conv_p, Aint8.data(), Aint8_zero_point, Bint8_tr.data(), Cint32_ref.data()); // computing row offset vector row_offsets(MDim); vector Aint8_im2col(MDim * KDim); im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); // computing column offset vector col_offsets(conv_p.OC); for (int g = 0; g < conv_p.G; ++g) { col_offsets_with_zero_pt_s8acc32_ref( KDimPerGroup, OC_per_G, OC_per_G, Bint8_tr.data() + g * KDimPerGroup * OC_per_G, Bint8_zero_point.data(), col_offsets.data() + g * OC_per_G, conv_p.OC); } for (int g = 0; g < conv_p.G; ++g) { row_offsets_u8acc32_ref( MDim, KDimPerGroup, KDim, Aint8_im2col.data() + g * KDimPerGroup, row_offsets.data()); requantize_u8acc32_ref( MDim, NDim, conv_p.G * NDim, Cint32_ref.data() + g * NDim, Cint8_ref.data() + g * NDim, C_multiplier.data() + g * NDim / conv_p.OC, C_zero_point, Aint8_zero_point, Bint8_zero_point.data() + g * NDim / conv_p.OC, row_offsets.data(), col_offsets.data() + g * NDim, nullptr, conv_p.OC); } // computing column offset vector col_offsetsT(conv_p.OC * MDim); for (int g = 0; g < conv_p.G; ++g) { col_offsets_with_zero_pt_s8acc32_DirectConvT_ref( conv_p, Bint8_tr_vec.data() + g * KDimPerGroup * OC_per_G, Bint8_zero_point.data(), col_offsetsT.data() + g * OC_per_G, conv_p.OC); } string runType; PackedDirectConvMatrix packedB(conv_p.IC, conv_p.OC, kernel_dim, Bint8.data()); DoNothing<> doNothingObj{}; ReQuantizeOutput outputProcObj( doNothingObj, C_multiplier.data(), C_zero_point, Aint8_zero_point, Bint8_zero_point.data(), nullptr, // row offsets col_offsetsT.data(), nullptr, // bias conv_p.OC, conv_p.G); int32_t* bias_p = nullptr; fbgemmDirectConv(conv_p, Aint8.data(), packedB, Cint8_fb.data(), Cint32_fb.data(), outputProcObj, bias_p, //bias 0, 1); /* compare_buffers( Cint8_ref.data(), Cint8_fb.data(), MDim, NDim * conv_p.G, NDim * conv_p.G, 5); */ // correctness check for (int n = 0; n < conv_p.MB; ++n) { for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) { for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) { for (int k = 0; k < conv_p.OC; ++k) { int H_OUT = conv_p.OUT_DIM[0]; int W_OUT = conv_p.OUT_DIM[1]; int OC = conv_p.OC; int32_t expected = Cint8_ref[((n * H_OUT + h) * W_OUT + w) * OC + k]; int32_t actual = Cint8_fb[((n * H_OUT + h) * W_OUT + w) * OC + k]; EXPECT_EQ(actual, expected) << "DirectconvTrans " << conv_p.K[0] << "x" << conv_p.K[1] << " results differ at (" << n << ", " << h << ", " << w << ", " << k << ")."; } } } } } // for each shape } INSTANTIATE_TEST_CASE_P( InstantiationName, FBGemmDirectConvTransFbgemmTest, ::testing::Combine( ::testing::Bool(), // a_symmetric ::testing::Bool(), // b_symmetric ::testing::Values(1, 2))); // oc_per_g TEST_P(FBGemmDirectConvTransFbgemmTest, Test2D) { bool a_symmetric, b_symmetric; int oc_per_g; tie(a_symmetric, b_symmetric, oc_per_g) = GetParam(); for (auto conv_p : shapes_trans) { int im_in_dim = accumulate( conv_p.IN_DIM.begin(), conv_p.IN_DIM.end(), 1, multiplies()); aligned_vector aBuf(conv_p.MB * im_in_dim * conv_p.IC); int kernel_dim = accumulate(conv_p.K.begin(), conv_p.K.end(), 1, multiplies()); aligned_vector bBuf( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); aligned_vector bBuf_pf( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); aligned_vector Bint8_tr( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); aligned_vector Bint8_tr_vec( kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); aligned_vector C_multiplier(1); randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); int32_t C_zero_point = 5; int im_out_dim = accumulate( conv_p.OUT_DIM.begin(), conv_p.OUT_DIM.end(), 1, multiplies()); // matrix dimensions after im2col int MDim = conv_p.MB * im_out_dim; int NDim = conv_p.OC / conv_p.G; int KDim = kernel_dim * conv_p.IC; int KDimPerGroup = KDim / conv_p.G; int OC_per_G = conv_p.OC / conv_p.G; aligned_vector Cint32_ref(conv_p.MB * im_out_dim * conv_p.OC); aligned_vector Cint8_ref(Cint32_ref.size(), 0); aligned_vector Cint32_fb(Cint32_ref.size(), 0); aligned_vector Cint8_fb(Cint32_ref.size(), 0); aligned_vector Cint8_fb2(Cint32_ref.size(), 0); aligned_vector Cint32_fb2(Cint32_ref.size()); randFill(aBuf, 0, 5); randFill(bBuf, -4, 4); randFill(bBuf_pf, -4, 4); int32_t Aint8_zero_point = 4; aligned_vector Bint8_zero_point(1); randFill(Bint8_zero_point, -3, -1); aligned_vector &Bint8 = bBuf; aligned_vector &Aint8 = aBuf; // reference implementation // conv_ref expects weights to be in G (R S C/G) K/G transposeConvWeights<2>(conv_p, Bint8.data(), Bint8_tr.data()); conv_ref( // DirectConvTrans_ref( conv_p, Aint8.data(), Aint8_zero_point, Bint8_tr.data(), Cint32_ref.data()); // computing row offset vector row_offsets(MDim); vector Aint8_im2col(MDim * KDim); im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); // computing column offset vector col_offsets(conv_p.OC); for (int g = 0; g < conv_p.G; ++g) { col_offsets_with_zero_pt_s8acc32_ref( KDimPerGroup, OC_per_G, OC_per_G, Bint8_tr.data() + g * KDimPerGroup * OC_per_G, Bint8_zero_point.data(), col_offsets.data() + g * OC_per_G, conv_p.OC); } for (int g = 0; g < conv_p.G; ++g) { row_offsets_u8acc32_ref( MDim, KDimPerGroup, KDim, Aint8_im2col.data() + g * KDimPerGroup, row_offsets.data()); requantize_u8acc32_ref( MDim, NDim, conv_p.G * NDim, Cint32_ref.data() + g * NDim, Cint8_ref.data() + g * NDim, C_multiplier.data() + g * NDim / conv_p.OC, C_zero_point, Aint8_zero_point, Bint8_zero_point.data() + g * NDim / conv_p.OC, row_offsets.data(), col_offsets.data() + g * NDim, nullptr, conv_p.OC); } // fbgemm top-level function for direct conv path PackWeightsForConv<2> packedB_2D(conv_p, Bint8.data()); if (packedB_2D.getPackedWForDirectconv().get()) { packedB_2D.getPackedWForDirectconv().get()->col_offsets_with_zero_pt_s8acc32_DirectConvT( conv_p, Bint8_zero_point.data(), col_offsets, conv_p.OC); } DoNothing<> doNothingObj{}; ReQuantizeOutput outputProcObj( doNothingObj, C_multiplier.data(), C_zero_point, Aint8_zero_point, Bint8_zero_point.data(), nullptr, // row offsets col_offsets.data(), nullptr, // bias conv_p.OC, conv_p.G); fbgemmConv( conv_p, Aint8.data(), packedB_2D, Cint8_fb.data(), Cint32_fb.data(), outputProcObj, 0, 1); /* compare_buffers( Cint8_ref.data(), Cint8_fb.data(), MDim, NDim * conv_p.G, NDim * conv_p.G, 5); */ // correctness check for (int n = 0; n < conv_p.MB; ++n) { for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) { for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) { for (int k = 0; k < conv_p.OC; ++k) { int H_OUT = conv_p.OUT_DIM[0]; int W_OUT = conv_p.OUT_DIM[1]; int OC = conv_p.OC; int32_t expected = Cint8_ref[((n * H_OUT + h) * W_OUT + w) * OC + k]; int32_t actual = Cint8_fb[((n * H_OUT + h) * W_OUT + w) * OC + k]; EXPECT_EQ(actual, expected) << "DirectconvTrans " << conv_p.K[0] << "x" << conv_p.K[1] << " results differ at (" << n << ", " << h << ", " << w << ", " << k << ")."; } } } } } // for each shape } } // fbgemm namespace