sglang_v0.5.2/pytorch_2.8.0/third_party/fbgemm/bench/EmbeddingIndexRemappingBenc...

162 lines
4.2 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <algorithm>
#include <array>
#include <chrono>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "./BenchUtils.h"
#include "fbgemm/Fbgemm.h"
#include "src/RefImplementations.h"
#include "test/EmbeddingSpMDMTestUtils.h"
using namespace std;
using namespace fbgemm;
static vector<vector<int>> GetInputs_() {
vector<vector<int>> input_dims = {
// batch size, number of rows of table, avg lengthl
{10, 4000000, 100},
{20, 4000000, 100},
{40, 4000000, 100},
{50, 4000000, 100},
{10, 10000000, 100},
{20, 10000000, 100},
{40, 10000000, 100},
{50, 10000000, 100},
};
return input_dims;
}
int run_benchmark(
int batch_size,
int num_rows,
int average_len,
bool use_32_bit_indices = false) {
constexpr int NWARMUP = 4;
constexpr int NITER = 100;
int offset_numel = batch_size + 1;
constexpr float sparsity = 0.5;
vector<int64_t> lengths, offsets, indices;
vector<int32_t> lengths_32, offsets_32, indices_32;
vector<float> weights;
GenerateLengthsIndicesWeights(
lengths,
lengths_32,
offsets,
offsets_32,
indices,
indices_32,
weights,
batch_size,
num_rows,
average_len, // average number of indices in a batch
EmbeddingSpMDMCornerCase::NONE);
// Create mapping table for rowwise sparsity
vector<int32_t> mapping_table;
CreateMappingTableForRowWiseSparsity(mapping_table, num_rows, sparsity);
vector<int32_t> out_indices_32(indices_32.size(), 0);
vector<int32_t> out_offsets_32(offsets_32.size(), 0);
vector<float> out_weights(weights.size(), 0);
vector<int64_t> out_indices(indices.size(), 0);
vector<int64_t> out_offsets(offsets.size(), 0);
double duration_ref = measureWithWarmup(
[&]() {
if (use_32_bit_indices) {
compressed_indices_remap_ref<int32_t>(
offset_numel,
indices_32.data(),
mapping_table.data(),
offsets_32.data(),
weights.data(),
out_indices_32.data(),
out_offsets_32.data(),
out_weights.data());
} else {
compressed_indices_remap_ref<int64_t>(
offset_numel,
indices.data(),
mapping_table.data(),
offsets.data(),
weights.data(),
out_indices.data(),
out_offsets.data(),
out_weights.data());
}
},
NWARMUP,
NITER);
double duration = measureWithWarmup(
[&]() {
if (use_32_bit_indices) {
compressed_indices_remap<int32_t>(
offset_numel,
indices_32.data(),
mapping_table.data(),
offsets_32.data(),
weights.data(),
out_indices_32.data(),
out_offsets_32.data(),
out_weights.data());
} else {
compressed_indices_remap<int64_t>(
offset_numel,
indices.data(),
mapping_table.data(),
offsets.data(),
weights.data(),
out_indices.data(),
out_offsets.data(),
out_weights.data());
}
},
NWARMUP,
NITER);
cout << "reference:" << duration_ref * 1e6 << " (us), ";
cout << "Opt:" << duration * 1e6 << " (us) " << endl;
return 0;
}
int main() {
int batch_size;
int num_rows;
int average_len;
vector<vector<int>> inputs(GetInputs_());
for (auto& input : inputs) {
assert(input.size() == 3);
batch_size = input[0];
num_rows = input[1];
average_len = input[2];
cout << "batch size" << setw(6) << batch_size << setw(10) << "num rows"
<< setw(14) << num_rows << setw(16) << "avg length" << setw(6)
<< average_len << endl;
cout << "64 bit indices, ";
run_benchmark(batch_size, num_rows, average_len);
cout << "32 bit indices, ";
run_benchmark(batch_size, num_rows, average_len, true);
cout << endl;
}
return 0;
}