/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include #include #include "./TestUtils.h" #include "bench/BenchUtils.h" #include "fbgemm/FbgemmI64.h" #include "src/RefImplementations.h" using namespace std; using namespace fbgemm; namespace { class Int64GemmTest : public testing::Test { protected: vector> GenParams() { vector> shapes; random_device r; default_random_engine generator(r()); uniform_int_distribution dist_dim(1, 128); for (int i = 0; i < 256; ++i) { shapes.push_back( {dist_dim(generator), dist_dim(generator), dist_dim(generator)}); } return shapes; } }; } // anonymous namespace TEST_F(Int64GemmTest, test) { const auto shapes = GenParams(); for (const auto s : shapes) { const int m = s[0]; const int n = s[1]; const int k = s[2]; aligned_vector A(m * k); aligned_vector B(k * n); for (matrix_op_t transa : {matrix_op_t::NoTranspose, matrix_op_t::Transpose}) { const int lda = transa == matrix_op_t::Transpose ? m : k; for (matrix_op_t transb : {matrix_op_t::NoTranspose, matrix_op_t::Transpose}) { const int ldb = transb == matrix_op_t::Transpose ? k : n; aligned_vector C(m * n); aligned_vector C_ref = C; randFill( A, numeric_limits::lowest(), numeric_limits::max()); randFill( B, numeric_limits::lowest(), numeric_limits::max()); for (const bool accumulate : {false, true}) { cblas_gemm_i64_i64acc( transa, transb, m, n, k, A.data(), lda, B.data(), ldb, accumulate, C.data(), n); cblas_gemm_i64_i64acc_ref( transa, transb, m, n, k, A.data(), lda, B.data(), ldb, accumulate, C_ref.data(), n); compare_validate_buffers( C_ref.data(), C.data(), m, n, n, 0L); } } // transb } // transa } // for each shape }