/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include #include // for accumulate and iota #include #include #include #include #include "./EmbeddingSpMDMTestUtils.h" #include "fbgemm/Fbgemm.h" #include "fbgemm/FbgemmConvert.h" #include "src/RefImplementations.h" using namespace std; using namespace fbgemm; static vector> GetInputs_() { vector> input_dims = { // batch size, number of rows of table, emb dim , avg length {1, 8, 8, 4}, {2, 8, 16, 4}, {10, 4000, 32, 100}, {100, 4000, 32, 100}, {10, 4000, 64, 100}, {10, 4000, 128, 100}, {4, 400, 256, 10}, {10, 4000, 48, 100}, {10, 4000, 40, 100}, {10, 4000, 56, 100}, {10, 4000, 1, 100}, {10, 4000, 4, 100}, // These were from C2 tests {10, 40, 16, 10}, {10, 40, 85, 10}, {10, 40, 8, 10}, {10, 40, 96, 10}, {10, 40, 163, 10}, }; return input_dims; } namespace { class EmbeddingSpMDMTest : public testing::TestWithParam> {}; class rowwiseSparseEmbeddingSpMDMTest : public testing::TestWithParam< tuple> {}; class IndexRemapTest : public testing::TestWithParam> {}; } // namespace vector prefetch_distances = {0, 16, 1000000}; INSTANTIATE_TEST_CASE_P( InstantiationName, EmbeddingSpMDMTest, ::testing::Combine( ::testing::ValuesIn(prefetch_distances), ::testing::Values( UNWEIGHTED, WEIGHTED, POSITIONAL_WEIGHTED), // use_weight ::testing::Values( NONE, EMPTY_INDICES, OUT_OF_BOUND_INDICES, UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM), ::testing::Values(FLOAT, FLOAT16, BFLOAT16), ::testing::Values(FLOAT, FLOAT16, BFLOAT16))); INSTANTIATE_TEST_CASE_P( InstantiationName, rowwiseSparseEmbeddingSpMDMTest, ::testing::Combine( ::testing::ValuesIn(prefetch_distances), ::testing::Values( UNWEIGHTED, WEIGHTED, POSITIONAL_WEIGHTED), // use_weight ::testing::Values( NONE, EMPTY_INDICES, OUT_OF_BOUND_INDICES, UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM))); INSTANTIATE_TEST_CASE_P( InstantiationName, IndexRemapTest, ::testing::Combine( ::testing::ValuesIn({1, 2, 5, 10}), // batch size ::testing::ValuesIn({1, 50, 100, 1000}), // number of rows ::testing::ValuesIn({1, 5, 16}), // avg len ::testing::Bool(), // is index 64 bit? ::testing::Bool())); // per sample weights? TEST_P(EmbeddingSpMDMTest, basicTest) { vector> inputs(GetInputs_()); random_device r; default_random_engine generator(r()); uniform_int_distribution<> bool_dist(0, 1); bool isIndex64b = bool_dist(generator); bool isOffset64b = bool_dist(generator); bool normalize_by_lengths = bool_dist(generator); bool use_offsets = bool_dist(generator); bool use_output_input_stride = bool_dist(generator); bool test_thread_local = bool_dist(generator); int prefetch; EmbeddingSpMDMWeightChoice weight_choice; EmbeddingSpMDMCornerCase corner_case; EmbeddingSpMDMInputDtypeChoice in_type; EmbeddingSpMDMOutputDtypeChoice out_type; tie(prefetch, weight_choice, corner_case, in_type, out_type) = GetParam(); bool is_wt_positional = weight_choice == POSITIONAL_WEIGHTED; bool use_weight = weight_choice != UNWEIGHTED; bool isFp16 = in_type == FLOAT16; bool isBf16 = in_type == BFLOAT16; bool is_output_float = (out_type == FLOAT); bool is_output_bfloat16 = (out_type == BFLOAT16); if (corner_case != NONE || is_wt_positional) { // Check corner case only for subset of tests. if (isFp16 || normalize_by_lengths || use_output_input_stride || !is_output_float || test_thread_local) { return; } } if (is_wt_positional && !use_weight) { // weight positional only makes sense when use_weight is true return; } #if defined(__APPLE__) || defined(_WIN32) if (in_type == BFLOAT16 && out_type == FLOAT) { return; } #endif for (auto input : inputs) { int batch_size = input[0]; int num_rows = input[1]; int embedding_dim = input[2]; int average_len = input[3]; int output_stride = use_output_input_stride ? embedding_dim * 2 + 3 : -1; int input_stride = use_output_input_stride ? embedding_dim * 2 + 3 : -1; // Create embedding table vector embedding_table( num_rows * (use_output_input_stride ? input_stride : embedding_dim)); normal_distribution embedding_distribution; for (int i = 0; i < num_rows; ++i) { for (int j = 0; j < embedding_dim; ++j) { embedding_table [i * (use_output_input_stride ? input_stride : embedding_dim) + j] = embedding_distribution(generator); } } vector embedding_table_fp16; if (isFp16) { embedding_table_fp16.resize(embedding_table.size()); FloatToFloat16_simd( embedding_table.data(), embedding_table_fp16.data(), embedding_table.size()); } vector embedding_table_bf16; if (isBf16) { embedding_table_bf16.resize(embedding_table.size()); FloatToBfloat16_simd( embedding_table.data(), embedding_table_bf16.data(), embedding_table.size()); } vector lengths, offsets, indices; vector lengths_32, offsets_32, indices_32; vector weights; int lengths_sum = GenerateLengthsIndicesWeights( lengths, lengths_32, offsets, offsets_32, indices, indices_32, weights, batch_size, num_rows, average_len, corner_case); const int64_t* offsets_or_lengths = (use_offsets ? offsets : lengths).data(); const int32_t* offsets_or_lengths_32 = (use_offsets ? offsets_32 : lengths_32).data(); // Sentries at the end to make sure masking is done correctly not to write // out of bounds. constexpr int num_sentries = 10; const float sentry_value = 1.0f; int output_size_wo_sentries = batch_size * (use_output_input_stride ? output_stride : embedding_dim); vector output_ref(output_size_wo_sentries + num_sentries); vector output(output_ref.size()); vector output_ref_fp16(output.size()), output_fp16(output.size()); vector output_ref_bf16(output.size()), output_bf16(output.size()); for (size_t i = output_size_wo_sentries; i < output.size(); ++i) { output_ref[i] = sentry_value; output[i] = sentry_value; output_ref_fp16[i] = cpu_float2half_rn(sentry_value); output_fp16[i] = cpu_float2half_rn(sentry_value); FloatToBfloat16_ref(&sentry_value, &output_ref_bf16[i], 1); FloatToBfloat16_ref(&sentry_value, &output_bf16[i], 1); } bool success, success_ref; #define TEST_BASE( \ table, \ indices, \ offsets_or_lengths, \ output_ref, \ output, \ InType, \ IndexType, \ OffsetType, \ OutType, \ THREAD_LOCAL) \ success_ref = EmbeddingSpMDM_ref( \ embedding_dim, \ batch_size, \ lengths_sum, \ num_rows, \ table.data(), \ corner_case == EMPTY_INDICES ? nullptr : indices.data(), \ offsets_or_lengths, \ use_weight ? weights.data() : nullptr, \ normalize_by_lengths, \ output_ref.data(), \ is_wt_positional, \ use_offsets, \ output_stride, \ input_stride, \ true, \ false, \ is_output_bfloat16, \ isBf16); \ \ auto kernel = GenerateEmbeddingSpMDMWithStrides< \ InType, \ IndexType, \ OffsetType, \ OutType, \ THREAD_LOCAL>( \ embedding_dim, \ use_weight, \ normalize_by_lengths, \ prefetch, \ is_wt_positional, \ use_offsets, \ output_stride, \ input_stride, \ true, \ false, \ is_output_bfloat16, \ isBf16); \ success = kernel( \ batch_size, \ lengths_sum, \ num_rows, \ table.data(), \ corner_case == EMPTY_INDICES ? nullptr : indices.data(), \ offsets_or_lengths, \ use_weight ? weights.data() : nullptr, \ output.data()); #define TEST_THREAD_LOCAL( \ table, \ indices, \ offsets_or_lengths, \ output_ref, \ output, \ InType, \ IndexType, \ OffsetType, \ OutType) \ if (test_thread_local) { \ TEST_BASE( \ table, \ indices, \ offsets_or_lengths, \ output_ref, \ output, \ InType, \ IndexType, \ OffsetType, \ OutType, \ true); \ } else { \ TEST_BASE( \ table, \ indices, \ offsets_or_lengths, \ output_ref, \ output, \ InType, \ IndexType, \ OffsetType, \ OutType, \ false); \ } #define TEST_OUT_TYPE( \ table, indices, offsets_or_lengths, InType, IndexType, OffsetType) \ if (is_output_float) { \ TEST_THREAD_LOCAL( \ table, \ indices, \ offsets_or_lengths, \ output_ref, \ output, \ InType, \ IndexType, \ OffsetType, \ float); \ } else if (is_output_bfloat16) { \ TEST_THREAD_LOCAL( \ table, \ indices, \ offsets_or_lengths, \ output_ref_bf16, \ output_bf16, \ InType, \ IndexType, \ OffsetType, \ bfloat16); \ } else { \ TEST_THREAD_LOCAL( \ table, \ indices, \ offsets_or_lengths, \ output_ref_fp16, \ output_fp16, \ InType, \ IndexType, \ OffsetType, \ float16); \ } #define TEST_OFFSET_TYPE(table, indices, InType, IndexType) \ if (isOffset64b) { \ TEST_OUT_TYPE( \ table, indices, offsets_or_lengths, InType, IndexType, int64_t); \ } else { \ TEST_OUT_TYPE( \ table, indices, offsets_or_lengths_32, InType, IndexType, int32_t); \ } #define TEST_INDEX_TYPE(table, InType) \ if (isIndex64b) { \ TEST_OFFSET_TYPE(table, indices, InType, int64_t); \ } else { \ TEST_OFFSET_TYPE(table, indices_32, InType, int32_t); \ } if (isFp16) { TEST_INDEX_TYPE(embedding_table_fp16, float16); } else if (isBf16) { TEST_INDEX_TYPE(embedding_table_bf16, bfloat16); } else { TEST_INDEX_TYPE(embedding_table, float); } #undef TEST_INDEX_TYPE #undef TEST_OFFSET_TYPE #undef TEST_OUT_TYPE #undef TEST_THREAD_LOCAL #undef TEST_BASE // Check correctness EXPECT_EQ(success, success_ref) << "Reference and JIT impl did not both succeed"; if (corner_case == OUT_OF_BOUND_INDICES || corner_case == UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM) { EXPECT_EQ(success, false); } auto get_actual = [&](int offset) { if (is_output_float) return output[offset]; else if (is_output_bfloat16) { float v; Bfloat16ToFloat_ref(&output_bf16[offset], &v, 1); return v; } else return cpu_half2float(output_fp16[offset]); }; auto get_expected = [&](int offset) { if (is_output_float) return output_ref[offset]; else if (is_output_bfloat16) { float v; Bfloat16ToFloat_ref(&output_ref_bf16[offset], &v, 1); return v; } else return cpu_half2float(output_ref_fp16[offset]); }; if (success) { for (int i = 0; i < batch_size; ++i) { for (int j = 0; j < embedding_dim; ++j) { int offset = i * (use_output_input_stride ? output_stride : embedding_dim) + j; float actual = get_actual(offset); float expected = get_expected(offset); EXPECT_EQ(actual, expected) << "results differ at (" << i << ") reference: " << expected << ", FBGEMM: " << actual << " emb dim :" << embedding_dim; } } for (int offset = output_size_wo_sentries; offset < output_size_wo_sentries + num_sentries; ++offset) { float actual = get_actual(offset); float expected = get_expected(offset); EXPECT_EQ(actual, expected) << "results differ at (" << offset << ") reference: " << expected << ", FBGEMM: " << actual << " emb dim :" << embedding_dim; } } } // end for input } TEST_P(rowwiseSparseEmbeddingSpMDMTest, rowwiseSparseTest) { vector> inputs(GetInputs_()); random_device r; default_random_engine generator(r()); uniform_int_distribution<> bool_dist(0, 1); bool isFp16 = bool_dist(generator); bool isIndex64b = bool_dist(generator); bool isOffset64b = bool_dist(generator); bool normalize_by_lengths = bool_dist(generator); bool use_offsets = bool_dist(generator); bool is_output_float = bool_dist(generator); int prefetch; EmbeddingSpMDMWeightChoice weight_choice; EmbeddingSpMDMCornerCase corner_case; tie(prefetch, weight_choice, corner_case) = GetParam(); bool is_wt_positional = weight_choice == POSITIONAL_WEIGHTED; bool use_weight = weight_choice != UNWEIGHTED; if (!is_output_float) { // Don't test is_output_float for row-wise sparse embedding spmdm return; } constexpr float sparsity = 0.7; for (auto input : inputs) { int batch_size = input[0]; int num_rows = input[1]; int embedding_dim = input[2]; int average_len = input[3]; // Create mapping table for rowwise sparsity vector mapping_table; int num_compressed_rows = CreateMappingTableForRowWiseSparsity(mapping_table, num_rows, sparsity); // Create embedding table vector embedding_table(num_compressed_rows * embedding_dim); normal_distribution embedding_distribution; for (size_t i = 0; i < embedding_table.size(); ++i) { embedding_table[i] = embedding_distribution(generator); } vector embedding_table_fp16; if (isFp16) { embedding_table_fp16.resize(embedding_table.size()); FloatToFloat16_simd( embedding_table.data(), embedding_table_fp16.data(), embedding_table.size()); } vector lengths, offsets, indices; vector lengths_32, offsets_32, indices_32; vector weights; int lengths_sum = GenerateLengthsIndicesWeights( lengths, lengths_32, offsets, offsets_32, indices, indices_32, weights, batch_size, num_rows, average_len, corner_case); const int64_t* offsets_or_lengths = (use_offsets ? offsets : lengths).data(); const int32_t* offsets_or_lengths_32 = (use_offsets ? offsets_32 : lengths_32).data(); vector output_sls_ref(batch_size * embedding_dim); vector output_slws_ref(output_sls_ref.size()), output_sls(output_sls_ref.size()), output_slws(output_sls_ref.size()); vector& output_ref = use_weight ? output_slws_ref : output_sls_ref; vector& output = use_weight ? output_slws : output_sls; bool success, success_ref; if (isOffset64b) { if (isIndex64b) { if (isFp16) { success_ref = EmbeddingSpMDMRowWiseSparse_ref( embedding_dim, batch_size, lengths_sum, num_rows, embedding_table_fp16.data(), corner_case == EMPTY_INDICES ? nullptr : indices.data(), mapping_table.data(), offsets_or_lengths, use_weight ? weights.data() : nullptr, normalize_by_lengths, output_ref.data(), is_wt_positional, use_offsets); auto kernel = GenerateEmbeddingSpMDMRowWiseSparse( embedding_dim, use_weight, normalize_by_lengths, prefetch, is_wt_positional, use_offsets); success = kernel( batch_size, lengths_sum, num_rows, embedding_table_fp16.data(), corner_case == EMPTY_INDICES ? nullptr : indices.data(), offsets_or_lengths, use_weight ? weights.data() : nullptr, output.data(), mapping_table.data()); } else { success_ref = EmbeddingSpMDMRowWiseSparse_ref( embedding_dim, batch_size, lengths_sum, num_rows, embedding_table.data(), corner_case == EMPTY_INDICES ? nullptr : indices.data(), mapping_table.data(), offsets_or_lengths, use_weight ? weights.data() : nullptr, normalize_by_lengths, output_ref.data(), is_wt_positional, use_offsets); auto kernel = GenerateEmbeddingSpMDMRowWiseSparse( embedding_dim, use_weight, normalize_by_lengths, prefetch, is_wt_positional, use_offsets); success = kernel( batch_size, lengths_sum, num_rows, embedding_table.data(), corner_case == EMPTY_INDICES ? nullptr : indices.data(), offsets_or_lengths, use_weight ? weights.data() : nullptr, output.data(), mapping_table.data()); } } else { if (isFp16) { success_ref = EmbeddingSpMDMRowWiseSparse_ref( embedding_dim, batch_size, lengths_sum, num_rows, embedding_table_fp16.data(), corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), mapping_table.data(), offsets_or_lengths, use_weight ? weights.data() : nullptr, normalize_by_lengths, output_ref.data(), is_wt_positional, use_offsets); auto kernel = GenerateEmbeddingSpMDMRowWiseSparse( embedding_dim, use_weight, normalize_by_lengths, prefetch, is_wt_positional, use_offsets); success = kernel( batch_size, lengths_sum, num_rows, embedding_table_fp16.data(), corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), offsets_or_lengths, use_weight ? weights.data() : nullptr, output.data(), mapping_table.data()); } else { success_ref = EmbeddingSpMDMRowWiseSparse_ref( embedding_dim, batch_size, lengths_sum, num_rows, embedding_table.data(), corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), mapping_table.data(), offsets_or_lengths, use_weight ? weights.data() : nullptr, normalize_by_lengths, output_ref.data(), is_wt_positional, use_offsets); auto kernel = GenerateEmbeddingSpMDMRowWiseSparse( embedding_dim, use_weight, normalize_by_lengths, prefetch, is_wt_positional, use_offsets); success = kernel( batch_size, lengths_sum, num_rows, embedding_table.data(), corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), offsets_or_lengths, use_weight ? weights.data() : nullptr, output.data(), mapping_table.data()); } } } else { if (isIndex64b) { if (isFp16) { success_ref = EmbeddingSpMDMRowWiseSparse_ref( embedding_dim, batch_size, lengths_sum, num_rows, embedding_table_fp16.data(), corner_case == EMPTY_INDICES ? nullptr : indices.data(), mapping_table.data(), offsets_or_lengths, use_weight ? weights.data() : nullptr, normalize_by_lengths, output_ref.data(), is_wt_positional, use_offsets); auto kernel = GenerateEmbeddingSpMDMRowWiseSparse( embedding_dim, use_weight, normalize_by_lengths, prefetch, is_wt_positional, use_offsets); success = kernel( batch_size, lengths_sum, num_rows, embedding_table_fp16.data(), corner_case == EMPTY_INDICES ? nullptr : indices.data(), offsets_or_lengths_32, use_weight ? weights.data() : nullptr, output.data(), mapping_table.data()); } else { success_ref = EmbeddingSpMDMRowWiseSparse_ref( embedding_dim, batch_size, lengths_sum, num_rows, embedding_table.data(), corner_case == EMPTY_INDICES ? nullptr : indices.data(), mapping_table.data(), offsets_or_lengths, use_weight ? weights.data() : nullptr, normalize_by_lengths, output_ref.data(), is_wt_positional, use_offsets); auto kernel = GenerateEmbeddingSpMDMRowWiseSparse( embedding_dim, use_weight, normalize_by_lengths, prefetch, is_wt_positional, use_offsets); success = kernel( batch_size, lengths_sum, num_rows, embedding_table.data(), corner_case == EMPTY_INDICES ? nullptr : indices.data(), offsets_or_lengths_32, use_weight ? weights.data() : nullptr, output.data(), mapping_table.data()); } } else { if (isFp16) { success_ref = EmbeddingSpMDMRowWiseSparse_ref( embedding_dim, batch_size, lengths_sum, num_rows, embedding_table_fp16.data(), corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), mapping_table.data(), offsets_or_lengths, use_weight ? weights.data() : nullptr, normalize_by_lengths, output_ref.data(), is_wt_positional, use_offsets); auto kernel = GenerateEmbeddingSpMDMRowWiseSparse( embedding_dim, use_weight, normalize_by_lengths, prefetch, is_wt_positional, use_offsets); success = kernel( batch_size, lengths_sum, num_rows, embedding_table_fp16.data(), corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), offsets_or_lengths_32, use_weight ? weights.data() : nullptr, output.data(), mapping_table.data()); } else { success_ref = EmbeddingSpMDMRowWiseSparse_ref( embedding_dim, batch_size, lengths_sum, num_rows, embedding_table.data(), corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), mapping_table.data(), offsets_or_lengths, use_weight ? weights.data() : nullptr, normalize_by_lengths, output_ref.data(), is_wt_positional, use_offsets); auto kernel = GenerateEmbeddingSpMDMRowWiseSparse( embedding_dim, use_weight, normalize_by_lengths, prefetch, is_wt_positional, use_offsets); success = kernel( batch_size, lengths_sum, num_rows, embedding_table.data(), corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), offsets_or_lengths_32, use_weight ? weights.data() : nullptr, output.data(), mapping_table.data()); } } } // Check correctness EXPECT_EQ(success, success_ref) << "Reference and JIT impl did not both succeed"; if (corner_case == OUT_OF_BOUND_INDICES || corner_case == UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM) { EXPECT_EQ(success, false); } if (success) { for (size_t i = 0; i < output.size(); ++i) { EXPECT_EQ(output[i], output_ref[i]) << "results differ at (" << i << ") reference: " << output_ref[i] << ", FBGEMM: " << output[i] << " emb dim :" << embedding_dim; } } } // end for input } TEST_P(IndexRemapTest, basicTest) { int batch_size, num_rows, avg_len; bool isIndex64b, per_sample_weights; tie(batch_size, num_rows, avg_len, isIndex64b, per_sample_weights) = GetParam(); constexpr float sparsity = 0.5; vector lengths, offsets, indices; vector lengths_32, offsets_32, indices_32; vector weights; GenerateLengthsIndicesWeights( lengths, lengths_32, offsets, offsets_32, indices, indices_32, weights, batch_size, num_rows, avg_len, // average number of indices in a batch EmbeddingSpMDMCornerCase::NONE); // Create mapping table for rowwise sparsity vector mapping_table; CreateMappingTableForRowWiseSparsity(mapping_table, num_rows, sparsity); // outputs vector out_indices_32(indices_32.size(), 0); vector out_offsets_32(offsets_32.size(), 0); vector out_weights(weights.size(), 0); vector out_indices(indices.size(), 0); vector out_offsets(offsets.size(), 0); // reference outputs vector out_indices_32_ref(indices_32.size(), 0); vector out_offsets_32_ref(offsets_32.size(), 0); vector out_weights_ref(weights.size(), 0); vector out_indices_ref(indices.size(), 0); vector out_offsets_ref(offsets.size(), 0); // number of elements in the offset array ( it's equal to batch_size + 1) int offset_numel = offsets_32.size(); if (isIndex64b) { if (per_sample_weights) { compressed_indices_remap( offset_numel, indices.data(), mapping_table.data(), offsets.data(), weights.data(), out_indices.data(), out_offsets.data(), out_weights.data()); compressed_indices_remap_ref( offset_numel, indices.data(), mapping_table.data(), offsets.data(), weights.data(), out_indices_ref.data(), out_offsets_ref.data(), out_weights_ref.data()); } else { compressed_indices_remap( offset_numel, indices.data(), mapping_table.data(), offsets.data(), nullptr, out_indices.data(), out_offsets.data(), nullptr); compressed_indices_remap_ref( offset_numel, indices.data(), mapping_table.data(), offsets.data(), nullptr, out_indices_ref.data(), out_offsets_ref.data(), nullptr); } } else { if (per_sample_weights) { compressed_indices_remap( offset_numel, indices_32.data(), mapping_table.data(), offsets_32.data(), weights.data(), out_indices_32.data(), out_offsets_32.data(), out_weights.data()); compressed_indices_remap_ref( offset_numel, indices_32.data(), mapping_table.data(), offsets_32.data(), weights.data(), out_indices_32_ref.data(), out_offsets_32_ref.data(), out_weights_ref.data()); } else { compressed_indices_remap( offset_numel, indices_32.data(), mapping_table.data(), offsets_32.data(), nullptr, out_indices_32.data(), out_offsets_32.data(), nullptr); compressed_indices_remap_ref( offset_numel, indices_32.data(), mapping_table.data(), offsets_32.data(), nullptr, out_indices_32_ref.data(), out_offsets_32_ref.data(), nullptr); } } if (isIndex64b) { EXPECT_EQ(out_offsets, out_offsets_ref) << "offsets don't match"; for (int i = 0; i < out_offsets[offset_numel - 1]; ++i) { EXPECT_EQ(out_indices[i], out_indices_ref[i]) << "indices don't match at " << i; } } else { EXPECT_EQ(out_offsets_32, out_offsets_32_ref) << "offsets don't match"; for (int i = 0; i < out_offsets_32[offset_numel - 1]; ++i) { EXPECT_EQ(out_indices_32[i], out_indices_32_ref[i]) << "indices don't match at " << i; } } if (per_sample_weights) { size_t len = isIndex64b ? out_offsets[offset_numel - 1] : out_offsets_32[offset_numel - 1]; for (size_t i = 0; i < len; ++i) { EXPECT_EQ(out_weights[i], out_weights_ref[i]) << "weights don't match at" << i; } } }