#include #include #include #include using namespace ::testing; namespace torch::nativert { TEST(GraphTest, Basic) { static constexpr std::string_view source = R"(graph(%foo, %bar, %baz): %o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) return(%o2, %baz) )"; auto graph = stringToGraph(source); EXPECT_EQ(graph->inputs().size(), 3); EXPECT_EQ(graph->inputs()[0]->name(), "foo"); EXPECT_EQ(graph->inputs()[1]->name(), "bar"); EXPECT_EQ(graph->inputs()[2]->name(), "baz"); const auto& nodes = graph->nodes(); EXPECT_EQ(nodes.size(), 3); // First node is the input node auto it = nodes.begin(); { const auto& node = *it; EXPECT_EQ(node.target(), "prim.Input"); EXPECT_EQ(node.inputs().size(), 0); EXPECT_EQ(node.outputs().size(), 3); EXPECT_EQ(node.outputs()[0]->name(), "foo"); EXPECT_EQ(node.outputs()[1]->name(), "bar"); EXPECT_EQ(node.outputs()[2]->name(), "baz"); } { std::advance(it, 1); const auto& node = *it; EXPECT_EQ(node.target(), "aten.foo"); EXPECT_EQ(node.inputs().size(), 2); EXPECT_EQ(node.inputs()[0].name, "self"); EXPECT_EQ(node.inputs()[1].name, "target"); EXPECT_EQ(node.attributes().size(), 1); EXPECT_EQ(node.attributes()[0].name, "alpha"); } { std::advance(it, 1); const auto& node = *it; EXPECT_EQ(node.target(), "prim.Output"); EXPECT_EQ(node.inputs().size(), 2); EXPECT_EQ(node.inputs()[0].name, "o2"); EXPECT_EQ(node.inputs()[1].name, "baz"); } EXPECT_EQ(graph->outputs().size(), 2); EXPECT_EQ(graph->outputs()[0]->name(), "o2"); EXPECT_EQ(graph->outputs()[1]->name(), "baz"); const auto& values = graph->values(); EXPECT_EQ(values.size(), 5); std::vector valueNames; valueNames.reserve(values.size()); for (const auto& v : values) { valueNames.emplace_back(v->name()); } std::sort(valueNames.begin(), valueNames.end()); EXPECT_THAT( valueNames, ContainerEq(std::vector({"bar", "baz", "foo", "o1", "o2"}))); } TEST(GraphTest, ValueProducer) { static constexpr std::string_view source = R"(graph(%foo, %bar, %baz): %o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) return(%o2, %baz) )"; auto graph = stringToGraph(source); auto foo = graph->getValue("foo"); EXPECT_EQ(foo->producer()->target(), "prim.Input"); auto o1 = graph->getValue("o1"); EXPECT_EQ(o1->producer()->target(), "aten.foo"); } TEST(GraphTest, InsertBeforeAfter) { static constexpr std::string_view source = R"(graph(%foo, %bar, %baz): %o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) return(%o2, %baz) )"; auto graph = stringToGraph(source); auto it = graph->nodes().begin(); ++it; auto& node = *it; EXPECT_EQ(node.target(), "aten.foo"); auto before = graph->createNode("before", {}); auto after = graph->createNode("after", {}); auto atEnd = graph->createNode("atEnd", {}); graph->insertBefore(before, &node); graph->insertAfter(after, &node); graph->insert(atEnd); static constexpr std::string_view expected = R"(graph(%foo, %bar, %baz): = before() %o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) = after() = atEnd() return(%o2, %baz) )"; EXPECT_EQ(graphToString(*graph), expected); } TEST(GraphTest, ValueUses) { static constexpr std::string_view source = R"(graph(%foo, %bar, %baz): %o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) return(%o2, %baz) )"; auto graph = stringToGraph(source); auto o2 = graph->getValue("o2"); EXPECT_EQ(o2->users().size(), 1); EXPECT_EQ(o2->users()[0]->target(), "prim.Output"); } TEST(GraphTest, ApplyDevicePlacement) { auto graph = Graph::createGraph(); auto node1 = graph->insertNode("node1"); auto node2 = graph->insertNode("node2"); node1->addAttribute({"a", c10::Device(c10::DeviceType::CPU)}); node1->addAttribute({"b", c10::Device(c10::DeviceType::CUDA, 0)}); node1->addAttribute({"c", c10::Device(c10::DeviceType::CUDA, 1)}); node2->addAttribute({"d", c10::Device(c10::DeviceType::CUDA, 0)}); graph->applyDevicePlacement( Placement(std::unordered_map{ {c10::Device(c10::DeviceType::CUDA, 0), c10::Device(c10::DeviceType::CUDA, 1)}})); EXPECT_EQ( std::get(node1->getAttribute("a").value), c10::Device(c10::DeviceType::CPU)); EXPECT_EQ( std::get(node1->getAttribute("b").value), c10::Device(c10::DeviceType::CUDA, 1)); EXPECT_EQ( std::get(node1->getAttribute("c").value), c10::Device(c10::DeviceType::CUDA, 1)); EXPECT_EQ( std::get(node2->getAttribute("d").value), c10::Device(c10::DeviceType::CUDA, 1)); } TEST(GraphTest, ReplaceAllUses) { static constexpr std::string_view source = R"(graph(%foo, %bar, %baz): %o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) return(%o2, %baz) )"; auto graph = stringToGraph(source); auto o2 = graph->getValue("o2"); auto bar = graph->getValue("bar"); auto foo = graph->getValue("foo"); EXPECT_EQ(o2->users().size(), 1); EXPECT_EQ(bar->users().size(), 1); EXPECT_EQ(foo->users().size(), 1); graph->replaceAllUses(o2, bar); EXPECT_EQ(o2->users().size(), 0); EXPECT_EQ(bar->users().size(), 2); graph->replaceAllUses(bar, foo); EXPECT_EQ(bar->users().size(), 0); EXPECT_EQ(foo->users().size(), 2); static constexpr std::string_view expected = R"(graph(%foo, %bar, %baz): %o1, %o2 = aten.foo(self=%foo, target=%foo, alpha=0.1) return(%foo, %baz) )"; EXPECT_EQ(graphToString(*graph), expected); } TEST(GraphTest, GetUniqueValueName) { static constexpr std::string_view source = R"(graph(%foo, %bar, %baz): %o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1) return(%o2, %bar) )"; auto graph = stringToGraph(source); auto o2 = graph->getValue("o2"); auto fooNode = o2->producer(); auto v0 = graph->getUniqueValueName(); graph->addValue(v0, Type::Kind::None, fooNode); auto v1 = graph->getUniqueValueName(); graph->addValue(v1, Type::Kind::None, fooNode); auto v2 = graph->getUniqueValueName(); EXPECT_EQ(v0, "v0"); EXPECT_EQ(v1, "v1"); EXPECT_EQ(v2, "v2"); } TEST(GraphTest, ReplaceAllUsesMultiUse) { static constexpr std::string_view source = R"(graph(%foo, %bar): %o1 = aten.foo(a=%foo, b=%foo, c=%bar) return(%o1) )"; auto graph = stringToGraph(source); auto foo = graph->getValue("foo"); auto bar = graph->getValue("bar"); graph->replaceAllUses(foo, bar); static constexpr std::string_view expected = R"(graph(%foo, %bar): %o1 = aten.foo(a=%bar, b=%bar, c=%bar) return(%o1) )"; EXPECT_EQ(graphToString(*graph), expected); } TEST(GraphTest, ReplaceAllUsesAfter) { static constexpr std::string_view source = R"(graph(%foo): %o1 = aten.foo1(a=%foo) %o2 = aten.foo2(a=%o1, b=%foo) %o3 = aten.foo3(a=%o2, b=%o2, c=%foo) return(%foo, %o1, %o2, %o3) )"; auto graph = stringToGraph(source); auto foo = graph->getValue("foo"); auto o1 = graph->getValue("o1"); auto foo3Node = graph->getValue("o3")->producer(); graph->replaceAllUsesAfterNode(foo, o1, foo3Node); static constexpr std::string_view expected = R"(graph(%foo): %o1 = aten.foo1(a=%foo) %o2 = aten.foo2(a=%o1, b=%foo) %o3 = aten.foo3(a=%o2, b=%o2, c=%foo) return(%o1, %o1, %o2, %o3) )"; EXPECT_EQ(graphToString(*graph), expected); EXPECT_EQ(foo->users().size(), 3); EXPECT_EQ(o1->users().size(), 2); } TEST(GraphTest, InsertingAfter) { static constexpr std::string_view source = R"(graph(%foo, %bar): %o1 = aten.first(a=%foo) %o2 = aten.foo(c=%bar) return(%o1, %o2) )"; auto graph = stringToGraph(source); auto origNode = graph->getValue("o1")->producer(); { InsertingAfter guard(origNode); graph->insertNode("one"); graph->insertNode("two"); graph->insertNode("three"); } graph->insertNode("four"); static constexpr std::string_view expected = R"(graph(%foo, %bar): %o1 = aten.first(a=%foo) = one() = two() = three() %o2 = aten.foo(c=%bar) = four() return(%o1, %o2) )"; EXPECT_EQ(graphToString(*graph), expected); } TEST(NodeTest, GetInputAndAttribute) { auto graph = Graph::createGraph(); auto input1 = graph->addInput("input1", Type::Kind::Tensor); auto input2 = graph->addInput("input2", Type::Kind::Tensor); auto input3 = graph->addInput("input3", Type::Kind::Tensor); auto node = graph->createNode("foo.bar"); node->addInput({"out_of_order", input1}); node->addInput({"arg1", input2}); node->addInput({"arg2", input3}); node->addAttribute({"b", static_cast(0)}); node->addAttribute({"a", static_cast(2)}); node->addAttribute({"c", static_cast(1)}); { const auto& input = node->getInput("out_of_order"); EXPECT_EQ(input.name, "out_of_order"); EXPECT_EQ(input.value, input1); } { const auto& input = node->getInput("arg1"); EXPECT_EQ(input.name, "arg1"); EXPECT_EQ(input.value, input2); } { const auto& input = node->getInput("arg2"); EXPECT_EQ(input.name, "arg2"); EXPECT_EQ(input.value, input3); } { const auto& attr = node->getAttribute("a"); EXPECT_EQ(attr.name, "a"); EXPECT_EQ(attr.value, Constant(static_cast(2))); } { const auto& attr = node->getAttribute("b"); EXPECT_EQ(attr.name, "b"); EXPECT_EQ(attr.value, Constant(static_cast(0))); } { const auto& attr = node->getAttribute("c"); EXPECT_EQ(attr.name, "c"); EXPECT_EQ(attr.value, Constant(static_cast(1))); } EXPECT_EQ(node->tryGetInput("doesnotexist"), nullptr); EXPECT_EQ(node->tryGetAttribute("doesnotexist"), nullptr); } TEST(NodeTest, NextPrev) { static constexpr std::string_view source = R"(graph(%foo): %o1 = aten.foo1(a=%foo) %o2 = aten.foo2(a=%o1, b=%foo) %o3 = aten.foo3(a=%o2, b=%o2, c=%foo) return(%foo, %o1, %o2, %o3) )"; auto graph = stringToGraph(source); auto foo1 = graph->getValue("o1")->producer(); auto foo2 = graph->getValue("o2")->producer(); auto foo3 = graph->getValue("o3")->producer(); EXPECT_EQ(foo1->next(), foo2); EXPECT_EQ(foo2->next(), foo3); EXPECT_EQ(foo3->prev(), foo2); EXPECT_EQ(foo3->next(), graph->outputNode()); EXPECT_EQ(foo2->prev(), foo1); EXPECT_EQ(foo1->prev(), graph->inputNode()); EXPECT_EQ(graph->inputNode()->prev(), nullptr); EXPECT_EQ(graph->outputNode()->next(), nullptr); } TEST(GraphTest, IsBefore) { auto source = R"IR( graph(%foo): %o1 = aten.foo1(a=%foo) %o2 = aten.foo2(a=%o1) %o3 = aten.foo3(a=%o2) return (%o3) )IR"; auto graph = stringToGraph(source); ASSERT_NE(graph, nullptr); auto* o1 = graph->tryGetValue("o1"); auto* o2 = graph->tryGetValue("o2"); auto* o3 = graph->tryGetValue("o3"); auto* foo1 = o1->producer(); auto* foo2 = o2->producer(); auto* foo3 = o3->producer(); EXPECT_TRUE(foo1->isBefore(foo2)) << "foo1 should appear before foo2"; EXPECT_TRUE(foo2->isBefore(foo3)) << "foo2 should appear before foo3"; EXPECT_TRUE(foo1->isBefore(foo3)) << "foo1 should appear before foo3"; EXPECT_FALSE(foo2->isBefore(foo1)) << "foo2 should not appear before foo1"; EXPECT_FALSE(foo3->isBefore(foo2)) << "foo3 should not appear before foo2"; } TEST(GraphTest, RemoveNodeWithUsers) { // Check we shouldn't be able to remove a node that still has users auto source = R"IR( graph(%foo): %o1 = aten.foo1(a=%foo) %o2 = aten.foo2(a=%o1, b=%foo) %o3 = aten.foo3(a=%o2, b=%o2, c=%foo) return (%foo, %o1, %o3) )IR"; auto graph = stringToGraph(source); ASSERT_NE(graph, nullptr); auto* o2 = graph->tryGetValue("o2"); auto* foo2 = o2->producer(); EXPECT_THROW(graph->removeNode(foo2), c10::Error); } TEST(GraphTest, RemoveNodeUnused) { // Check node removal works as expected auto source = R"IR( graph(%foo): %o1 = aten.foo1(a=%foo) %o2 = aten.foo2(a=%o1, b=%foo) %unused = aten.fooUnused(a=%o2) return(%foo, %o1, %o2) )IR"; auto graph = stringToGraph(source); auto* valUnused = graph->tryGetValue("unused"); Node* nodeUnused = valUnused->producer(); EXPECT_EQ(nodeUnused->target(), "aten.fooUnused"); graph->removeNode(nodeUnused); graph->lint(); // %unused should now be gone EXPECT_EQ(graph->tryGetValue("unused"), nullptr) << "Value %unused should no longer exist in the graph"; for (const auto& node : graph->nodes()) { EXPECT_NE(node.target(), "aten.fooUnused"); for (const auto* output : node.outputs()) { EXPECT_NE(output->name(), "unused") << "Should not find %unused in any remaining node's outputs"; } } } TEST(GraphTest, RemoveValue) { auto source = R"IR( graph(%foo): %o1 = aten.foo1(a=%foo) %o2 = aten.foo2(a=%o1, b=%foo) %o3 = aten.foo3(a=%o2, b=%o2, c=%foo) return (%foo, %o1, %o3) )IR"; auto graph = stringToGraph(source); auto* val_o1 = graph->tryGetValue("o1"); { // Check we shouldn't be able to remove a value that still has users EXPECT_THROW(graph->removeValue(val_o1), c10::Error); } { // Check value removal works as expected graph->replaceAllUses(val_o1, graph->tryGetValue("foo")); graph->removeValue(val_o1); EXPECT_EQ(graph->tryGetValue("%o1"), nullptr); } } TEST(GraphTest, InsertGraph) { auto source = R"IR( graph(%foo): %o1 = aten.foo1(a=%foo) return (%o1) )IR"; // Subgraph to be inserted auto subgraphSource = R"IR( graph(%x): %s1 = aten.subFoo1(a=%x) %s2 = aten.subFoo2(a=%s1) return (%s2) )IR"; auto mainGraph = stringToGraph(source); auto subGraph = stringToGraph(subgraphSource); // Insert subGraph into mainGraph. Use %o1 as the subGraph's %x auto val_o1 = mainGraph->tryGetValue("o1"); std::unordered_map valueMap; std::vector insertedOutputs = mainGraph->insertGraph(*subGraph, {val_o1}, valueMap); EXPECT_EQ(insertedOutputs.size(), 1); // Check all new nodes are inserted correctly from the copied %s2 auto* newS2 = insertedOutputs.front(); auto* newSubFoo2 = newS2->producer(); EXPECT_EQ(newSubFoo2->target(), "aten.subFoo2"); auto* newS1 = newSubFoo2->inputs().front().value; auto* newSubFoo1 = newS1->producer(); EXPECT_EQ(newSubFoo1->target(), "aten.subFoo1"); EXPECT_EQ(newSubFoo1->inputs().front().value, val_o1); auto* subInputVal = subGraph->inputs().front(); EXPECT_EQ(valueMap[subInputVal], val_o1); for (const auto& [val1, val2] : valueMap) { if (val1->name() == "s1") { EXPECT_EQ(val2->name(), newS1->name()); } if (val1->name() == "s2") { EXPECT_EQ(val2->name(), newS2->name()); } if (val1->name() == "x") { EXPECT_EQ(val2->name(), val_o1->name()); } } mainGraph->lint(); } TEST(GraphTest, CleanupDeadNodes) { // %c is unused const std::string source = R"( graph(%x, %y): %a = foo(a=%x, b=%y) %b = foo1(c=%a) %c = foo2(a=%b, b=%y) return(%b) )"; auto graph = stringToGraph(source); // Verify that %c exists initially auto* cVal = graph->tryGetValue("c"); ASSERT_NE(nullptr, cVal); size_t nodeCountBefore = graph->nodes().size(); graph->cleanupDeadNodes(); // %c should now be gone EXPECT_EQ(nullptr, graph->tryGetValue("c")); // %b should still be there EXPECT_NE(nullptr, graph->tryGetValue("b")); EXPECT_EQ(nodeCountBefore - 1, graph->nodes().size()); } TEST(GraphTest, RenumberValues) { const std::string source = R"( graph(%x): %a = foo(a=%x) %b = foo1(a=%a) return (%a) )"; auto graph = stringToGraph(source); graph->cleanupDeadNodes(); // %b should now be gone EXPECT_EQ(nullptr, graph->tryGetValue("b")); // %a should now be the last value EXPECT_EQ(graph->tryGetValue("a")->id(), graph->numValues() - 1); // All values should be renumbered size_t numVals = graph->numValues(); std::unordered_set ids; ids.reserve(numVals); for (const auto* val : graph->values()) { ASSERT_LT(val->id(), numVals); ids.insert(val->id()); } // Check ids are contiguous and unique b/w 0 and numVals EXPECT_EQ(numVals, ids.size()); for (size_t i = 0; i < numVals; ++i) { EXPECT_NE(ids.end(), ids.find(i)); } } TEST(SerializationTest, RoundTrip) { static constexpr std::string_view source = R"(graph(%foo, %bar, %baz): %o1 = aten.foo(self=%foo, target=%bar, alpha=0.1) return(%o1, %baz) )"; const auto graph = stringToGraph(source); const auto serialized = graphToString(*graph); EXPECT_EQ(source, serialized); } TEST(SerializationTest, EscapedStringConstant) { const auto parsed = std::get(convertAtomicConstant(R"("string_\"escape")")); std::string expected = "string_\\\"escape"; EXPECT_EQ(parsed, expected); } TEST(SerializationTest, DeviceConstant) { const auto device = std::get(convertAtomicConstant("Device{cuda:1}")); EXPECT_EQ(device.index(), 1); EXPECT_EQ(device.type(), c10::DeviceType::CUDA); } TEST(SerializationTest, TrueConstant) { const auto parsedTrue = std::get(convertAtomicConstant("true")); EXPECT_EQ(parsedTrue, true); const auto parsedFalse = std::get(convertAtomicConstant("false")); EXPECT_EQ(parsedFalse, false); } TEST(SerializationTest, MemoryFormatConstant) { const auto parsed = std::get( convertAtomicConstant("MemoryFormat::ContiguousFormat")); EXPECT_EQ(parsed, c10::MemoryFormat::Contiguous); } TEST(SerializationTest, FloatConstant) { const auto parsed = std::get(convertAtomicConstant("5.0")); EXPECT_EQ(parsed, 5.0); } TEST(SerializationTest, IntConstant) { const auto parsed = std::get(convertAtomicConstant("5")); EXPECT_EQ(parsed, 5); } TEST(SerializationTest, FloatExponentConstant) { const auto parsed = std::get(convertAtomicConstant("1e-05")); EXPECT_EQ(parsed, 0.00001); } TEST(SerializationTest, SingleElementListConstant) { const auto parsed = std::get>(convertListConstant("[1]")); const auto expected = std::vector{1}; EXPECT_EQ(parsed, expected); } TEST(SerializationTest, IntListConstant) { const auto parsed = std::get>(convertListConstant("[1, 2, 3, 4]")); const auto expected = std::vector{1, 2, 3, 4}; EXPECT_EQ(parsed, expected); } TEST(SerializationTest, FloatListConstant) { const auto parsed = std::get>( convertListConstant("[1.0, 2.0, 3.0, 4.0]")); const auto expected = std::vector{1.0, 2.0, 3.0, 4.0}; EXPECT_EQ(parsed, expected); } TEST(SerializationTest, BoolListConstant) { const auto parsed = std::get>(convertListConstant("[false, true, false]")); const auto expected = std::vector{false, true, false}; EXPECT_EQ(parsed, expected); } } // namespace torch::nativert