/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include #include #include #include "./BenchUtils.h" #include "fbgemm/Fbgemm.h" #include "src/RefImplementations.h" using namespace std; using namespace fbgemm; static vector> GetInputs_() { vector> input_dims = { // batch size, number of rows of table, emb dim , avg lengthl // TODO: Add more inputs // Use these -- but they are slow. {10, 4000000, 32, 100}, {10, 4000000, 64, 100}, {10, 4000000, 128, 100}, {10, 4000000, 256, 100}, // Use these for debugging // {2, 16, 128, 10}, // {10, 4000, 128, 100}, // {10, 4000, 128, 100}, // {10, 4000, 128, 100}, }; return input_dims; } void run_benchmark( int batch_size, int num_rows, int embedding_dim, int average_len, bool use_32_bit_indices = false, bool prefetch = false) { vector llc(64L * 1024L * 1024L, 1.0); vector g(batch_size * embedding_dim); // gradients vector h(num_rows); // input momentums vector w(num_rows * embedding_dim); // input params vector h_ref(h.size()); vector w_ref(w.size()); default_random_engine generator; // normal_distribution h_w_distribution; // TODO: check appropriate vals for g,h,w for (size_t i = 0; i < g.size(); ++i) { g[i] = 4 + i; // h_w_distribution(generator); } for (size_t i = 0; i < h.size(); ++i) { h_ref[i] = h[i] = 2 + i; // h_w_distribution(generator); } for (size_t i = 0; i < w.size(); ++i) { w_ref[i] = w[i] = 3 + i; // h_w_distribution(generator); } // Generate lengths uniform_int_distribution length_distribution( 1, std::min(2 * average_len + 1, num_rows)); vector offsets(batch_size + 1); offsets[0] = 0; for (int i = 0; i < batch_size; ++i) { offsets[i + 1] = offsets[i] + length_distribution(generator); } // Compute the number of indices int lengths_sum = offsets[batch_size]; cout << "lengths_sum " << lengths_sum << endl; // Generate indices vector indices; vector indices_32; vector container(num_rows); // please note we generate unique indices for (int i = 0; i < batch_size; ++i) { iota(container.begin(), container.end(), 0); shuffle(container.begin(), container.end(), generator); copy( container.begin(), container.begin() + (offsets[i + 1] - offsets[i]), back_inserter(indices)); } copy(begin(indices), end(indices), back_inserter(indices_32)); float epsilon = 1e-5; float lr = 0.5; constexpr int NUM_WARMUP = 4; constexpr int NUM_ITER = 10; // Only counts the number of bytes for reading embedding table and ignore // others. Should be good enough as long as embdding_dim is big enough. double bytes = lengths_sum * ((embedding_dim + 1) * sizeof(float) * 2 + (use_32_bit_indices ? 4 : 8)) + batch_size * (embedding_dim * sizeof(float) + sizeof(int)); double bytes_padded = lengths_sum * (((embedding_dim * sizeof(float) + 63) / 64 + 1) * 64 * 2 + (use_32_bit_indices ? 4 : 8)) + batch_size * (embedding_dim * sizeof(float) + sizeof(int)); auto kernel_i32 = GenerateRowWiseSparseAdaGradFused( embedding_dim, prefetch ? 16 : 0); auto kernel_i64 = GenerateRowWiseSparseAdaGradFused( embedding_dim, prefetch ? 16 : 0); for (bool flush_cache : {false, true}) { double t = measureWithWarmup( [&]() { if (use_32_bit_indices) { kernel_i32( batch_size, lengths_sum, num_rows, w.data(), g.data(), h.data(), indices_32.data(), offsets.data(), epsilon, lr); } else { kernel_i64( batch_size, lengths_sum, num_rows, w.data(), g.data(), h.data(), indices.data(), offsets.data(), epsilon, lr); } }, NUM_WARMUP, NUM_ITER, [&]() { llc_flush(llc); }); if (flush_cache) { cout << setw(20) << "cache flushed"; } else { cout << setw(20) << "cache not flushed"; } if (prefetch) { cout << setw(16) << "prefetch on"; } else { cout << setw(16) << "prefetch off"; } cout << setw(8) << "b/w" << setw(10) << bytes / 1e9 / t << " GB/s" << setw(20) << "effective b/w: " << setw(16) << bytes_padded / 1e9 / t << "GB/s" << setw(8) << " time " << setw(16) << t << endl; } } int main() { vector> inputs(GetInputs_()); for (auto& input : inputs) { assert(input.size() > 3); int batch_size = input[0]; int num_rows = input[1]; int embedding_dim = input[2]; int average_len = input[3]; cout << "batch size" << setw(6) << batch_size << setw(10) << "num rows" << setw(16) << num_rows << setw(10) << "emb dim" << setw(6) << embedding_dim << setw(16) << "avg length" << setw(6) << average_len << endl; for (bool use_32_bit_indices : {false, true}) { for (bool prefetch : {false, true}) { // args: batch sz, num rows, emb dim, avg len, use 32b, prefetch cout << (use_32_bit_indices ? " 32" : " 64") << " bit indices"; if (prefetch) { cout << " with prefetching"; } cout << ", "; run_benchmark( batch_size, num_rows, embedding_dim, average_len, use_32_bit_indices, prefetch); } // prefetch } // use_32_bit_indices } // for each input return 0; }