/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include "TestUtils.h" #include "fbgemm/Fbgemm.h" #include "src/RefImplementations.h" using namespace std; using namespace fbgemm; static vector> GetInputs_() { vector> input_dims = { // num_rows, block_size {150, 1}, {150, 4}, {10, 8}, {150, 16}, {1, 8}, {1, 16}, {150, 24}, {150, 32}, {150, 40}, {150, 64}, {150, 80}, {150, 128}, {150, 384}, {10, 385}, {10, 769}, }; return input_dims; } vector prefetch_distances{0, 16, 1000000}; namespace { class SparseAdagradTest : public testing::TestWithParam> {}; }; // namespace constexpr float DEFAULT_TOL = 1.0e-6; // Test: INSTANTIATE_TEST_CASE_P( InstantiationName, SparseAdagradTest, ::testing::Combine( ::testing::Bool(), // 64 bit indices ::testing::ValuesIn(prefetch_distances), ::testing::Bool(), // out of bound indices ::testing::Bool(), // use weight decay ::testing::Bool())); // adjust weight decay TEST_P(SparseAdagradTest, basicTest_two_stages) { vector> inputs(GetInputs_()); bool isIndex64b, out_of_bounds, use_weight_decay, adjust_weight_decay; int prefetch; tie(isIndex64b, prefetch, out_of_bounds, use_weight_decay, adjust_weight_decay) = GetParam(); for (auto input : inputs) { int num_rows = input[0]; int block_size = input[1]; int param_size = num_rows * block_size; vector g(param_size); // gradients vector h(param_size); // input momentums vector w(param_size); // input params vector h_ref(param_size); vector w_ref(param_size); random_device r; default_random_engine generator(r()); normal_distribution h_w_distribution; uniform_real_distribution values_gen(0, 10); for (int i = 0; i < param_size; i++) { h_ref[i] = h[i] = values_gen(generator); } for (int i = 0; i < param_size; i++) { w_ref[i] = w[i] = values_gen(generator); } for (int i = 0; i < param_size; i++) { g[i] = values_gen(generator); } vector indices(num_rows); vector indices_32(num_rows); float epsilon = 1e-5; float lr = 0.5; float weight_decay = use_weight_decay ? 0.1f : 0.0f; uniform_int_distribution index_distribution(0, num_rows - 1); for (int i = 0; i < num_rows; ++i) { indices_32[i] = indices[i] = index_distribution(generator); } if (out_of_bounds) { int idx = index_distribution(generator); indices_32[idx] = indices[idx] = num_rows; } vector counters; constexpr int64_t counter_halflife = 1e6; if (adjust_weight_decay) { uniform_real_distribution<> counter_distribution(0, 2 * counter_halflife); counters.resize(num_rows); for (int i = 0; i < num_rows; ++i) { counters[i] = counter_distribution(generator); } } int ret_fbgemm, ret_ref; if (isIndex64b) { ret_ref = sparse_adagrad_ref( num_rows, // number of rows reading block_size, // number of parameters per rows param_size, // total number of parameters w_ref.data(), // input parameters g.data(), // input gradients h_ref.data(), // input momentums indices.data(), // indices of each row epsilon, lr, weight_decay, adjust_weight_decay ? counters.data() : nullptr, counter_halflife); auto fn_fbgemm = GenerateSparseAdaGrad( block_size, false, prefetch, use_weight_decay); ret_fbgemm = fn_fbgemm( num_rows, // number of rows reading param_size, // total number of parameters w.data(), // input parameters g.data(), // input gradients h.data(), // input momentums indices.data(), // indices of each row epsilon, lr, weight_decay, adjust_weight_decay ? counters.data() : nullptr, counter_halflife); } else { // 32 bit indices ret_ref = sparse_adagrad_ref( num_rows, // number of rows reading block_size, // number of parameters per rows param_size, // total number of parameters w_ref.data(), // input parameters g.data(), // input gradients h_ref.data(), // input momentums indices_32.data(), // indices of each row epsilon, lr, weight_decay, adjust_weight_decay ? counters.data() : nullptr, counter_halflife); auto fn_fbgemm = GenerateSparseAdaGrad( block_size, false, prefetch, use_weight_decay); ret_fbgemm = fn_fbgemm( num_rows, // number of rows reading param_size, // total number of parameters w.data(), // input parameters g.data(), // input gradients h.data(), // input momentums indices_32.data(), // indices of each row epsilon, lr, weight_decay, adjust_weight_decay ? counters.data() : nullptr, counter_halflife); } EXPECT_EQ(ret_fbgemm, ret_ref) << "return vals differ, reference is: " << ret_ref << " ,fbgemm is: " << ret_fbgemm; EXPECT_TRUE(floatCloseAll(h, h_ref, DEFAULT_TOL, DEFAULT_TOL)); EXPECT_TRUE(floatCloseAll(w, w_ref, DEFAULT_TOL, DEFAULT_TOL)); } } TEST_P(SparseAdagradTest, rowwiseTest_two_stages) { vector> inputs(GetInputs_()); bool isIndex64b, out_of_bounds, use_weight_decay, adjust_weight_decay; int prefetch; tie(isIndex64b, prefetch, out_of_bounds, use_weight_decay, adjust_weight_decay) = GetParam(); for (auto input : inputs) { int num_rows = input[0]; int block_size = input[1]; int param_size = num_rows * block_size; vector g(param_size); // gradients vector h(param_size); // input momentums vector w(param_size); // input params vector h_ref(param_size); vector w_ref(param_size); random_device r; default_random_engine generator(r()); uniform_real_distribution values_gen(0, 2); for (int i = 0; i < param_size; i++) { h_ref[i] = h[i] = values_gen(generator); } for (int i = 0; i < param_size; i++) { w_ref[i] = w[i] = values_gen(generator); } for (int i = 0; i < param_size; i++) { g[i] = values_gen(generator); } vector indices(num_rows); vector indices_32(num_rows); float epsilon = 1e-5; float lr = 0.5; float weight_decay = use_weight_decay ? 0.1f : 0.0f; uniform_int_distribution index_distribution(0, num_rows - 1); for (int i = 0; i < num_rows; ++i) { indices_32[i] = indices[i] = index_distribution(generator); } if (out_of_bounds) { int idx = index_distribution(generator); indices_32[idx] = indices[idx] = num_rows; } vector counters; constexpr int64_t counter_halflife = 1e6; if (adjust_weight_decay) { uniform_real_distribution<> counter_distribution(0, 2 * counter_halflife); counters.resize(num_rows); for (int i = 0; i < num_rows; ++i) { counters[i] = counter_distribution(generator); } } int ret_fbgemm, ret_ref; if (isIndex64b) { ret_ref = rowwise_sparse_adagrad_ref( num_rows, // number of rows reading block_size, // number of parameters per rows param_size, // total number of parameters w_ref.data(), // input parameters g.data(), // input gradients h_ref.data(), // input momentums indices.data(), // indices of each row epsilon, lr, weight_decay, adjust_weight_decay ? counters.data() : nullptr, counter_halflife); auto fn_fbgemm = GenerateSparseAdaGrad( block_size, true, prefetch, use_weight_decay); ret_fbgemm = fn_fbgemm( num_rows, // number of rows reading param_size, // total number of parameters w.data(), // input parameters g.data(), // input gradients h.data(), // input momentums indices.data(), // indices of each row epsilon, lr, weight_decay, adjust_weight_decay ? counters.data() : nullptr, counter_halflife); } else { // 32 bit indices ret_ref = rowwise_sparse_adagrad_ref( num_rows, // number of rows reading block_size, // number of parameters per rows param_size, // total number of parameters w_ref.data(), // input parameters g.data(), // input gradients h_ref.data(), // input momentums indices_32.data(), // indices of each row epsilon, lr, weight_decay, adjust_weight_decay ? counters.data() : nullptr, counter_halflife); auto fn_fbgemm = GenerateSparseAdaGrad( block_size, true, prefetch, use_weight_decay); ret_fbgemm = fn_fbgemm( num_rows, // number of rows reading param_size, // total number of parameters w.data(), // input parameters g.data(), // input gradients h.data(), // input momentums indices_32.data(), // indices of each row epsilon, lr, weight_decay, adjust_weight_decay ? counters.data() : nullptr, counter_halflife); } EXPECT_EQ(ret_fbgemm, ret_ref) << "return vals differ, reference is: " << ret_ref << " ,fbgemm is: " << ret_fbgemm; // Set the absolute tolerance of rowwise momentum to 1e-3 because it a // product of square, add, div which the rounding error can be very high EXPECT_TRUE(floatCloseAll(h, h_ref, 1.0e-3, 1.0e-3)); EXPECT_TRUE(floatCloseAll(w, w_ref, DEFAULT_TOL, DEFAULT_TOL)); } }