sglang_v0.5.2/pytorch_2.8.0/third_party/fbgemm/src/FbgemmSparseDense.cc

308 lines
8.8 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#define FBGEMM_EXPORTS
#include "fbgemm/FbgemmSparse.h"
#include <algorithm>
#include <array>
#include <cassert>
#include <cstdio>
#include <cstring>
#include <memory>
#include <sstream>
#include <vector>
#include "fbgemm/Utils.h"
#include "fbgemm/spmmUtils.h"
using namespace std;
namespace fbgemm {
template <typename T>
FBGEMM_API std::unique_ptr<CSRMatrix<T>>
fbgemmDenseToCSR(int R, int C, const T* inp, int ld) {
unique_ptr<CSRMatrix<T>> csr(new CSRMatrix<T>());
csr->rowPtr.push_back(0);
int nnz = 0;
for (int i = 0; i < R; ++i) {
for (int j = 0; j < C; ++j) {
if (inp[i * ld + j] != 0) {
csr->values.push_back(inp[i * ld + j]);
csr->colIdx.push_back(j);
nnz++;
}
}
csr->rowPtr.push_back(nnz);
}
return csr;
}
template <typename T>
std::unique_ptr<CSRMatrix<T>> fbgemmDenseToCSR(int R, int C, const T* inp) {
return fbgemmDenseToCSR<T>(R, C, inp, C);
}
template FBGEMM_API std::unique_ptr<CSRMatrix<int8_t>>
fbgemmDenseToCSR(int R, int C, const int8_t* inp);
template FBGEMM_API std::unique_ptr<CSRMatrix<float>>
fbgemmDenseToCSR(int R, int C, const float* inp);
template FBGEMM_API std::unique_ptr<CSRMatrix<int8_t>>
fbgemmDenseToCSR(int R, int C, const int8_t* inp, int ld);
template FBGEMM_API std::unique_ptr<CSRMatrix<float>>
fbgemmDenseToCSR(int R, int C, const float* inp, int ld);
template <typename T, int RB, int CB>
FBGEMM_API std::unique_ptr<BCSRMatrix<T, RB, CB>>
fbgemmDenseToBCSR(int R, int C, const T* inp, int ld) {
unique_ptr<BCSRMatrix<T, RB, CB>> bcsr(new BCSRMatrix<T, RB, CB>(R, C));
bcsr->pack(inp, ld);
return bcsr;
}
template <typename T, int RB, int CB>
FBGEMM_API std::unique_ptr<BCSRMatrix<T, RB, CB>>
fbgemmDenseToBCSR(int R, int C, const T* inp) {
return fbgemmDenseToBCSR<T, RB, CB>(R, C, inp, C);
}
#if __cplusplus < 201703L
template <typename T, int RB, int CB>
constexpr int BCSRMatrix<T, RB, CB>::RB;
template <typename T, int RB, int CB>
constexpr int BCSRMatrix<T, RB, CB>::CB;
template <typename T, int RB, int CB>
constexpr int BCSRMatrix<T, RB, CB>::COLTILE;
#endif
template <typename T, int RB, int CB>
void BCSRMatrix<T, RB, CB>::pack(const DTYPE* src, size_t ld) {
rowBPtr.push_back(0);
int nnzb = 0;
int numCOLTILEs = (C + COLTILE - 1) / COLTILE;
int rowBlocks = (R + RB - 1) / RB;
for (int jt = 0; jt < numCOLTILEs; ++jt) {
for (int i = 0; i < rowBlocks; ++i) {
int curCols = min(C - jt * COLTILE, COLTILE);
int curColBlocks = (curCols + CB - 1) / CB;
std::array<int32_t, RB> rowSum = {0};
for (int j = 0; j < curColBlocks; ++j) {
// is the whole block zero?
bool isCurrentBlockNonZero = false;
for (int ib = 0; ib < RB; ++ib) {
// break if already found a non-zero element or
// out of bounds
if (isCurrentBlockNonZero || (i * RB + ib) >= R) {
break;
}
for (int jb = 0; jb < CB; ++jb) {
// within bound?
if ((jt * COLTILE + j * CB + jb) >= C) {
continue;
} else {
if (src[(i * RB + ib) * ld + jt * COLTILE + j * CB + jb] != 0) {
isCurrentBlockNonZero = true;
break;
}
}
}
}
if (isCurrentBlockNonZero) {
for (int ib = 0; ib < RB; ++ib) {
for (int jb = 0; jb < CB; ++jb) {
if ((i * RB + ib) >= R || (jt * COLTILE + j * CB + jb) >= C) {
// zero fill
values.push_back(0);
} else {
DTYPE val =
src[(i * RB + ib) * ld + jt * COLTILE + j * CB + jb];
values.push_back(val);
rowSum[ib] += static_cast<int32_t>(val);
}
}
}
colBIdx.push_back(j);
nnzb++;
}
}
rowBPtr.push_back(nnzb);
// Note: in row_offsets we don't need to subtract the constant term
// weight_zero_point * C because it's 0 as weight_zero_point is always 0
// for sparse kernels.
for (int ib = 0; ib < RB; ++ib) {
if (jt) {
row_offsets[i * RB + ib] += rowSum[ib];
} else {
row_offsets[i * RB + ib] = rowSum[ib];
}
}
}
}
}
template <typename T, int RB, int CB>
void BCSRMatrix<T, RB, CB>::pack(const DTYPE* src) {
pack(src, C);
}
template <typename T, int RB, int CB>
void BCSRMatrix<T, RB, CB>::unpack(T* dst, size_t ld) {
// zero out destination
memset(dst, 0, R * C * sizeof(T));
int numCOLTILEs = (C + COLTILE - 1) / COLTILE;
int rowBlocks = (R + RB - 1) / RB;
for (int jt = 0; jt < numCOLTILEs; ++jt) {
for (int i = 0; i < rowBlocks; ++i) {
// For the current tile, rowBPtr starts from currentTileIdx (i.e., jt) * R
for (int r = rowBPtr[jt * R + i]; r < rowBPtr[jt * R + i + 1]; ++r) {
int curColIdx = colBIdx[r];
for (int ib = 0; ib < RB; ++ib) {
for (int jb = 0; jb < CB; ++jb) {
// Are we within bounds of destination matrix?
if ((i * RB + ib) < R && (jt * COLTILE + curColIdx * CB + jb) < C) {
dst[(i * RB + ib) * ld + jt * COLTILE + curColIdx * CB + jb] =
values[r * RB * CB + ib * CB + jb];
}
}
}
}
}
}
}
template <typename T, int RB, int CB>
void BCSRMatrix<T, RB, CB>::unpack(T* dst) {
unpack(dst, C);
}
template struct BCSRMatrix<int8_t, 1, 4>;
template struct CSRMatrix<int8_t>;
template struct CSRMatrix<float>;
template FBGEMM_API std::unique_ptr<BCSRMatrix<int8_t, 1, 4>>
fbgemmDenseToBCSR(int R, int C, const int8_t* inp);
template FBGEMM_API std::unique_ptr<BCSRMatrix<int8_t, 1, 4>>
fbgemmDenseToBCSR(int R, int C, const int8_t* inp, int ld);
void SparseDenseMM(
int M,
int N,
const int* row_ptr,
const int* col_idx,
const float* values,
const float* B,
int ldb,
float* C,
int ldc,
bool accum) {
static const auto iset = fbgemmInstructionSet();
// Run time CPU detection
if (isZmm(iset)) {
internal::SparseDenseMMAvx512(
M, N, row_ptr, col_idx, values, B, ldb, C, ldc, accum);
} else if (isYmm(iset)) {
internal::SparseDenseMMAvx2(
M, N, row_ptr, col_idx, values, B, ldb, C, ldc, accum);
} else {
sparseDenseMMRef(M, N, row_ptr, col_idx, values, B, ldb, C, ldc, accum);
}
}
template <bool FUSE_RELU, QuantizationGranularity Q_GRAN>
FBGEMM_API void fbgemmSparseDenseInt8MM(
int N,
const std::unique_ptr<BCSRMatrix<>>& bcsr,
const uint8_t* B,
int ldb,
int32_t* C_i32,
uint8_t* C_u8,
int ldc,
trRequantizationParams_t& rParams,
bool accum,
int thread_id,
int num_threads) {
static const auto iset = fbgemmInstructionSet();
// No parallelization currently
// All work is done by thread 0
if (thread_id > 0) {
return;
}
// Run time CPU detection
if (isZmm(iset)) {
internal::SparseDenseInt8MMAvx512<FUSE_RELU, Q_GRAN>(
N,
bcsr,
B,
ldb,
C_i32,
C_u8,
ldc,
rParams,
accum,
thread_id,
num_threads);
} else if (isYmm(iset)) {
internal::SparseDenseInt8MMAvx2<FUSE_RELU, Q_GRAN>(
N,
bcsr,
B,
ldb,
C_i32,
C_u8,
ldc,
rParams,
accum,
thread_id,
num_threads);
} else {
sparseDenseInt8MMRef<FUSE_RELU, Q_GRAN>(
N,
bcsr,
B,
ldb,
C_i32,
C_u8,
ldc,
rParams,
accum,
thread_id,
num_threads);
}
}
#define CREATE_INSTANCE(FUSE_RELU, QGRAN) \
template FBGEMM_API void fbgemmSparseDenseInt8MM<FUSE_RELU, QGRAN>( \
int N, \
const std::unique_ptr<BCSRMatrix<>>& bcsr, \
const uint8_t* B, \
int ldb, \
int32_t* C_i32, \
uint8_t* C_u8, \
int ldc, \
trRequantizationParams_t& rParams, \
bool accum, \
int thread_id, \
int num_threads);
CREATE_INSTANCE(true, QuantizationGranularity::TENSOR)
CREATE_INSTANCE(true, QuantizationGranularity::OUT_CHANNEL)
CREATE_INSTANCE(false, QuantizationGranularity::TENSOR)
CREATE_INSTANCE(false, QuantizationGranularity::OUT_CHANNEL)
#undef CREATE_INSTANCE
} // namespace fbgemm