210 lines
6.0 KiB
C++
210 lines
6.0 KiB
C++
/*
|
|
* Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
* All rights reserved.
|
|
*
|
|
* This source code is licensed under the BSD-style license found in the
|
|
* LICENSE file in the root directory of this source tree.
|
|
*/
|
|
|
|
#include <algorithm>
|
|
#include <cassert>
|
|
#include <chrono>
|
|
#include <cstdint>
|
|
#include <iomanip>
|
|
#include <iostream>
|
|
#include <map>
|
|
#include <random>
|
|
#include <set>
|
|
#include <vector>
|
|
|
|
#include "./BenchUtils.h"
|
|
#include "fbgemm/Fbgemm.h"
|
|
#include "src/RefImplementations.h"
|
|
|
|
using namespace std;
|
|
using namespace fbgemm;
|
|
|
|
static vector<vector<int>> GetInputs_() {
|
|
vector<vector<int>> input_dims = {
|
|
// batch size, number of rows of table, emb dim , avg lengthl
|
|
// TODO: Add more inputs
|
|
// Use these -- but they are slow.
|
|
{10, 4000000, 32, 100},
|
|
{10, 4000000, 64, 100},
|
|
{10, 4000000, 128, 100},
|
|
{10, 4000000, 256, 100},
|
|
// Use these for debugging
|
|
// {2, 16, 128, 10},
|
|
// {10, 4000, 128, 100},
|
|
// {10, 4000, 128, 100},
|
|
// {10, 4000, 128, 100},
|
|
};
|
|
return input_dims;
|
|
}
|
|
|
|
void run_benchmark(
|
|
int batch_size,
|
|
int num_rows,
|
|
int embedding_dim,
|
|
int average_len,
|
|
bool use_32_bit_indices = false,
|
|
bool prefetch = false) {
|
|
vector<char> llc(64L * 1024L * 1024L, 1.0);
|
|
vector<float> g(batch_size * embedding_dim); // gradients
|
|
vector<float> h(num_rows); // input momentums
|
|
vector<float> w(num_rows * embedding_dim); // input params
|
|
vector<float> h_ref(h.size());
|
|
vector<float> w_ref(w.size());
|
|
|
|
default_random_engine generator;
|
|
// normal_distribution<float> h_w_distribution;
|
|
|
|
// TODO: check appropriate vals for g,h,w
|
|
for (size_t i = 0; i < g.size(); ++i) {
|
|
g[i] = 4 + i; // h_w_distribution(generator);
|
|
}
|
|
for (size_t i = 0; i < h.size(); ++i) {
|
|
h_ref[i] = h[i] = 2 + i; // h_w_distribution(generator);
|
|
}
|
|
for (size_t i = 0; i < w.size(); ++i) {
|
|
w_ref[i] = w[i] = 3 + i; // h_w_distribution(generator);
|
|
}
|
|
|
|
// Generate lengths
|
|
uniform_int_distribution<int> length_distribution(
|
|
1, std::min(2 * average_len + 1, num_rows));
|
|
vector<int> offsets(batch_size + 1);
|
|
offsets[0] = 0;
|
|
for (int i = 0; i < batch_size; ++i) {
|
|
offsets[i + 1] = offsets[i] + length_distribution(generator);
|
|
}
|
|
|
|
// Compute the number of indices
|
|
int lengths_sum = offsets[batch_size];
|
|
cout << "lengths_sum " << lengths_sum << endl;
|
|
|
|
// Generate indices
|
|
vector<int64_t> indices;
|
|
vector<int32_t> indices_32;
|
|
|
|
vector<int> container(num_rows);
|
|
|
|
// please note we generate unique indices
|
|
for (int i = 0; i < batch_size; ++i) {
|
|
iota(container.begin(), container.end(), 0);
|
|
shuffle(container.begin(), container.end(), generator);
|
|
copy(
|
|
container.begin(),
|
|
container.begin() + (offsets[i + 1] - offsets[i]),
|
|
back_inserter(indices));
|
|
}
|
|
copy(begin(indices), end(indices), back_inserter(indices_32));
|
|
|
|
float epsilon = 1e-5;
|
|
float lr = 0.5;
|
|
|
|
constexpr int NUM_WARMUP = 4;
|
|
constexpr int NUM_ITER = 10;
|
|
// Only counts the number of bytes for reading embedding table and ignore
|
|
// others. Should be good enough as long as embdding_dim is big enough.
|
|
double bytes = lengths_sum *
|
|
((embedding_dim + 1) * sizeof(float) * 2 +
|
|
(use_32_bit_indices ? 4 : 8)) +
|
|
batch_size * (embedding_dim * sizeof(float) + sizeof(int));
|
|
double bytes_padded = lengths_sum *
|
|
(((embedding_dim * sizeof(float) + 63) / 64 + 1) * 64 * 2 +
|
|
(use_32_bit_indices ? 4 : 8)) +
|
|
batch_size * (embedding_dim * sizeof(float) + sizeof(int));
|
|
|
|
auto kernel_i32 = GenerateRowWiseSparseAdaGradFused<int32_t>(
|
|
embedding_dim, prefetch ? 16 : 0);
|
|
auto kernel_i64 = GenerateRowWiseSparseAdaGradFused<int64_t>(
|
|
embedding_dim, prefetch ? 16 : 0);
|
|
|
|
for (bool flush_cache : {false, true}) {
|
|
double t = measureWithWarmup(
|
|
[&]() {
|
|
if (use_32_bit_indices) {
|
|
kernel_i32(
|
|
batch_size,
|
|
lengths_sum,
|
|
num_rows,
|
|
w.data(),
|
|
g.data(),
|
|
h.data(),
|
|
indices_32.data(),
|
|
offsets.data(),
|
|
epsilon,
|
|
lr);
|
|
} else {
|
|
kernel_i64(
|
|
batch_size,
|
|
lengths_sum,
|
|
num_rows,
|
|
w.data(),
|
|
g.data(),
|
|
h.data(),
|
|
indices.data(),
|
|
offsets.data(),
|
|
epsilon,
|
|
lr);
|
|
}
|
|
},
|
|
NUM_WARMUP,
|
|
NUM_ITER,
|
|
[&]() { llc_flush(llc); });
|
|
|
|
if (flush_cache) {
|
|
cout << setw(20) << "cache flushed";
|
|
} else {
|
|
cout << setw(20) << "cache not flushed";
|
|
}
|
|
if (prefetch) {
|
|
cout << setw(16) << "prefetch on";
|
|
} else {
|
|
cout << setw(16) << "prefetch off";
|
|
}
|
|
|
|
cout << setw(8) << "b/w" << setw(10) << bytes / 1e9 / t << " GB/s"
|
|
<< setw(20) << "effective b/w: " << setw(16) << bytes_padded / 1e9 / t
|
|
<< "GB/s" << setw(8) << " time " << setw(16) << t << endl;
|
|
}
|
|
}
|
|
|
|
int main() {
|
|
vector<vector<int>> inputs(GetInputs_());
|
|
|
|
for (auto& input : inputs) {
|
|
assert(input.size() > 3);
|
|
int batch_size = input[0];
|
|
int num_rows = input[1];
|
|
int embedding_dim = input[2];
|
|
int average_len = input[3];
|
|
|
|
cout << "batch size" << setw(6) << batch_size << setw(10) << "num rows"
|
|
<< setw(16) << num_rows << setw(10) << "emb dim" << setw(6)
|
|
<< embedding_dim << setw(16) << "avg length" << setw(6) << average_len
|
|
<< endl;
|
|
|
|
for (bool use_32_bit_indices : {false, true}) {
|
|
for (bool prefetch : {false, true}) {
|
|
// args: batch sz, num rows, emb dim, avg len, use 32b, prefetch
|
|
cout << (use_32_bit_indices ? " 32" : " 64") << " bit indices";
|
|
if (prefetch) {
|
|
cout << " with prefetching";
|
|
}
|
|
cout << ", ";
|
|
run_benchmark(
|
|
batch_size,
|
|
num_rows,
|
|
embedding_dim,
|
|
average_len,
|
|
use_32_bit_indices,
|
|
prefetch);
|
|
} // prefetch
|
|
} // use_32_bit_indices
|
|
} // for each input
|
|
|
|
return 0;
|
|
}
|