116 lines
4.4 KiB
C++
116 lines
4.4 KiB
C++
// Copyright 2004-present Facebook. All Rights Reserved.
|
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
|
|
|
#include <ATen/core/Tensor.h>
|
|
#include <ATen/Dispatch.h>
|
|
#include <c10/util/irange.h>
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#include <ATen/NativeFunctions.h>
|
|
#else
|
|
#include <ATen/ops/_rowwise_prune_native.h>
|
|
#include <ATen/ops/empty.h>
|
|
#endif
|
|
|
|
namespace at::native {
|
|
|
|
namespace {
|
|
|
|
template <typename input_t>
|
|
std::tuple<Tensor, Tensor> _rowwise_prune_helper(
|
|
const Tensor& weights, const Tensor& mask,
|
|
ScalarType compressed_indices_dtype) {
|
|
int num_non_masked_rows = 0;
|
|
auto mask_contig = mask.contiguous();
|
|
auto mask_data = mask_contig.data_ptr<bool>();
|
|
for (const auto i : c10::irange(mask.numel())) {
|
|
num_non_masked_rows += (((mask_data[i] == true)) ? 1 : 0);
|
|
}
|
|
int num_cols = weights.size(1);
|
|
auto pruned_2d_tensor = at::empty({num_non_masked_rows, num_cols},
|
|
weights.options());
|
|
auto compressed_indices_mapping = at::empty({mask.numel()},
|
|
compressed_indices_dtype);
|
|
AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half,
|
|
at::ScalarType::BFloat16,
|
|
weights.scalar_type(),
|
|
"rowwise_prune_helper", [&]() {
|
|
auto* pruned_2d_tensor_data = pruned_2d_tensor.data_ptr<scalar_t>();
|
|
auto compressed_indices_mapping_data =
|
|
compressed_indices_mapping.data_ptr<input_t>();
|
|
auto weights_data = weights.data_ptr<scalar_t>();
|
|
int last_row_kept = 0;
|
|
for (const auto i : c10::irange(mask.numel())) {
|
|
if (mask_data[i]) {
|
|
memcpy(pruned_2d_tensor_data + last_row_kept * num_cols,
|
|
weights_data + i * num_cols,
|
|
num_cols * sizeof (scalar_t));
|
|
compressed_indices_mapping_data[i] = last_row_kept;
|
|
last_row_kept++;
|
|
} else {
|
|
compressed_indices_mapping_data[i] = -1;
|
|
}
|
|
}
|
|
});
|
|
return std::tuple<Tensor, Tensor>(pruned_2d_tensor,
|
|
compressed_indices_mapping);
|
|
}
|
|
|
|
} // namespace
|
|
|
|
|
|
// This operator introduces sparsity to the 'weights' matrix with the help
|
|
// of the importance indicator 'mask'.
|
|
//
|
|
// A row is considered important and not pruned if the mask value for that
|
|
// particular row is 1(True) and not important otherwise.
|
|
//
|
|
// This operator doesn't zero out the pruned rows in-place. Instead, it
|
|
// returns a tuple that contains a pruned weights tensor as well as a map that
|
|
// can be used to look up the original row in the pruned weights tensor.
|
|
// We refer this map as 'compressed indices map' going forward.
|
|
|
|
// The 'compressed indices map' is an 1D tensor that contains one entry per
|
|
// original row in 'weights'. The array index is the index for the original
|
|
// non-pruned weight tensor and the value would be the re-mapped index in the
|
|
// pruned weights tensor. If the value for a index is -1, it means the
|
|
// corresponding row has been pruned from the original weight tensor.
|
|
|
|
// Arguments:
|
|
// 'weights' - two dimensional matrix that needs to be prune.
|
|
// 'mask' - 1D boolean tensor that represents whether a row is important or
|
|
// not. A mask value of 1 means the row should be kept and 0 means the row
|
|
// should be pruned.
|
|
//
|
|
// Returns:
|
|
// A tuple containing two tensors,
|
|
// 1. A pruned weight tensor that contains only the weights that are preserved
|
|
// post pruning.
|
|
// 2. An 1D tensor that contains the mapping between original weight row and
|
|
// the corresponding row in the pruned weights tensor.
|
|
std::tuple<Tensor, Tensor> _rowwise_prune(const Tensor& weights,
|
|
const Tensor& mask,
|
|
ScalarType compressed_indices_dtype) {
|
|
TORCH_CHECK(weights.ndimension() == 2,
|
|
"'weights' should have 2 dimensions.");
|
|
TORCH_CHECK(
|
|
mask.numel() == weights.size(0),
|
|
"Number of elements in 'mask' should be equivalent to the "
|
|
"number of rows in 'weights'."
|
|
)
|
|
TORCH_CHECK(
|
|
compressed_indices_dtype == ScalarType::Int ||
|
|
compressed_indices_dtype == ScalarType::Long,
|
|
"compressed_indices_dtype should be either int(int32) or long(int64).");
|
|
|
|
if (compressed_indices_dtype == at::ScalarType::Int) {
|
|
return _rowwise_prune_helper<int32_t>(weights, mask,
|
|
compressed_indices_dtype);
|
|
}
|
|
return _rowwise_prune_helper<int64_t>(weights, mask,
|
|
compressed_indices_dtype);
|
|
}
|
|
|
|
} // namespace at::native
|