sglang_v0.5.2/pytorch_2.8.0/third_party/fbgemm/test/SparseAdagradTest.cc

331 lines
10 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <algorithm>
#include <ostream>
#include <random>
#include <stdexcept>
#include <gtest/gtest.h>
#include "TestUtils.h"
#include "fbgemm/Fbgemm.h"
#include "src/RefImplementations.h"
using namespace std;
using namespace fbgemm;
static vector<vector<int>> GetInputs_() {
vector<vector<int>> input_dims = {
// num_rows, block_size
{150, 1},
{150, 4},
{10, 8},
{150, 16},
{1, 8},
{1, 16},
{150, 24},
{150, 32},
{150, 40},
{150, 64},
{150, 80},
{150, 128},
{150, 384},
{10, 385},
{10, 769},
};
return input_dims;
}
vector<int> prefetch_distances{0, 16, 1000000};
namespace {
class SparseAdagradTest
: public testing::TestWithParam<tuple<bool, int, bool, bool, bool>> {};
}; // namespace
constexpr float DEFAULT_TOL = 1.0e-6;
// Test:
INSTANTIATE_TEST_CASE_P(
InstantiationName,
SparseAdagradTest,
::testing::Combine(
::testing::Bool(), // 64 bit indices
::testing::ValuesIn(prefetch_distances),
::testing::Bool(), // out of bound indices
::testing::Bool(), // use weight decay
::testing::Bool())); // adjust weight decay
TEST_P(SparseAdagradTest, basicTest_two_stages) {
vector<vector<int>> inputs(GetInputs_());
bool isIndex64b, out_of_bounds, use_weight_decay, adjust_weight_decay;
int prefetch;
tie(isIndex64b,
prefetch,
out_of_bounds,
use_weight_decay,
adjust_weight_decay) = GetParam();
for (auto input : inputs) {
int num_rows = input[0];
int block_size = input[1];
int param_size = num_rows * block_size;
vector<float> g(param_size); // gradients
vector<float> h(param_size); // input momentums
vector<float> w(param_size); // input params
vector<float> h_ref(param_size);
vector<float> w_ref(param_size);
random_device r;
default_random_engine generator(r());
normal_distribution<float> h_w_distribution;
uniform_real_distribution<float> values_gen(0, 10);
for (int i = 0; i < param_size; i++) {
h_ref[i] = h[i] = values_gen(generator);
}
for (int i = 0; i < param_size; i++) {
w_ref[i] = w[i] = values_gen(generator);
}
for (int i = 0; i < param_size; i++) {
g[i] = values_gen(generator);
}
vector<std::int64_t> indices(num_rows);
vector<std::int32_t> indices_32(num_rows);
float epsilon = 1e-5;
float lr = 0.5;
float weight_decay = use_weight_decay ? 0.1f : 0.0f;
uniform_int_distribution<std::int64_t> index_distribution(0, num_rows - 1);
for (int i = 0; i < num_rows; ++i) {
indices_32[i] = indices[i] = index_distribution(generator);
}
if (out_of_bounds) {
int idx = index_distribution(generator);
indices_32[idx] = indices[idx] = num_rows;
}
vector<double> counters;
constexpr int64_t counter_halflife = 1e6;
if (adjust_weight_decay) {
uniform_real_distribution<> counter_distribution(0, 2 * counter_halflife);
counters.resize(num_rows);
for (int i = 0; i < num_rows; ++i) {
counters[i] = counter_distribution(generator);
}
}
int ret_fbgemm, ret_ref;
if (isIndex64b) {
ret_ref = sparse_adagrad_ref(
num_rows, // number of rows reading
block_size, // number of parameters per rows
param_size, // total number of parameters
w_ref.data(), // input parameters
g.data(), // input gradients
h_ref.data(), // input momentums
indices.data(), // indices of each row
epsilon,
lr,
weight_decay,
adjust_weight_decay ? counters.data() : nullptr,
counter_halflife);
auto fn_fbgemm = GenerateSparseAdaGrad<std::int64_t>(
block_size, false, prefetch, use_weight_decay);
ret_fbgemm = fn_fbgemm(
num_rows, // number of rows reading
param_size, // total number of parameters
w.data(), // input parameters
g.data(), // input gradients
h.data(), // input momentums
indices.data(), // indices of each row
epsilon,
lr,
weight_decay,
adjust_weight_decay ? counters.data() : nullptr,
counter_halflife);
} else { // 32 bit indices
ret_ref = sparse_adagrad_ref(
num_rows, // number of rows reading
block_size, // number of parameters per rows
param_size, // total number of parameters
w_ref.data(), // input parameters
g.data(), // input gradients
h_ref.data(), // input momentums
indices_32.data(), // indices of each row
epsilon,
lr,
weight_decay,
adjust_weight_decay ? counters.data() : nullptr,
counter_halflife);
auto fn_fbgemm = GenerateSparseAdaGrad<std::int32_t>(
block_size, false, prefetch, use_weight_decay);
ret_fbgemm = fn_fbgemm(
num_rows, // number of rows reading
param_size, // total number of parameters
w.data(), // input parameters
g.data(), // input gradients
h.data(), // input momentums
indices_32.data(), // indices of each row
epsilon,
lr,
weight_decay,
adjust_weight_decay ? counters.data() : nullptr,
counter_halflife);
}
EXPECT_EQ(ret_fbgemm, ret_ref)
<< "return vals differ, reference is: " << ret_ref
<< " ,fbgemm is: " << ret_fbgemm;
EXPECT_TRUE(floatCloseAll(h, h_ref, DEFAULT_TOL, DEFAULT_TOL));
EXPECT_TRUE(floatCloseAll(w, w_ref, DEFAULT_TOL, DEFAULT_TOL));
}
}
TEST_P(SparseAdagradTest, rowwiseTest_two_stages) {
vector<vector<int>> inputs(GetInputs_());
bool isIndex64b, out_of_bounds, use_weight_decay, adjust_weight_decay;
int prefetch;
tie(isIndex64b,
prefetch,
out_of_bounds,
use_weight_decay,
adjust_weight_decay) = GetParam();
for (auto input : inputs) {
int num_rows = input[0];
int block_size = input[1];
int param_size = num_rows * block_size;
vector<float> g(param_size); // gradients
vector<float> h(param_size); // input momentums
vector<float> w(param_size); // input params
vector<float> h_ref(param_size);
vector<float> w_ref(param_size);
random_device r;
default_random_engine generator(r());
uniform_real_distribution<float> values_gen(0, 2);
for (int i = 0; i < param_size; i++) {
h_ref[i] = h[i] = values_gen(generator);
}
for (int i = 0; i < param_size; i++) {
w_ref[i] = w[i] = values_gen(generator);
}
for (int i = 0; i < param_size; i++) {
g[i] = values_gen(generator);
}
vector<std::int64_t> indices(num_rows);
vector<std::int32_t> indices_32(num_rows);
float epsilon = 1e-5;
float lr = 0.5;
float weight_decay = use_weight_decay ? 0.1f : 0.0f;
uniform_int_distribution<std::int64_t> index_distribution(0, num_rows - 1);
for (int i = 0; i < num_rows; ++i) {
indices_32[i] = indices[i] = index_distribution(generator);
}
if (out_of_bounds) {
int idx = index_distribution(generator);
indices_32[idx] = indices[idx] = num_rows;
}
vector<double> counters;
constexpr int64_t counter_halflife = 1e6;
if (adjust_weight_decay) {
uniform_real_distribution<> counter_distribution(0, 2 * counter_halflife);
counters.resize(num_rows);
for (int i = 0; i < num_rows; ++i) {
counters[i] = counter_distribution(generator);
}
}
int ret_fbgemm, ret_ref;
if (isIndex64b) {
ret_ref = rowwise_sparse_adagrad_ref(
num_rows, // number of rows reading
block_size, // number of parameters per rows
param_size, // total number of parameters
w_ref.data(), // input parameters
g.data(), // input gradients
h_ref.data(), // input momentums
indices.data(), // indices of each row
epsilon,
lr,
weight_decay,
adjust_weight_decay ? counters.data() : nullptr,
counter_halflife);
auto fn_fbgemm = GenerateSparseAdaGrad<std::int64_t>(
block_size, true, prefetch, use_weight_decay);
ret_fbgemm = fn_fbgemm(
num_rows, // number of rows reading
param_size, // total number of parameters
w.data(), // input parameters
g.data(), // input gradients
h.data(), // input momentums
indices.data(), // indices of each row
epsilon,
lr,
weight_decay,
adjust_weight_decay ? counters.data() : nullptr,
counter_halflife);
} else { // 32 bit indices
ret_ref = rowwise_sparse_adagrad_ref(
num_rows, // number of rows reading
block_size, // number of parameters per rows
param_size, // total number of parameters
w_ref.data(), // input parameters
g.data(), // input gradients
h_ref.data(), // input momentums
indices_32.data(), // indices of each row
epsilon,
lr,
weight_decay,
adjust_weight_decay ? counters.data() : nullptr,
counter_halflife);
auto fn_fbgemm = GenerateSparseAdaGrad<std::int32_t>(
block_size, true, prefetch, use_weight_decay);
ret_fbgemm = fn_fbgemm(
num_rows, // number of rows reading
param_size, // total number of parameters
w.data(), // input parameters
g.data(), // input gradients
h.data(), // input momentums
indices_32.data(), // indices of each row
epsilon,
lr,
weight_decay,
adjust_weight_decay ? counters.data() : nullptr,
counter_halflife);
}
EXPECT_EQ(ret_fbgemm, ret_ref)
<< "return vals differ, reference is: " << ret_ref
<< " ,fbgemm is: " << ret_fbgemm;
// Set the absolute tolerance of rowwise momentum to 1e-3 because it a
// product of square, add, div which the rounding error can be very high
EXPECT_TRUE(floatCloseAll(h, h_ref, 1.0e-3, 1.0e-3));
EXPECT_TRUE(floatCloseAll(w, w_ref, DEFAULT_TOL, DEFAULT_TOL));
}
}