#include #include #include #ifdef AT_PER_OPERATOR_HEADERS #include #endif namespace at { class Tensor; namespace native { template static void _assert_match(const O& original, const C& compared, const std::string& name) { if (compared) { bool equal = (original == compared.value()); if (!equal) { std::stringstream msg; msg << "Tensor " << name << " mismatch! Expected: " << compared.value() << ", Got: " << original; throw std::runtime_error(msg.str()); } } } void _assert_tensor_metadata_meta_symint(at::Tensor const& tensor, at::OptionalSymIntArrayRef sizes, at::OptionalSymIntArrayRef strides, std::optional dtype, std::optional device, std::optional layout) { _assert_match(tensor.sym_sizes(), sizes, "sizes"); _assert_match(tensor.sym_strides(), strides, "strides"); _assert_match(tensor.dtype(), dtype, "dtype"); if (tensor.device().type() != DeviceType::Meta) { _assert_match(tensor.device(), device, "device"); } _assert_match(tensor.layout(), layout, "layout"); } void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, std::optional dtype, std::optional device, std::optional layout) { _assert_match(tensor.sizes(), sizes, "sizes"); _assert_match(tensor.strides(), strides, "strides"); _assert_match(tensor.dtype(), dtype, "dtype"); if (tensor.device().type() != DeviceType::Meta) { _assert_match(tensor.device(), device, "device"); } _assert_match(tensor.layout(), layout, "layout"); } } } // namespace at::native