sglang_v0.5.2/pytorch_2.8.0/third_party/fbgemm/bench/RowwiseAdagradFusedBenchmar...

210 lines
6.0 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <algorithm>
#include <cassert>
#include <chrono>
#include <cstdint>
#include <iomanip>
#include <iostream>
#include <map>
#include <random>
#include <set>
#include <vector>
#include "./BenchUtils.h"
#include "fbgemm/Fbgemm.h"
#include "src/RefImplementations.h"
using namespace std;
using namespace fbgemm;
static vector<vector<int>> GetInputs_() {
vector<vector<int>> input_dims = {
// batch size, number of rows of table, emb dim , avg lengthl
// TODO: Add more inputs
// Use these -- but they are slow.
{10, 4000000, 32, 100},
{10, 4000000, 64, 100},
{10, 4000000, 128, 100},
{10, 4000000, 256, 100},
// Use these for debugging
// {2, 16, 128, 10},
// {10, 4000, 128, 100},
// {10, 4000, 128, 100},
// {10, 4000, 128, 100},
};
return input_dims;
}
void run_benchmark(
int batch_size,
int num_rows,
int embedding_dim,
int average_len,
bool use_32_bit_indices = false,
bool prefetch = false) {
vector<char> llc(64L * 1024L * 1024L, 1.0);
vector<float> g(batch_size * embedding_dim); // gradients
vector<float> h(num_rows); // input momentums
vector<float> w(num_rows * embedding_dim); // input params
vector<float> h_ref(h.size());
vector<float> w_ref(w.size());
default_random_engine generator;
// normal_distribution<float> h_w_distribution;
// TODO: check appropriate vals for g,h,w
for (size_t i = 0; i < g.size(); ++i) {
g[i] = 4 + i; // h_w_distribution(generator);
}
for (size_t i = 0; i < h.size(); ++i) {
h_ref[i] = h[i] = 2 + i; // h_w_distribution(generator);
}
for (size_t i = 0; i < w.size(); ++i) {
w_ref[i] = w[i] = 3 + i; // h_w_distribution(generator);
}
// Generate lengths
uniform_int_distribution<int> length_distribution(
1, std::min(2 * average_len + 1, num_rows));
vector<int> offsets(batch_size + 1);
offsets[0] = 0;
for (int i = 0; i < batch_size; ++i) {
offsets[i + 1] = offsets[i] + length_distribution(generator);
}
// Compute the number of indices
int lengths_sum = offsets[batch_size];
cout << "lengths_sum " << lengths_sum << endl;
// Generate indices
vector<int64_t> indices;
vector<int32_t> indices_32;
vector<int> container(num_rows);
// please note we generate unique indices
for (int i = 0; i < batch_size; ++i) {
iota(container.begin(), container.end(), 0);
shuffle(container.begin(), container.end(), generator);
copy(
container.begin(),
container.begin() + (offsets[i + 1] - offsets[i]),
back_inserter(indices));
}
copy(begin(indices), end(indices), back_inserter(indices_32));
float epsilon = 1e-5;
float lr = 0.5;
constexpr int NUM_WARMUP = 4;
constexpr int NUM_ITER = 10;
// Only counts the number of bytes for reading embedding table and ignore
// others. Should be good enough as long as embdding_dim is big enough.
double bytes = lengths_sum *
((embedding_dim + 1) * sizeof(float) * 2 +
(use_32_bit_indices ? 4 : 8)) +
batch_size * (embedding_dim * sizeof(float) + sizeof(int));
double bytes_padded = lengths_sum *
(((embedding_dim * sizeof(float) + 63) / 64 + 1) * 64 * 2 +
(use_32_bit_indices ? 4 : 8)) +
batch_size * (embedding_dim * sizeof(float) + sizeof(int));
auto kernel_i32 = GenerateRowWiseSparseAdaGradFused<int32_t>(
embedding_dim, prefetch ? 16 : 0);
auto kernel_i64 = GenerateRowWiseSparseAdaGradFused<int64_t>(
embedding_dim, prefetch ? 16 : 0);
for (bool flush_cache : {false, true}) {
double t = measureWithWarmup(
[&]() {
if (use_32_bit_indices) {
kernel_i32(
batch_size,
lengths_sum,
num_rows,
w.data(),
g.data(),
h.data(),
indices_32.data(),
offsets.data(),
epsilon,
lr);
} else {
kernel_i64(
batch_size,
lengths_sum,
num_rows,
w.data(),
g.data(),
h.data(),
indices.data(),
offsets.data(),
epsilon,
lr);
}
},
NUM_WARMUP,
NUM_ITER,
[&]() { llc_flush(llc); });
if (flush_cache) {
cout << setw(20) << "cache flushed";
} else {
cout << setw(20) << "cache not flushed";
}
if (prefetch) {
cout << setw(16) << "prefetch on";
} else {
cout << setw(16) << "prefetch off";
}
cout << setw(8) << "b/w" << setw(10) << bytes / 1e9 / t << " GB/s"
<< setw(20) << "effective b/w: " << setw(16) << bytes_padded / 1e9 / t
<< "GB/s" << setw(8) << " time " << setw(16) << t << endl;
}
}
int main() {
vector<vector<int>> inputs(GetInputs_());
for (auto& input : inputs) {
assert(input.size() > 3);
int batch_size = input[0];
int num_rows = input[1];
int embedding_dim = input[2];
int average_len = input[3];
cout << "batch size" << setw(6) << batch_size << setw(10) << "num rows"
<< setw(16) << num_rows << setw(10) << "emb dim" << setw(6)
<< embedding_dim << setw(16) << "avg length" << setw(6) << average_len
<< endl;
for (bool use_32_bit_indices : {false, true}) {
for (bool prefetch : {false, true}) {
// args: batch sz, num rows, emb dim, avg len, use 32b, prefetch
cout << (use_32_bit_indices ? " 32" : " 64") << " bit indices";
if (prefetch) {
cout << " with prefetching";
}
cout << ", ";
run_benchmark(
batch_size,
num_rows,
embedding_dim,
average_len,
use_32_bit_indices,
prefetch);
} // prefetch
} // use_32_bit_indices
} // for each input
return 0;
}