/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include "./QuantizationHelpers.h" #include "./TestUtils.h" #include "bench/BenchUtils.h" #include "fbgemm/Fbgemm.h" #include "src/RefImplementations.h" using namespace std; using namespace fbgemm; vector qGranularityVals{ QuantizationGranularity::TENSOR, QuantizationGranularity::GROUP, QuantizationGranularity::OUT_CHANNEL}; // clang-format off template static typename std::enable_if>>::type GetShapes_() { vector> shapes = { // MB, IC, OC, {IW}, G, {KW}, {stride_w}, {pad_l,pad_r}, {dilation_w} // Regular conv_param_t<1>(1, 16, 16, {30}, 1, {3}, {1}, {1, 1}), conv_param_t<1>(1, 32, 32, {30}, 1, {3}, {1}, {1, 1}), conv_param_t<1>(1, 32, 16, {30}, 1, {3}, {1}, {0, 0}, {2}), // deconv shapes // MB, IC, OC, {IW}, G, {KW}, {stride_w}, {pad_l,pad_r}, {dilation_w}, {output_padding}, transposed // Regular conv_param_t<1>(1, 16, 16, {30}, 1, {3}, {1}, {1, 1}, {1}, {0}, true), conv_param_t<1>(1, 32, 32, {30}, 1, {3}, {1}, {1, 1}, {1}, {0}, true), conv_param_t<1>(1, 32, 16, {30}, 1, {3}, {1}, {0, 0}, {1}, {0}, true), conv_param_t<1>(1, 16, 16, {30}, 1, {3}, {1}, {1, 1}, {2}, {0}, true), conv_param_t<1>(1, 32, 32, {30}, 1, {3}, {1}, {1, 1}, {2}, {0}, true), conv_param_t<1>(1, 32, 16, {30}, 1, {3}, {1}, {0, 0}, {2}, {0}, true), conv_param_t<1>(1, 16, 16, {30}, 1, {3}, {1}, {1, 1}, {1}, {1}, true), conv_param_t<1>(1, 32, 32, {30}, 1, {3}, {1}, {1, 1}, {1}, {1}, true), conv_param_t<1>(1, 32, 16, {30}, 1, {3}, {1}, {0, 0}, {1}, {1}, true), conv_param_t<1>(1, 32, 32, {30}, 1, {3}, {2}, {1, 1}, {2}, {0}, true), // some example deconv shapes conv_param_t<1>(1, 96, 48, {30}, 1, {16}, {8}, {4, 4}, {1}, {0}, true), conv_param_t<1>(1, 48, 24, {30}, 1, {8}, {4}, {2, 2}, {1}, {0}, true), conv_param_t<1>(1, 24, 12, {30}, 1, {4}, {2}, {1, 1}, {1}, {0}, true), // groupwise conv_param_t<1>(1, 32, 16, {30}, 8, {3}, {1}, {0, 0}, {1}, {1}, true), conv_param_t<1>(1, 32, 32, {30}, 8, {3}, {2}, {1, 1}, {2}, {0}, true), }; return shapes; } // clang-format on // clang-format off template static typename std::enable_if>>::type GetShapes_() { vector> shapes = { // MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, {pad_t, pad_l, // pad_b, pad_r}, {dilation_h, dilation_w} // Regular conv_param_t<>(1, 16, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 16, 32, {30, 10}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}, {2, 2}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {3, 3}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {2, 2}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {2, 1}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 2}, {1, 1, 1, 1}, {1, 2}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {2, 1, 2, 1}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 2, 1, 2}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 2}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {2, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 16, 16, {10, 30}, 1, {3, 5}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 16, 16, {10, 30}, 1, {5, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 16, 16, {10, 30}, 1, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), // groupwise conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 16, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 16, 32, {10, 30}, 8, {3, 3}, {2, 2}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 2}, {2, 1, 2, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 2}, {2, 1, 2, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 1}, {2, 1, 2, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 5}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), // DW conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {2, 2}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 1}, {1, 2, 1, 2}), conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {2, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 3}, {1, 2}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 32, {3, 5}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 3}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 5}, {1, 1}, {1, 1, 1, 1}), conv_param_t<>(1, 32, 32, {10, 30}, 32, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}), conv_param_t<>(1, 128, 256, {32, 100}, 128, {3, 3}, {1, 1}, {1, 1, 1, 1}), // Pointwise conv_param_t<>(1, 32, 32, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}), conv_param_t<>(1, 16, 32, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {1, 1}, {0, 0, 0, 0}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {2, 2}, {0, 0, 0, 0}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {1, 2}, {0, 0, 0, 0}), conv_param_t<>(1, 32, 16, {10, 30}, 1, {1, 1}, {2, 1}, {0, 0, 0, 0}), // deconv shapes // MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, {pad_t, pad_l, // pad_b, pad_r}, {dilation_h, dilation_w}, {output_padding_h, output_padding_w}, transposed // Regular conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}, {1, 1}, {0, 0}, true), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {0, 0, 0, 0}, {2, 2}, {0, 0}, true), conv_param_t<>(1, 32, 16, {10, 30}, 1, {3, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}, {0, 0}, true), // groupwise conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}, {1, 1}, {0, 0}, true), conv_param_t<>(1, 32, 16, {10, 30}, 8, {3, 3}, {1, 1}, {1, 1, 1, 1}, {1, 1}, {1, 0}, true), conv_param_t<>(1, 16, 32, {10, 30}, 8, {3, 3}, {2, 2}, {1, 1, 1, 1}, {1, 1}, {0, 1}, true), conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 2}, {2, 1, 2, 1}, {1, 1}, {1, 1}, true), conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {1, 2}, {2, 1, 2, 1}, {1, 1}, {0, 0}, true), conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 3}, {2, 1}, {2, 1, 2, 1}, {1, 1}, {0, 0}, true), conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 5}, {1, 1}, {1, 1, 1, 1}, {1, 1}, {0, 0}, true), conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}, {1, 1}, {0, 0}, true), conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}, {0, 0}, true), // directconv conv_param_t<>(1, 256, 176, {2, 4}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, {1, 1}, {0, 0}, true), conv_param_t<>(1, 128, 128, {4, 12}, 1, {2, 6}, {1, 1}, {0, 0, 0, 0}, {1, 1}, {0, 0}, true), conv_param_t<>(1, 512, 64, {4, 50}, 1, {2, 6}, {1, 1}, {0, 0, 0, 0}, {1, 1}, {0, 0}, true), }; return shapes; } // clang-format on namespace { // tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad class uniConvTest : public testing::TestWithParam< tuple> {}; // tuple represents QuantizationGranularity, A symmetric, B symmetric, // test_bias, test_float_bias class UniConvQGranTest : public testing::TestWithParam< tuple> {}; }; // namespace // Combine only allows at most 10 generators. INSTANTIATE_TEST_CASE_P( InstantiationName, uniConvTest, ::testing::Combine( ::testing::ValuesIn({1, 2}), // MB ::testing::ValuesIn({16, 32}), // IC ::testing::ValuesIn({16, 32}), // OC ::testing::ValuesIn({17}), // IT ::testing::ValuesIn({10, 30}), // IH ::testing::ValuesIn({10, 30, 55}), // IW ::testing::ValuesIn({1, 4, 16}), // G ::testing::ValuesIn({1, 3, 5, 7}), // kernel ::testing::ValuesIn({1, 2}), // stride ::testing::ValuesIn({0, 1, 2}))); // pad INSTANTIATE_TEST_CASE_P( InstantiationName, UniConvQGranTest, ::testing::Combine( ::testing::ValuesIn(qGranularityVals), ::testing::Bool(), // A symmetric ::testing::Bool(), // B symmetric ::testing::Bool(), // test_bias ::testing::Bool())); // test_float_bias /** * Test for conv packing */ TEST_P(uniConvTest, packingTest) { int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad; tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam(); conv_param_t<1> conv_p_1d( MB, IC, OC, {IW}, G, {kernel}, {stride}, {pad, pad}); int kernel_dim_1d = kernel; aligned_vector Bint8_1d( kernel_dim_1d * conv_p_1d.IC * (conv_p_1d.OC / conv_p_1d.G)); PackWeightsForConv<1> packedB_1D(conv_p_1d, Bint8_1d.data()); switch (ConvFastPath<1, int32_t>(conv_p_1d)) { case optimized_conv_t::depthwise: { ASSERT_EQ(packedB_1D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; ASSERT_EQ(packedB_1D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_1D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_NE(packedB_1D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix is null"; ASSERT_EQ(packedB_1D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix should be null"; break; } case optimized_conv_t::groupwise: { ASSERT_EQ(packedB_1D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; ASSERT_EQ(packedB_1D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_1D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_NE(packedB_1D.getPackedWForGroupwise(), nullptr) << "Groupwise packed matrix is null"; ASSERT_EQ(packedB_1D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix should be null"; break; } case optimized_conv_t::pointwise: { ASSERT_EQ(packedB_1D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; ASSERT_EQ(packedB_1D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix should null"; ASSERT_EQ(packedB_1D.getPackedWForGroupwise(), nullptr) << "Groupwise packed matrix should be null"; ASSERT_NE(packedB_1D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix is null"; ASSERT_EQ(packedB_1D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix should be null"; break; } case optimized_conv_t::directconv: { ASSERT_EQ(packedB_1D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_1D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_1D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_NE(packedB_1D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix is null"; ASSERT_EQ(packedB_1D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; break; } case optimized_conv_t::fastpath1d: { break; } case optimized_conv_t::im2col: { ASSERT_EQ(packedB_1D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_1D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_1D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_EQ(packedB_1D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix should be null"; ASSERT_NE(packedB_1D.getPackedWForIm2col(), nullptr) << "im2col packed matrix is null"; break; } } conv_param_t<2> conv_p_2d( MB, IC, OC, {IH, IW}, G, {kernel, kernel}, {stride, stride}, {pad, pad, pad, pad}); int kernel_dim_2d = kernel * kernel; aligned_vector Bint8_2d( kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); switch (ConvFastPath<2, int32_t>(conv_p_2d)) { case optimized_conv_t::depthwise: { ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_NE(packedB_2D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix is null"; ASSERT_EQ(packedB_2D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix should be null"; break; } case optimized_conv_t::groupwise: { ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr) << "Groupwise packed matrix is null"; ASSERT_EQ(packedB_2D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix should be null"; break; } case optimized_conv_t::pointwise: { ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix should null"; ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) << "Groupwise packed matrix should be null"; ASSERT_NE(packedB_2D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix is null"; ASSERT_EQ(packedB_2D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix should be null"; break; } case optimized_conv_t::directconv: { ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_NE(packedB_2D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix is null"; ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; break; } case optimized_conv_t::fastpath1d: { break; } case optimized_conv_t::im2col: { ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_NE(packedB_2D.getPackedWForIm2col(), nullptr) << "im2col packed matrix is null"; ASSERT_EQ(packedB_2D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix should be null"; break; } } conv_param_t<3> conv_p_3d( MB, IC, OC, {IT, IH, IW}, G, {kernel, kernel, kernel}, {stride, stride, stride}, {pad, pad, pad, pad, pad, pad}); int kernel_dim_3d = kernel * kernel * kernel; aligned_vector Bint8_3d( kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data()); switch (ConvFastPath<3, int32_t>(conv_p_3d)) { case optimized_conv_t::depthwise: { ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_NE(packedB_3D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix is null"; ASSERT_EQ(packedB_3D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix should be null"; break; } case optimized_conv_t::groupwise: { ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; ASSERT_NE(packedB_3D.getPackedWForGroupwise(), nullptr) << "Groupwise packed matrix is null"; ASSERT_EQ(packedB_3D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix should be null"; break; } case optimized_conv_t::pointwise: { ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; ASSERT_NE(packedB_3D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix is null"; ASSERT_EQ(packedB_3D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix should be null"; break; } case optimized_conv_t::directconv: { ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_NE(packedB_3D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix is null"; ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) << "im2col packed matrix should be null"; break; } case optimized_conv_t::fastpath1d: { break; } case optimized_conv_t::im2col: { ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; ASSERT_NE(packedB_3D.getPackedWForIm2col(), nullptr) << "im2col packed matrix is null"; ASSERT_EQ(packedB_3D.getPackedWForDirectconv(), nullptr) << "directconv packed matrix should be null"; break; } } } /** * Test for packing/unpacking */ TEST_P(uniConvTest, packUnpackTest) { int MB, IC, OC, IT, IH, IW, G, kernel, stride, pad; tie(MB, IC, OC, IT, IH, IW, G, kernel, stride, pad) = GetParam(); conv_param_t<1> conv_p_1d( MB, IC, OC, {IW}, G, {kernel}, {stride}, {pad, pad}); int kernel_dim_1d = kernel; aligned_vector Bint8_1d( kernel_dim_1d * conv_p_1d.IC * (conv_p_1d.OC / conv_p_1d.G)); aligned_vector Bint8_1d_unpacked( kernel_dim_1d * conv_p_1d.IC * (conv_p_1d.OC / conv_p_1d.G)); PackWeightsForConv<1> packedB_1D(conv_p_1d, Bint8_1d.data()); packedB_1D.unpack(Bint8_1d_unpacked.data()); ASSERT_EQ(Bint8_1d, Bint8_1d_unpacked) << "Original and unpacked data elements are not the same [1D]"; conv_param_t<2> conv_p_2d( MB, IC, OC, {IH, IW}, G, {kernel, kernel}, {stride, stride}, {pad, pad, pad, pad}); int kernel_dim_2d = kernel * kernel; aligned_vector Bint8_2d( kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); aligned_vector Bint8_2d_unpacked( kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); packedB_2D.unpack(Bint8_2d_unpacked.data()); ASSERT_EQ(Bint8_2d_unpacked, Bint8_2d) << "Original and unpacked data elements are not the same [2D]"; conv_param_t<3> conv_p_3d( MB, IC, OC, {IT, IH, IW}, G, {kernel, kernel, kernel}, {stride, stride, stride}, {pad, pad, pad, pad, pad, pad}); int kernel_dim_3d = kernel * kernel * kernel; aligned_vector Bint8_3d( kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); aligned_vector Bint8_3d_unpacked( kernel_dim_3d * conv_p_3d.IC * (conv_p_3d.OC / conv_p_3d.G)); PackWeightsForConv<3> packedB_3D(conv_p_3d, Bint8_3d.data()); packedB_3D.unpack(Bint8_3d_unpacked.data()); ASSERT_EQ(Bint8_3d_unpacked, Bint8_3d) << "Original and unpacked data elements are not the same [3D]"; } TEST(uniConvTest, cornerCases) { int stride = 1; conv_param_t<2> conv_p_2d( 1, // mini-batch 16, // input channels 32, // output channels {28, 28}, // input height/width 4, // groups {3, 3}, // kernel height/width {stride, stride}, // strides {1, 1, 1, 1}); // padding int kernel_dim_2d = conv_p_2d.K[0] * conv_p_2d.K[1]; aligned_vector Aint8( conv_p_2d.MB * conv_p_2d.IN_DIM[0] * conv_p_2d.IN_DIM[1] * conv_p_2d.IC); aligned_vector Bint8_2d( kernel_dim_2d * conv_p_2d.IC * (conv_p_2d.OC / conv_p_2d.G)); aligned_vector Cint32_fb( conv_p_2d.MB * conv_p_2d.OUT_DIM[0] * conv_p_2d.OUT_DIM[1] * conv_p_2d.OC); aligned_vector Cint8_fb(Cint32_fb.size(), 0); // A matrix (input activations) randFill(Aint8, 0, 5); int32_t Aint8_zero_point = 4; // B matrix (weights) randFill(Bint8_2d, -4, 4); aligned_vector Bint8_zero_point(1); randFill(Bint8_zero_point, -3, -1); aligned_vector C_multiplier(Bint8_zero_point.size()); randFill(C_multiplier, 0.1234f / 2, 0.1234f * 3 / 2); int32_t C_zero_point = 5; PackWeightsForConv<2> packedB_2D(conv_p_2d, Bint8_2d.data()); vector col_offsets(conv_p_2d.OC); DoNothing<> doNothingObj{}; ReQuantizeOutput outputProcObj( doNothingObj, C_multiplier.data(), C_zero_point, Aint8_zero_point, Bint8_zero_point.data(), nullptr, // row offsets col_offsets.data(), nullptr, // bias conv_p_2d.OC, conv_p_2d.G); try { conv_p_2d.stride[0] = 2; fbgemmConv( conv_p_2d, Aint8.data(), packedB_2D, Cint8_fb.data(), Cint32_fb.data(), outputProcObj, 0, 1); } catch (std::logic_error const& err) { std::string s(err.what()); EXPECT_TRUE(s.rfind("[FBGEMM_CONV_ERROR]", 0) == 0); } // reset conv_p_2d.stride[0] = stride; // this should run fine fbgemmConv( conv_p_2d, Aint8.data(), packedB_2D, Cint8_fb.data(), Cint32_fb.data(), outputProcObj, 0, 1); } template bool takeDirectConvPath(const conv_param_t& conv_p) { // Note: Direct convolutions (2D) are optimized for // filter size: 2 x 1 to 2 x 6, transposed conv, // in_channel % 8 == 0, out_channel % 8 == 0 // stride = 1 or 2 // padding = 0 ( non-zero padding will be supported soon) bool ret = std::is_same::value && conv_p.transposed && conv_p.G == 1 && conv_p.IC % 8 == 0 && conv_p.OC % 8 == 0 && std::all_of( conv_p.stride.begin(), conv_p.stride.end(), [](int i) { return i == 1 || i == 2; }) && SPATIAL_DIM == 2 && conv_p.K[SPATIAL_DIM - 2] == 2 && conv_p.K[SPATIAL_DIM - 1] <= 6 && std::all_of(conv_p.dilation.begin(), conv_p.dilation.end(), [](int i) { return i == 1; }); // Check pads: zero padding for (int i = 0; i < SPATIAL_DIM; ++i) { if (conv_p.pad[i] != 0) { ret = false; } } return ret; } /** * @brief Unit test for uint8 activations, int8 weights, and 32-bit * accumulation. Output processing: requantization -> nothing */ template void runRequantizeTest( QuantizationGranularity q_granularity, bool a_symmetric, bool b_symmetric, bool test_bias, bool test_float_bias) { vector> shapes(GetShapes_()); for (auto conv_p : shapes) { int R = SPATIAL_DIM == 1 ? 1 : conv_p.K[SPATIAL_DIM - 2]; int S = conv_p.K[SPATIAL_DIM - 1]; int G = conv_p.G; int OC = conv_p.OC; int OH = SPATIAL_DIM == 1 ? 1 : conv_p.OUT_DIM[SPATIAL_DIM - 2]; int OW = conv_p.OUT_DIM[SPATIAL_DIM - 1]; int IC_per_G = conv_p.IC / conv_p.G; int OC_per_G = conv_p.OC / conv_p.G; int IH = SPATIAL_DIM == 1 ? 1 : conv_p.IN_DIM[SPATIAL_DIM - 2]; int IW = conv_p.IN_DIM[SPATIAL_DIM - 1]; // activations aligned_vector Aint8(conv_p.MB * IH * IW * conv_p.IC, 0); // weights // The weight matrix is in layout G K/G (R S C/G) aligned_vector Bint8(R * S * G * IC_per_G * OC_per_G, 0); aligned_vector Bint8_tr(Bint8.size(), 0); aligned_vector Cint32_ref(conv_p.MB * OH * OW * OC, 0); aligned_vector Cint32_fb(Cint32_ref.size(), 0); aligned_vector Cint8_ref(Cint32_ref.size(), 0); aligned_vector Cint8_fb(Cint32_ref.size(), 0); randFill(Aint8, 0, 5); int32_t Aint8_zero_point = a_symmetric ? 0 : 4; randFill(Bint8, -4, 4); // computing column offset vector col_offsets(OC); int ncols_per_quant_group = OC; if (q_granularity == QuantizationGranularity::GROUP) { ncols_per_quant_group = OC_per_G; } else if (q_granularity == QuantizationGranularity::OUT_CHANNEL) { ncols_per_quant_group = 1; } aligned_vector Bint8_zero_point(OC / ncols_per_quant_group); if (b_symmetric) { randFill(Bint8_zero_point, 0, 0); } else { randFill(Bint8_zero_point, -3, 3); } // matrix dimensions after im2col for each GEMM. // For each group, there is one GEMM of the following dimensions int MDim = conv_p.MB * OH * OW; int NDim = OC_per_G; int KDim = R * S * IC_per_G; vector Aint8_im2col(MDim * KDim * G); im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); vector row_offsets(MDim); // activation_scale * weight_scale aligned_vector act_times_w_scale(Bint8_zero_point.size()); randFill(act_times_w_scale, 0.1234f / 2, 0.1234f * 3 / 2); float out_scale = 2.0f; aligned_vector C_multiplier(Bint8_zero_point.size()); transform( act_times_w_scale.begin(), act_times_w_scale.end(), C_multiplier.begin(), [&out_scale](float i) { return i / out_scale; }); int32_t C_zero_pt = 5; // initialize bias aligned_vector bias_int32(OC); aligned_vector bias_fp32(OC); if (test_bias) { randFill(bias_int32, -8, 8); } // floating point bias if (test_float_bias) { if (q_granularity == QuantizationGranularity::TENSOR) { transform( bias_int32.begin(), bias_int32.end(), bias_fp32.begin(), [&act_times_w_scale](float i) { return i * act_times_w_scale[0]; }); } else if (q_granularity == QuantizationGranularity::GROUP) { for (int g = 0; g < G; ++g) { for (int c = 0; c < OC_per_G; ++c) { bias_fp32[g * OC_per_G + c] = act_times_w_scale[g] * static_cast(bias_int32[g * OC_per_G + c]); } } } else { // OUT_CHANNEL transform( act_times_w_scale.begin(), act_times_w_scale.end(), bias_int32.begin(), bias_fp32.begin(), multiplies()); } } // reference implementation // conv_ref expects weights to be in G (R S C/G) K/G transposeConvWeights(conv_p, Bint8.data(), Bint8_tr.data()); int8_t* rightBData = Bint8_tr.data(); for (int g = 0; g < G; ++g) { col_offsets_with_zero_pt_s8acc32_ref( R * S * IC_per_G, OC_per_G, OC_per_G, rightBData + g * R * S * IC_per_G * OC_per_G, Bint8_zero_point.data() + g * OC_per_G / ncols_per_quant_group, col_offsets.data() + g * OC_per_G, ncols_per_quant_group); } conv_ref( conv_p, Aint8.data(), Aint8_zero_point, rightBData, Cint32_ref.data()); for (int g = 0; g < G; ++g) { row_offsets_u8acc32_ref( MDim, KDim, KDim * G, Aint8_im2col.data() + g * KDim, row_offsets.data()); requantize_u8acc32_ref( MDim, NDim, G * NDim, Cint32_ref.data() + g * NDim, Cint8_ref.data() + g * NDim, C_multiplier.data() + g * NDim / ncols_per_quant_group, C_zero_pt, Aint8_zero_point, Bint8_zero_point.data() + g * NDim / ncols_per_quant_group, row_offsets.data(), col_offsets.data() + g * NDim, test_bias ? bias_int32.data() + g * NDim : nullptr, ncols_per_quant_group); } PackWeightsForConv packedWeights(conv_p, Bint8.data()); // DirectConv col_offsets is handled differently if (packedWeights.getPackedWForDirectconv().get()) { packedWeights.getPackedWForDirectconv() .get() ->col_offsets_with_zero_pt_s8acc32_DirectConvT( conv_p, Bint8_zero_point.data(), col_offsets, ncols_per_quant_group); } // TODO: Uncomment once we support multiple threads in fbgemmGroupwiseConv // #ifdef _OPENMP // #pragma omp parallel // #endif { DoNothing<> doNothingObj{}; int num_threads = fbgemm_get_num_threads(); int tid = fbgemm_get_thread_num(); if (q_granularity == QuantizationGranularity::TENSOR) { if (test_float_bias) { ReQuantizeOutput reqObj( doNothingObj, C_multiplier.data(), C_zero_pt, Aint8_zero_point, Bint8_zero_point.data(), nullptr, /* row offset buffer */ col_offsets.data(), test_bias ? bias_fp32.data() : nullptr, G * NDim, G, act_times_w_scale.data()); fbgemmConv( conv_p, Aint8.data(), packedWeights, Cint8_fb.data(), Cint32_fb.data(), reqObj, tid, num_threads); } else { ReQuantizeOutput reqObj( doNothingObj, C_multiplier.data(), C_zero_pt, Aint8_zero_point, Bint8_zero_point.data(), nullptr, /* row offset buffer */ col_offsets.data(), test_bias ? bias_int32.data() : nullptr, G * NDim, G); fbgemmConv( conv_p, Aint8.data(), packedWeights, Cint8_fb.data(), Cint32_fb.data(), reqObj, tid, num_threads); } } else if (q_granularity == QuantizationGranularity::GROUP) { if (test_float_bias) { ReQuantizeOutput reqObj( doNothingObj, C_multiplier.data(), C_zero_pt, Aint8_zero_point, Bint8_zero_point.data(), nullptr, /* row offset buffer */ col_offsets.data(), test_bias ? bias_fp32.data() : nullptr, G * NDim, G, act_times_w_scale.data()); fbgemmConv( conv_p, Aint8.data(), packedWeights, Cint8_fb.data(), Cint32_fb.data(), reqObj, tid, num_threads); } else { ReQuantizeOutput reqObj( doNothingObj, C_multiplier.data(), C_zero_pt, Aint8_zero_point, Bint8_zero_point.data(), nullptr, /* row offset buffer */ col_offsets.data(), test_bias ? bias_int32.data() : nullptr, G * NDim, G); fbgemmConv( conv_p, Aint8.data(), packedWeights, Cint8_fb.data(), Cint32_fb.data(), reqObj, tid, num_threads); } } else { if (test_float_bias) { ReQuantizeOutput reqObj( doNothingObj, C_multiplier.data(), C_zero_pt, Aint8_zero_point, Bint8_zero_point.data(), nullptr, /* row offset buffer */ col_offsets.data(), test_bias ? bias_fp32.data() : nullptr, G * NDim, G, act_times_w_scale.data()); fbgemmConv( conv_p, Aint8.data(), packedWeights, Cint8_fb.data(), Cint32_fb.data(), reqObj, tid, num_threads); } else { ReQuantizeOutput reqObj( doNothingObj, C_multiplier.data(), C_zero_pt, Aint8_zero_point, Bint8_zero_point.data(), nullptr, /* row offset buffer */ col_offsets.data(), test_bias ? bias_int32.data() : nullptr, G * NDim, G); fbgemmConv( conv_p, Aint8.data(), packedWeights, Cint8_fb.data(), Cint32_fb.data(), reqObj, tid, num_threads); } } } // omp parallel compare_validate_buffers( Cint8_ref.data(), Cint8_fb.data(), MDim, NDim * G, NDim * G, static_cast(0)); } // for each shape } TEST_P(UniConvQGranTest, requantizeTest) { QuantizationGranularity q_granularity; bool a_symmetric, b_symmetric; bool test_bias, test_float_bias; tie(q_granularity, a_symmetric, b_symmetric, test_bias, test_float_bias) = GetParam(); runRequantizeTest<1>( q_granularity, a_symmetric, b_symmetric, test_bias, test_float_bias); runRequantizeTest<2>( q_granularity, a_symmetric, b_symmetric, test_bias, test_float_bias); }