#include #include #include #include #include #include #include #include namespace torch::nativert { class WeightsTest : public ::testing::Test { protected: void SetUp() override { static constexpr std::string_view source = R"(graph(%foo, %bar, %baz): %o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) return(%o2, %baz) )"; graph = stringToGraph(source); placement = std::make_unique(c10::Device(c10::DeviceType::CPU)); } std::shared_ptr graph; std::unique_ptr placement; }; TEST_F(WeightsTest, ConstructEmptyStateDict) { std::unordered_map stateDict; Weights weights(graph.get(), stateDict, *placement); // Check that weights are initialized correctly EXPECT_TRUE(weights.parameters().empty()); EXPECT_TRUE(weights.buffers().empty()); EXPECT_FALSE(weights.contains("non_existent_weight")); } TEST_F(WeightsTest, SetAndGetValue) { std::unordered_map stateDict; Weights weights(graph.get(), stateDict, *placement); at::Tensor tensor = at::ones({2, 2}); weights.setValue("added_weight", tensor); EXPECT_TRUE(weights.contains("added_weight")); EXPECT_EQ(weights.at("added_weight").sizes(), tensor.sizes()); } } // namespace torch::nativert using namespace ::testing; struct ContainsTensorDict : torch::CustomClassHolder { explicit ContainsTensorDict(at::Tensor t) : t_(t) {} explicit ContainsTensorDict(c10::Dict dict) { t_ = dict.at(std::string("init_tensor")); } c10::Dict serialize() const { c10::Dict dict; dict.insert(std::string("init_tensor"), t_); return dict; } at::Tensor t_; }; static auto reg = torch::class_("testing", "ContainsTensorDict") .def(torch::init()) .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self) -> c10::Dict { return self->serialize(); }, // __setstate__ [](c10::Dict data) -> c10::intrusive_ptr { return c10::make_intrusive(std::move(data)); }); TEST(CustomWeightsTest, TestCustomObjWithContainedTensor) { // Save auto customObj = c10::make_intrusive(torch::tensor({1, 2, 3})); const auto bytes = torch::jit::pickle_save(c10::IValue(std::move(customObj))); // Load const auto loadedCustomObj = torch::jit::pickle_load_obj(std::string{bytes.begin(), bytes.end()}); EXPECT_TRUE(loadedCustomObj.isObject()); EXPECT_EQ( loadedCustomObj.to>() ->t_[0] .item(), 1); }