sglang_v0.5.2/pytorch_2.8.0/third_party/fbgemm/bench/SparseDenseMMFP32Benchmark.cc

134 lines
3.8 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "bench/BenchUtils.h"
#include "fbgemm/FbgemmSparse.h"
#include "fbgemm/Utils.h"
#include "fbgemm/spmmUtils.h"
#include "src/RefImplementations.h"
#include <iomanip>
#include <iostream>
using namespace std;
using namespace fbgemm;
int main(int, char**) {
vector<vector<int>> shapes = getSparseMatrixShapes();
// C is MxN -> CT is NxM
// A is MxK -> AT is KxM
// B is KxN -> BT is NxK
cout << setw(7) << "index" << setw(7) << "m" << setw(7) << "n" << setw(7)
<< "k" << setw(7) << "fnz" << setw(15) << "eff_GFLOPS" << setw(15)
<< "real_GFLOPS" << endl;
int index = 0;
// for (int s = 64; s <= 128; s *= 2)
for (auto const& s : shapes) {
int m = s[0];
int n = s[1];
int k = s[2];
for (float fnz = 0.20; fnz >= 0.20; fnz -= 0.01) {
auto aData = getRandomSparseVector(m * k);
auto bData = getRandomSparseVector(k * n, fnz);
auto cData = getRandomSparseVector(m * n);
aligned_vector<float> atData(k * m);
aligned_vector<float> btData(n * k);
aligned_vector<float> ctData(n * m);
aligned_vector<float> ctDataRef(n * m);
aligned_vector<float> ctDataIntrin(n * m);
transpose_matrix(m, k, aData.data(), k, atData.data(), m);
transpose_matrix(k, n, bData.data(), n, btData.data(), k);
unique_ptr<CSRMatrix<float>> csr = fbgemmDenseToCSR(n, k, btData.data());
// We calculate C^T = B^T x A^T
int ldat = m;
// int ldbt = k;
int ldct = m;
double effective_flop = m * n * k * 2;
constexpr int NWARMUP = 20;
constexpr int NITER = 100;
auto secs_intrin = measureWithWarmup(
[&]() {
SparseDenseMM(
n,
m,
csr->rowPtr.data(),
csr->colIdx.data(),
csr->values.data(),
atData.data(),
ldat,
ctDataIntrin.data(),
ldct);
},
NWARMUP,
NITER,
[&]() {
cache_evict(atData);
cache_evict(csr->rowPtr);
cache_evict(csr->colIdx);
cache_evict(csr->values);
cache_evict(ctDataIntrin);
});
// printMatrix(matrix_op_t::NoTranspose, btData.data(), n, k, k,
// "btData");
// printMatrix(matrix_op_t::NoTranspose, atData.data(), k, m, m,
// "atData");
// printMatrix(matrix_op_t::NoTranspose, ctData.data(), n, m, m,
// "ctData");
cblas_sgemm_ref(
matrix_op_t::NoTranspose,
matrix_op_t::NoTranspose,
m,
n,
k,
1.0f,
aData.data(),
k,
bData.data(),
n,
0.0f,
cData.data(),
n);
transpose_matrix(m, n, cData.data(), n, ctDataRef.data(), m);
// printMatrix(matrix_op_t::NoTranspose, ctDataRef.data(), n, m, m,
// "ctData_Ref");
//
// Compare results
for (size_t i = 0; i < ctDataRef.size(); i++) {
if (std::abs(ctDataRef[i] - ctDataIntrin[i]) > 1e-3) {
fprintf(
stderr,
"Error: Results differ ref %f and test %f at %ld\n",
ctDataRef[i],
ctDataIntrin[i],
i);
return 1;
}
}
double effective_gflops_intrin = effective_flop / secs_intrin / 1e9;
cout << "[" << setw(5) << index << "]" << setw(7) << m << setw(7) << n
<< setw(7) << k << fixed << setw(7) << setprecision(2) << fnz
<< setw(15) << setprecision(5) << effective_gflops_intrin << setw(15)
<< setprecision(5) << fnz * effective_gflops_intrin << endl;
++index;
}
}
}