/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include #include #include #include #include "./BenchUtils.h" #include "fbgemm/Fbgemm.h" #include "src/RefImplementations.h" #include "test/EmbeddingSpMDMTestUtils.h" using namespace std; using namespace fbgemm; static vector> GetInputs_() { vector> input_dims = { // batch size, number of rows of table, avg lengthl {10, 4000000, 100}, {20, 4000000, 100}, {40, 4000000, 100}, {50, 4000000, 100}, {10, 10000000, 100}, {20, 10000000, 100}, {40, 10000000, 100}, {50, 10000000, 100}, }; return input_dims; } int run_benchmark( int batch_size, int num_rows, int average_len, bool use_32_bit_indices = false) { constexpr int NWARMUP = 4; constexpr int NITER = 100; int offset_numel = batch_size + 1; constexpr float sparsity = 0.5; vector lengths, offsets, indices; vector lengths_32, offsets_32, indices_32; vector weights; GenerateLengthsIndicesWeights( lengths, lengths_32, offsets, offsets_32, indices, indices_32, weights, batch_size, num_rows, average_len, // average number of indices in a batch EmbeddingSpMDMCornerCase::NONE); // Create mapping table for rowwise sparsity vector mapping_table; CreateMappingTableForRowWiseSparsity(mapping_table, num_rows, sparsity); vector out_indices_32(indices_32.size(), 0); vector out_offsets_32(offsets_32.size(), 0); vector out_weights(weights.size(), 0); vector out_indices(indices.size(), 0); vector out_offsets(offsets.size(), 0); double duration_ref = measureWithWarmup( [&]() { if (use_32_bit_indices) { compressed_indices_remap_ref( offset_numel, indices_32.data(), mapping_table.data(), offsets_32.data(), weights.data(), out_indices_32.data(), out_offsets_32.data(), out_weights.data()); } else { compressed_indices_remap_ref( offset_numel, indices.data(), mapping_table.data(), offsets.data(), weights.data(), out_indices.data(), out_offsets.data(), out_weights.data()); } }, NWARMUP, NITER); double duration = measureWithWarmup( [&]() { if (use_32_bit_indices) { compressed_indices_remap( offset_numel, indices_32.data(), mapping_table.data(), offsets_32.data(), weights.data(), out_indices_32.data(), out_offsets_32.data(), out_weights.data()); } else { compressed_indices_remap( offset_numel, indices.data(), mapping_table.data(), offsets.data(), weights.data(), out_indices.data(), out_offsets.data(), out_weights.data()); } }, NWARMUP, NITER); cout << "reference:" << duration_ref * 1e6 << " (us), "; cout << "Opt:" << duration * 1e6 << " (us) " << endl; return 0; } int main() { int batch_size; int num_rows; int average_len; vector> inputs(GetInputs_()); for (auto& input : inputs) { assert(input.size() == 3); batch_size = input[0]; num_rows = input[1]; average_len = input[2]; cout << "batch size" << setw(6) << batch_size << setw(10) << "num rows" << setw(14) << num_rows << setw(16) << "avg length" << setw(6) << average_len << endl; cout << "64 bit indices, "; run_benchmark(batch_size, num_rows, average_len); cout << "32 bit indices, "; run_benchmark(batch_size, num_rows, average_len, true); cout << endl; } return 0; }