sglang_v0.5.2/pytorch_2.8.0/third_party/fbgemm/test/I64Test.cc

104 lines
2.5 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <gtest/gtest.h>
#include <limits>
#include <random>
#include "./TestUtils.h"
#include "bench/BenchUtils.h"
#include "fbgemm/FbgemmI64.h"
#include "src/RefImplementations.h"
using namespace std;
using namespace fbgemm;
namespace {
class Int64GemmTest : public testing::Test {
protected:
vector<array<int, 3>> GenParams() {
vector<array<int, 3>> shapes;
random_device r;
default_random_engine generator(r());
uniform_int_distribution<int> dist_dim(1, 128);
for (int i = 0; i < 256; ++i) {
shapes.push_back(
{dist_dim(generator), dist_dim(generator), dist_dim(generator)});
}
return shapes;
}
};
} // anonymous namespace
TEST_F(Int64GemmTest, test) {
const auto shapes = GenParams();
for (const auto s : shapes) {
const int m = s[0];
const int n = s[1];
const int k = s[2];
aligned_vector<int64_t> A(m * k);
aligned_vector<int64_t> B(k * n);
for (matrix_op_t transa :
{matrix_op_t::NoTranspose, matrix_op_t::Transpose}) {
const int lda = transa == matrix_op_t::Transpose ? m : k;
for (matrix_op_t transb :
{matrix_op_t::NoTranspose, matrix_op_t::Transpose}) {
const int ldb = transb == matrix_op_t::Transpose ? k : n;
aligned_vector<int64_t> C(m * n);
aligned_vector<int64_t> C_ref = C;
randFill(
A,
numeric_limits<int64_t>::lowest(),
numeric_limits<int64_t>::max());
randFill(
B,
numeric_limits<int64_t>::lowest(),
numeric_limits<int64_t>::max());
for (const bool accumulate : {false, true}) {
cblas_gemm_i64_i64acc(
transa,
transb,
m,
n,
k,
A.data(),
lda,
B.data(),
ldb,
accumulate,
C.data(),
n);
cblas_gemm_i64_i64acc_ref(
transa,
transb,
m,
n,
k,
A.data(),
lda,
B.data(),
ldb,
accumulate,
C_ref.data(),
n);
compare_validate_buffers<int64_t>(
C_ref.data(), C.data(), m, n, n, 0L);
}
} // transb
} // transa
} // for each shape
}