/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #pragma once #include #include namespace fbgemm { enum EmbeddingSpMDMCornerCase { NONE, EMPTY_INDICES, OUT_OF_BOUND_INDICES, UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM, }; enum EmbeddingSpMDMWeightChoice { UNWEIGHTED, WEIGHTED, POSITIONAL_WEIGHTED, }; enum EmbeddingSpMDMDtypeChoice { FLOAT, FLOAT16, BFLOAT16, }; using EmbeddingSpMDMInputDtypeChoice = EmbeddingSpMDMDtypeChoice; using EmbeddingSpMDMOutputDtypeChoice = EmbeddingSpMDMDtypeChoice; /** * @return lengths_sum */ int GenerateLengthsIndicesWeights( std::vector& lengths, std::vector& lengths_32, std::vector& offsets, std::vector& offsets_32, std::vector& indices, std::vector& indices_32, std::vector& weights, int batch_size, int num_rows, int average_len, EmbeddingSpMDMCornerCase corner_case); /** * @return num_compressed_rows */ int CreateMappingTableForRowWiseSparsity( std::vector& mapping_table, int num_rows, float sparsity); } // namespace fbgemm