91 lines
2.6 KiB
C++
91 lines
2.6 KiB
C++
#include <limits>
|
|
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
|
#include <ATen/core/Tensor.h>
|
|
#include <c10/core/Device.h>
|
|
#include <c10/core/Layout.h>
|
|
#include <c10/core/MemoryFormat.h>
|
|
#include <c10/core/Scalar.h>
|
|
#include <c10/core/ScalarType.h>
|
|
#include <optional>
|
|
|
|
#ifndef AT_PER_OPERATOR_HEADERS
|
|
#include <ATen/Functions.h>
|
|
#include <ATen/NativeFunctions.h>
|
|
#else
|
|
#include <ATen/ops/_functional_sym_constrain_range_native.h>
|
|
#include <ATen/ops/_make_dep_token_native.h>
|
|
#include <ATen/ops/empty.h>
|
|
#include <ATen/ops/sym_constrain_range_native.h>
|
|
#include <ATen/ops/sym_constrain_range_for_size_native.h>
|
|
#include <ATen/ops/_functional_sym_constrain_range_for_size_native.h>
|
|
#endif
|
|
|
|
namespace at::native {
|
|
|
|
void sym_constrain_range(
|
|
const Scalar& size,
|
|
std::optional<int64_t> min,
|
|
std::optional<int64_t> max) {
|
|
|
|
int64_t min_val = min.has_value() ? min.value() : std::numeric_limits<int64_t>::min();
|
|
int64_t max_val = max.has_value() ? max.value() : std::numeric_limits<int64_t>::max();
|
|
int64_t size_as_int = size.toLong();
|
|
|
|
TORCH_CHECK(
|
|
max_val >= min_val,
|
|
"Max must be greater than or equal to min. Got min=",
|
|
min_val,
|
|
" max=",
|
|
max_val
|
|
);
|
|
|
|
TORCH_CHECK(
|
|
min_val <= size_as_int && size_as_int <= max_val,
|
|
"Invalid value range for ",
|
|
size_as_int,
|
|
" between [",
|
|
min_val,
|
|
", ",
|
|
max_val,
|
|
"]."
|
|
);
|
|
}
|
|
|
|
Tensor _functional_sym_constrain_range(
|
|
const Scalar& size,
|
|
std::optional<int64_t> min,
|
|
std::optional<int64_t> max,
|
|
const Tensor& dep_token) {
|
|
sym_constrain_range(size, min, max);
|
|
return dep_token.clone();
|
|
}
|
|
|
|
void sym_constrain_range_for_size(const Scalar& size, std::optional<int64_t> min, std::optional<int64_t> max) {
|
|
int64_t min_val = min.has_value() ? min.value() : 0;
|
|
if (max.has_value() && max.value() <= 2) {
|
|
TORCH_CHECK(false, "Max value to constrain_range_for_size must be greater than 2. got: ", max.value());
|
|
}
|
|
sym_constrain_range(size, min_val, max);
|
|
}
|
|
|
|
Tensor _functional_sym_constrain_range_for_size(
|
|
const Scalar& size,
|
|
std::optional<int64_t> min,
|
|
std::optional<int64_t> max,
|
|
const Tensor& dep_token) {
|
|
sym_constrain_range_for_size(size, min, max);
|
|
return dep_token.clone();
|
|
}
|
|
|
|
Tensor _make_dep_token_cpu(
|
|
std::optional<ScalarType> dtype_opt,
|
|
std::optional<Layout> layout_opt,
|
|
std::optional<Device> device_opt,
|
|
std::optional<bool> pin_memory_opt,
|
|
std::optional<c10::MemoryFormat> memory_format_opt) {
|
|
return at::empty(
|
|
{}, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
|
}
|
|
|
|
} // namespace at::native
|