648 lines
18 KiB
C++
648 lines
18 KiB
C++
#include <c10/core/Device.h>
|
|
#include <gmock/gmock.h>
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <torch/nativert/graph/Graph.h>
|
|
|
|
using namespace ::testing;
|
|
|
|
namespace torch::nativert {
|
|
TEST(GraphTest, Basic) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%foo, %bar, %baz):
|
|
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
|
|
return(%o2, %baz)
|
|
)";
|
|
auto graph = stringToGraph(source);
|
|
EXPECT_EQ(graph->inputs().size(), 3);
|
|
EXPECT_EQ(graph->inputs()[0]->name(), "foo");
|
|
EXPECT_EQ(graph->inputs()[1]->name(), "bar");
|
|
EXPECT_EQ(graph->inputs()[2]->name(), "baz");
|
|
|
|
const auto& nodes = graph->nodes();
|
|
EXPECT_EQ(nodes.size(), 3);
|
|
// First node is the input node
|
|
auto it = nodes.begin();
|
|
{
|
|
const auto& node = *it;
|
|
EXPECT_EQ(node.target(), "prim.Input");
|
|
EXPECT_EQ(node.inputs().size(), 0);
|
|
EXPECT_EQ(node.outputs().size(), 3);
|
|
EXPECT_EQ(node.outputs()[0]->name(), "foo");
|
|
EXPECT_EQ(node.outputs()[1]->name(), "bar");
|
|
EXPECT_EQ(node.outputs()[2]->name(), "baz");
|
|
}
|
|
{
|
|
std::advance(it, 1);
|
|
const auto& node = *it;
|
|
EXPECT_EQ(node.target(), "aten.foo");
|
|
EXPECT_EQ(node.inputs().size(), 2);
|
|
EXPECT_EQ(node.inputs()[0].name, "self");
|
|
EXPECT_EQ(node.inputs()[1].name, "target");
|
|
|
|
EXPECT_EQ(node.attributes().size(), 1);
|
|
EXPECT_EQ(node.attributes()[0].name, "alpha");
|
|
}
|
|
{
|
|
std::advance(it, 1);
|
|
const auto& node = *it;
|
|
EXPECT_EQ(node.target(), "prim.Output");
|
|
EXPECT_EQ(node.inputs().size(), 2);
|
|
EXPECT_EQ(node.inputs()[0].name, "o2");
|
|
EXPECT_EQ(node.inputs()[1].name, "baz");
|
|
}
|
|
EXPECT_EQ(graph->outputs().size(), 2);
|
|
EXPECT_EQ(graph->outputs()[0]->name(), "o2");
|
|
EXPECT_EQ(graph->outputs()[1]->name(), "baz");
|
|
|
|
const auto& values = graph->values();
|
|
EXPECT_EQ(values.size(), 5);
|
|
std::vector<std::string> valueNames;
|
|
valueNames.reserve(values.size());
|
|
for (const auto& v : values) {
|
|
valueNames.emplace_back(v->name());
|
|
}
|
|
std::sort(valueNames.begin(), valueNames.end());
|
|
|
|
EXPECT_THAT(
|
|
valueNames,
|
|
ContainerEq(std::vector<std::string>({"bar", "baz", "foo", "o1", "o2"})));
|
|
}
|
|
|
|
TEST(GraphTest, ValueProducer) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%foo, %bar, %baz):
|
|
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
|
|
return(%o2, %baz)
|
|
)";
|
|
auto graph = stringToGraph(source);
|
|
auto foo = graph->getValue("foo");
|
|
EXPECT_EQ(foo->producer()->target(), "prim.Input");
|
|
auto o1 = graph->getValue("o1");
|
|
EXPECT_EQ(o1->producer()->target(), "aten.foo");
|
|
}
|
|
|
|
TEST(GraphTest, InsertBeforeAfter) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%foo, %bar, %baz):
|
|
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
|
|
return(%o2, %baz)
|
|
)";
|
|
auto graph = stringToGraph(source);
|
|
auto it = graph->nodes().begin();
|
|
++it;
|
|
auto& node = *it;
|
|
EXPECT_EQ(node.target(), "aten.foo");
|
|
auto before = graph->createNode("before", {});
|
|
auto after = graph->createNode("after", {});
|
|
auto atEnd = graph->createNode("atEnd", {});
|
|
|
|
graph->insertBefore(before, &node);
|
|
graph->insertAfter(after, &node);
|
|
graph->insert(atEnd);
|
|
|
|
static constexpr std::string_view expected =
|
|
R"(graph(%foo, %bar, %baz):
|
|
= before()
|
|
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
|
|
= after()
|
|
= atEnd()
|
|
return(%o2, %baz)
|
|
)";
|
|
EXPECT_EQ(graphToString(*graph), expected);
|
|
}
|
|
|
|
TEST(GraphTest, ValueUses) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%foo, %bar, %baz):
|
|
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
|
|
return(%o2, %baz)
|
|
)";
|
|
auto graph = stringToGraph(source);
|
|
auto o2 = graph->getValue("o2");
|
|
EXPECT_EQ(o2->users().size(), 1);
|
|
EXPECT_EQ(o2->users()[0]->target(), "prim.Output");
|
|
}
|
|
|
|
TEST(GraphTest, ApplyDevicePlacement) {
|
|
auto graph = Graph::createGraph();
|
|
auto node1 = graph->insertNode("node1");
|
|
auto node2 = graph->insertNode("node2");
|
|
|
|
node1->addAttribute({"a", c10::Device(c10::DeviceType::CPU)});
|
|
node1->addAttribute({"b", c10::Device(c10::DeviceType::CUDA, 0)});
|
|
node1->addAttribute({"c", c10::Device(c10::DeviceType::CUDA, 1)});
|
|
|
|
node2->addAttribute({"d", c10::Device(c10::DeviceType::CUDA, 0)});
|
|
|
|
graph->applyDevicePlacement(
|
|
Placement(std::unordered_map<c10::Device, c10::Device>{
|
|
{c10::Device(c10::DeviceType::CUDA, 0),
|
|
c10::Device(c10::DeviceType::CUDA, 1)}}));
|
|
|
|
EXPECT_EQ(
|
|
std::get<c10::Device>(node1->getAttribute("a").value),
|
|
c10::Device(c10::DeviceType::CPU));
|
|
EXPECT_EQ(
|
|
std::get<c10::Device>(node1->getAttribute("b").value),
|
|
c10::Device(c10::DeviceType::CUDA, 1));
|
|
EXPECT_EQ(
|
|
std::get<c10::Device>(node1->getAttribute("c").value),
|
|
c10::Device(c10::DeviceType::CUDA, 1));
|
|
EXPECT_EQ(
|
|
std::get<c10::Device>(node2->getAttribute("d").value),
|
|
c10::Device(c10::DeviceType::CUDA, 1));
|
|
}
|
|
|
|
TEST(GraphTest, ReplaceAllUses) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%foo, %bar, %baz):
|
|
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
|
|
return(%o2, %baz)
|
|
)";
|
|
auto graph = stringToGraph(source);
|
|
auto o2 = graph->getValue("o2");
|
|
auto bar = graph->getValue("bar");
|
|
auto foo = graph->getValue("foo");
|
|
|
|
EXPECT_EQ(o2->users().size(), 1);
|
|
EXPECT_EQ(bar->users().size(), 1);
|
|
EXPECT_EQ(foo->users().size(), 1);
|
|
|
|
graph->replaceAllUses(o2, bar);
|
|
EXPECT_EQ(o2->users().size(), 0);
|
|
EXPECT_EQ(bar->users().size(), 2);
|
|
|
|
graph->replaceAllUses(bar, foo);
|
|
EXPECT_EQ(bar->users().size(), 0);
|
|
EXPECT_EQ(foo->users().size(), 2);
|
|
static constexpr std::string_view expected =
|
|
R"(graph(%foo, %bar, %baz):
|
|
%o1, %o2 = aten.foo(self=%foo, target=%foo, alpha=0.1)
|
|
return(%foo, %baz)
|
|
)";
|
|
EXPECT_EQ(graphToString(*graph), expected);
|
|
}
|
|
|
|
TEST(GraphTest, GetUniqueValueName) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%foo, %bar, %baz):
|
|
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
|
|
return(%o2, %bar)
|
|
)";
|
|
auto graph = stringToGraph(source);
|
|
auto o2 = graph->getValue("o2");
|
|
auto fooNode = o2->producer();
|
|
auto v0 = graph->getUniqueValueName();
|
|
graph->addValue(v0, Type::Kind::None, fooNode);
|
|
auto v1 = graph->getUniqueValueName();
|
|
graph->addValue(v1, Type::Kind::None, fooNode);
|
|
auto v2 = graph->getUniqueValueName();
|
|
EXPECT_EQ(v0, "v0");
|
|
EXPECT_EQ(v1, "v1");
|
|
EXPECT_EQ(v2, "v2");
|
|
}
|
|
|
|
TEST(GraphTest, ReplaceAllUsesMultiUse) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%foo, %bar):
|
|
%o1 = aten.foo(a=%foo, b=%foo, c=%bar)
|
|
return(%o1)
|
|
)";
|
|
auto graph = stringToGraph(source);
|
|
auto foo = graph->getValue("foo");
|
|
auto bar = graph->getValue("bar");
|
|
graph->replaceAllUses(foo, bar);
|
|
|
|
static constexpr std::string_view expected =
|
|
R"(graph(%foo, %bar):
|
|
%o1 = aten.foo(a=%bar, b=%bar, c=%bar)
|
|
return(%o1)
|
|
)";
|
|
EXPECT_EQ(graphToString(*graph), expected);
|
|
}
|
|
|
|
TEST(GraphTest, ReplaceAllUsesAfter) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%foo):
|
|
%o1 = aten.foo1(a=%foo)
|
|
%o2 = aten.foo2(a=%o1, b=%foo)
|
|
%o3 = aten.foo3(a=%o2, b=%o2, c=%foo)
|
|
return(%foo, %o1, %o2, %o3)
|
|
)";
|
|
auto graph = stringToGraph(source);
|
|
auto foo = graph->getValue("foo");
|
|
auto o1 = graph->getValue("o1");
|
|
auto foo3Node = graph->getValue("o3")->producer();
|
|
graph->replaceAllUsesAfterNode(foo, o1, foo3Node);
|
|
|
|
static constexpr std::string_view expected =
|
|
R"(graph(%foo):
|
|
%o1 = aten.foo1(a=%foo)
|
|
%o2 = aten.foo2(a=%o1, b=%foo)
|
|
%o3 = aten.foo3(a=%o2, b=%o2, c=%foo)
|
|
return(%o1, %o1, %o2, %o3)
|
|
)";
|
|
EXPECT_EQ(graphToString(*graph), expected);
|
|
EXPECT_EQ(foo->users().size(), 3);
|
|
EXPECT_EQ(o1->users().size(), 2);
|
|
}
|
|
|
|
TEST(GraphTest, InsertingAfter) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%foo, %bar):
|
|
%o1 = aten.first(a=%foo)
|
|
%o2 = aten.foo(c=%bar)
|
|
return(%o1, %o2)
|
|
)";
|
|
auto graph = stringToGraph(source);
|
|
auto origNode = graph->getValue("o1")->producer();
|
|
{
|
|
InsertingAfter guard(origNode);
|
|
graph->insertNode("one");
|
|
graph->insertNode("two");
|
|
graph->insertNode("three");
|
|
}
|
|
graph->insertNode("four");
|
|
static constexpr std::string_view expected =
|
|
R"(graph(%foo, %bar):
|
|
%o1 = aten.first(a=%foo)
|
|
= one()
|
|
= two()
|
|
= three()
|
|
%o2 = aten.foo(c=%bar)
|
|
= four()
|
|
return(%o1, %o2)
|
|
)";
|
|
EXPECT_EQ(graphToString(*graph), expected);
|
|
}
|
|
|
|
TEST(NodeTest, GetInputAndAttribute) {
|
|
auto graph = Graph::createGraph();
|
|
auto input1 = graph->addInput("input1", Type::Kind::Tensor);
|
|
auto input2 = graph->addInput("input2", Type::Kind::Tensor);
|
|
auto input3 = graph->addInput("input3", Type::Kind::Tensor);
|
|
auto node = graph->createNode("foo.bar");
|
|
|
|
node->addInput({"out_of_order", input1});
|
|
node->addInput({"arg1", input2});
|
|
node->addInput({"arg2", input3});
|
|
|
|
node->addAttribute({"b", static_cast<int64_t>(0)});
|
|
node->addAttribute({"a", static_cast<int64_t>(2)});
|
|
node->addAttribute({"c", static_cast<int64_t>(1)});
|
|
{
|
|
const auto& input = node->getInput("out_of_order");
|
|
EXPECT_EQ(input.name, "out_of_order");
|
|
EXPECT_EQ(input.value, input1);
|
|
}
|
|
{
|
|
const auto& input = node->getInput("arg1");
|
|
EXPECT_EQ(input.name, "arg1");
|
|
EXPECT_EQ(input.value, input2);
|
|
}
|
|
{
|
|
const auto& input = node->getInput("arg2");
|
|
EXPECT_EQ(input.name, "arg2");
|
|
EXPECT_EQ(input.value, input3);
|
|
}
|
|
{
|
|
const auto& attr = node->getAttribute("a");
|
|
EXPECT_EQ(attr.name, "a");
|
|
EXPECT_EQ(attr.value, Constant(static_cast<int64_t>(2)));
|
|
}
|
|
{
|
|
const auto& attr = node->getAttribute("b");
|
|
EXPECT_EQ(attr.name, "b");
|
|
EXPECT_EQ(attr.value, Constant(static_cast<int64_t>(0)));
|
|
}
|
|
{
|
|
const auto& attr = node->getAttribute("c");
|
|
EXPECT_EQ(attr.name, "c");
|
|
EXPECT_EQ(attr.value, Constant(static_cast<int64_t>(1)));
|
|
}
|
|
|
|
EXPECT_EQ(node->tryGetInput("doesnotexist"), nullptr);
|
|
EXPECT_EQ(node->tryGetAttribute("doesnotexist"), nullptr);
|
|
}
|
|
|
|
TEST(NodeTest, NextPrev) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%foo):
|
|
%o1 = aten.foo1(a=%foo)
|
|
%o2 = aten.foo2(a=%o1, b=%foo)
|
|
%o3 = aten.foo3(a=%o2, b=%o2, c=%foo)
|
|
return(%foo, %o1, %o2, %o3)
|
|
)";
|
|
auto graph = stringToGraph(source);
|
|
auto foo1 = graph->getValue("o1")->producer();
|
|
auto foo2 = graph->getValue("o2")->producer();
|
|
auto foo3 = graph->getValue("o3")->producer();
|
|
EXPECT_EQ(foo1->next(), foo2);
|
|
EXPECT_EQ(foo2->next(), foo3);
|
|
EXPECT_EQ(foo3->prev(), foo2);
|
|
EXPECT_EQ(foo3->next(), graph->outputNode());
|
|
EXPECT_EQ(foo2->prev(), foo1);
|
|
EXPECT_EQ(foo1->prev(), graph->inputNode());
|
|
EXPECT_EQ(graph->inputNode()->prev(), nullptr);
|
|
EXPECT_EQ(graph->outputNode()->next(), nullptr);
|
|
}
|
|
|
|
TEST(GraphTest, IsBefore) {
|
|
auto source = R"IR(
|
|
graph(%foo):
|
|
%o1 = aten.foo1(a=%foo)
|
|
%o2 = aten.foo2(a=%o1)
|
|
%o3 = aten.foo3(a=%o2)
|
|
return (%o3)
|
|
)IR";
|
|
|
|
auto graph = stringToGraph(source);
|
|
ASSERT_NE(graph, nullptr);
|
|
|
|
auto* o1 = graph->tryGetValue("o1");
|
|
auto* o2 = graph->tryGetValue("o2");
|
|
auto* o3 = graph->tryGetValue("o3");
|
|
|
|
auto* foo1 = o1->producer();
|
|
auto* foo2 = o2->producer();
|
|
auto* foo3 = o3->producer();
|
|
|
|
EXPECT_TRUE(foo1->isBefore(foo2)) << "foo1 should appear before foo2";
|
|
EXPECT_TRUE(foo2->isBefore(foo3)) << "foo2 should appear before foo3";
|
|
EXPECT_TRUE(foo1->isBefore(foo3)) << "foo1 should appear before foo3";
|
|
|
|
EXPECT_FALSE(foo2->isBefore(foo1)) << "foo2 should not appear before foo1";
|
|
EXPECT_FALSE(foo3->isBefore(foo2)) << "foo3 should not appear before foo2";
|
|
}
|
|
|
|
TEST(GraphTest, RemoveNodeWithUsers) {
|
|
// Check we shouldn't be able to remove a node that still has users
|
|
auto source = R"IR(
|
|
graph(%foo):
|
|
%o1 = aten.foo1(a=%foo)
|
|
%o2 = aten.foo2(a=%o1, b=%foo)
|
|
%o3 = aten.foo3(a=%o2, b=%o2, c=%foo)
|
|
return (%foo, %o1, %o3)
|
|
)IR";
|
|
|
|
auto graph = stringToGraph(source);
|
|
ASSERT_NE(graph, nullptr);
|
|
|
|
auto* o2 = graph->tryGetValue("o2");
|
|
auto* foo2 = o2->producer();
|
|
|
|
EXPECT_THROW(graph->removeNode(foo2), c10::Error);
|
|
}
|
|
|
|
TEST(GraphTest, RemoveNodeUnused) {
|
|
// Check node removal works as expected
|
|
auto source = R"IR(
|
|
graph(%foo):
|
|
%o1 = aten.foo1(a=%foo)
|
|
%o2 = aten.foo2(a=%o1, b=%foo)
|
|
%unused = aten.fooUnused(a=%o2)
|
|
return(%foo, %o1, %o2)
|
|
)IR";
|
|
auto graph = stringToGraph(source);
|
|
|
|
auto* valUnused = graph->tryGetValue("unused");
|
|
Node* nodeUnused = valUnused->producer();
|
|
EXPECT_EQ(nodeUnused->target(), "aten.fooUnused");
|
|
|
|
graph->removeNode(nodeUnused);
|
|
graph->lint();
|
|
|
|
// %unused should now be gone
|
|
EXPECT_EQ(graph->tryGetValue("unused"), nullptr)
|
|
<< "Value %unused should no longer exist in the graph";
|
|
|
|
for (const auto& node : graph->nodes()) {
|
|
EXPECT_NE(node.target(), "aten.fooUnused");
|
|
for (const auto* output : node.outputs()) {
|
|
EXPECT_NE(output->name(), "unused")
|
|
<< "Should not find %unused in any remaining node's outputs";
|
|
}
|
|
}
|
|
}
|
|
|
|
TEST(GraphTest, RemoveValue) {
|
|
auto source = R"IR(
|
|
graph(%foo):
|
|
%o1 = aten.foo1(a=%foo)
|
|
%o2 = aten.foo2(a=%o1, b=%foo)
|
|
%o3 = aten.foo3(a=%o2, b=%o2, c=%foo)
|
|
return (%foo, %o1, %o3)
|
|
)IR";
|
|
|
|
auto graph = stringToGraph(source);
|
|
auto* val_o1 = graph->tryGetValue("o1");
|
|
|
|
{
|
|
// Check we shouldn't be able to remove a value that still has users
|
|
EXPECT_THROW(graph->removeValue(val_o1), c10::Error);
|
|
}
|
|
|
|
{
|
|
// Check value removal works as expected
|
|
graph->replaceAllUses(val_o1, graph->tryGetValue("foo"));
|
|
graph->removeValue(val_o1);
|
|
EXPECT_EQ(graph->tryGetValue("%o1"), nullptr);
|
|
}
|
|
}
|
|
|
|
TEST(GraphTest, InsertGraph) {
|
|
auto source = R"IR(
|
|
graph(%foo):
|
|
%o1 = aten.foo1(a=%foo)
|
|
return (%o1)
|
|
)IR";
|
|
|
|
// Subgraph to be inserted
|
|
auto subgraphSource = R"IR(
|
|
graph(%x):
|
|
%s1 = aten.subFoo1(a=%x)
|
|
%s2 = aten.subFoo2(a=%s1)
|
|
return (%s2)
|
|
)IR";
|
|
|
|
auto mainGraph = stringToGraph(source);
|
|
auto subGraph = stringToGraph(subgraphSource);
|
|
|
|
// Insert subGraph into mainGraph. Use %o1 as the subGraph's %x
|
|
auto val_o1 = mainGraph->tryGetValue("o1");
|
|
std::unordered_map<const Value*, Value*> valueMap;
|
|
std::vector<Value*> insertedOutputs =
|
|
mainGraph->insertGraph(*subGraph, {val_o1}, valueMap);
|
|
|
|
EXPECT_EQ(insertedOutputs.size(), 1);
|
|
|
|
// Check all new nodes are inserted correctly from the copied %s2
|
|
auto* newS2 = insertedOutputs.front();
|
|
|
|
auto* newSubFoo2 = newS2->producer();
|
|
EXPECT_EQ(newSubFoo2->target(), "aten.subFoo2");
|
|
|
|
auto* newS1 = newSubFoo2->inputs().front().value;
|
|
auto* newSubFoo1 = newS1->producer();
|
|
EXPECT_EQ(newSubFoo1->target(), "aten.subFoo1");
|
|
|
|
EXPECT_EQ(newSubFoo1->inputs().front().value, val_o1);
|
|
|
|
auto* subInputVal = subGraph->inputs().front();
|
|
EXPECT_EQ(valueMap[subInputVal], val_o1);
|
|
for (const auto& [val1, val2] : valueMap) {
|
|
if (val1->name() == "s1") {
|
|
EXPECT_EQ(val2->name(), newS1->name());
|
|
}
|
|
if (val1->name() == "s2") {
|
|
EXPECT_EQ(val2->name(), newS2->name());
|
|
}
|
|
if (val1->name() == "x") {
|
|
EXPECT_EQ(val2->name(), val_o1->name());
|
|
}
|
|
}
|
|
|
|
mainGraph->lint();
|
|
}
|
|
|
|
TEST(GraphTest, CleanupDeadNodes) {
|
|
// %c is unused
|
|
const std::string source = R"(
|
|
graph(%x, %y):
|
|
%a = foo(a=%x, b=%y)
|
|
%b = foo1(c=%a)
|
|
%c = foo2(a=%b, b=%y)
|
|
return(%b)
|
|
)";
|
|
auto graph = stringToGraph(source);
|
|
|
|
// Verify that %c exists initially
|
|
auto* cVal = graph->tryGetValue("c");
|
|
ASSERT_NE(nullptr, cVal);
|
|
size_t nodeCountBefore = graph->nodes().size();
|
|
|
|
graph->cleanupDeadNodes();
|
|
|
|
// %c should now be gone
|
|
EXPECT_EQ(nullptr, graph->tryGetValue("c"));
|
|
// %b should still be there
|
|
EXPECT_NE(nullptr, graph->tryGetValue("b"));
|
|
EXPECT_EQ(nodeCountBefore - 1, graph->nodes().size());
|
|
}
|
|
|
|
TEST(GraphTest, RenumberValues) {
|
|
const std::string source = R"(
|
|
graph(%x):
|
|
%a = foo(a=%x)
|
|
%b = foo1(a=%a)
|
|
return (%a)
|
|
)";
|
|
auto graph = stringToGraph(source);
|
|
graph->cleanupDeadNodes();
|
|
|
|
// %b should now be gone
|
|
EXPECT_EQ(nullptr, graph->tryGetValue("b"));
|
|
|
|
// %a should now be the last value
|
|
EXPECT_EQ(graph->tryGetValue("a")->id(), graph->numValues() - 1);
|
|
|
|
// All values should be renumbered
|
|
size_t numVals = graph->numValues();
|
|
std::unordered_set<ValueId> ids;
|
|
ids.reserve(numVals);
|
|
for (const auto* val : graph->values()) {
|
|
ASSERT_LT(val->id(), numVals);
|
|
ids.insert(val->id());
|
|
}
|
|
|
|
// Check ids are contiguous and unique b/w 0 and numVals
|
|
EXPECT_EQ(numVals, ids.size());
|
|
for (size_t i = 0; i < numVals; ++i) {
|
|
EXPECT_NE(ids.end(), ids.find(i));
|
|
}
|
|
}
|
|
|
|
TEST(SerializationTest, RoundTrip) {
|
|
static constexpr std::string_view source =
|
|
R"(graph(%foo, %bar, %baz):
|
|
%o1 = aten.foo(self=%foo, target=%bar, alpha=0.1)
|
|
return(%o1, %baz)
|
|
)";
|
|
const auto graph = stringToGraph(source);
|
|
const auto serialized = graphToString(*graph);
|
|
EXPECT_EQ(source, serialized);
|
|
}
|
|
|
|
TEST(SerializationTest, EscapedStringConstant) {
|
|
const auto parsed =
|
|
std::get<std::string>(convertAtomicConstant(R"("string_\"escape")"));
|
|
std::string expected = "string_\\\"escape";
|
|
EXPECT_EQ(parsed, expected);
|
|
}
|
|
|
|
TEST(SerializationTest, DeviceConstant) {
|
|
const auto device =
|
|
std::get<c10::Device>(convertAtomicConstant("Device{cuda:1}"));
|
|
EXPECT_EQ(device.index(), 1);
|
|
EXPECT_EQ(device.type(), c10::DeviceType::CUDA);
|
|
}
|
|
|
|
TEST(SerializationTest, TrueConstant) {
|
|
const auto parsedTrue = std::get<bool>(convertAtomicConstant("true"));
|
|
EXPECT_EQ(parsedTrue, true);
|
|
const auto parsedFalse = std::get<bool>(convertAtomicConstant("false"));
|
|
EXPECT_EQ(parsedFalse, false);
|
|
}
|
|
|
|
TEST(SerializationTest, MemoryFormatConstant) {
|
|
const auto parsed = std::get<c10::MemoryFormat>(
|
|
convertAtomicConstant("MemoryFormat::ContiguousFormat"));
|
|
EXPECT_EQ(parsed, c10::MemoryFormat::Contiguous);
|
|
}
|
|
|
|
TEST(SerializationTest, FloatConstant) {
|
|
const auto parsed = std::get<double>(convertAtomicConstant("5.0"));
|
|
EXPECT_EQ(parsed, 5.0);
|
|
}
|
|
|
|
TEST(SerializationTest, IntConstant) {
|
|
const auto parsed = std::get<int64_t>(convertAtomicConstant("5"));
|
|
EXPECT_EQ(parsed, 5);
|
|
}
|
|
|
|
TEST(SerializationTest, FloatExponentConstant) {
|
|
const auto parsed = std::get<double>(convertAtomicConstant("1e-05"));
|
|
EXPECT_EQ(parsed, 0.00001);
|
|
}
|
|
|
|
TEST(SerializationTest, SingleElementListConstant) {
|
|
const auto parsed =
|
|
std::get<std::vector<int64_t>>(convertListConstant("[1]"));
|
|
const auto expected = std::vector<int64_t>{1};
|
|
EXPECT_EQ(parsed, expected);
|
|
}
|
|
|
|
TEST(SerializationTest, IntListConstant) {
|
|
const auto parsed =
|
|
std::get<std::vector<int64_t>>(convertListConstant("[1, 2, 3, 4]"));
|
|
const auto expected = std::vector<int64_t>{1, 2, 3, 4};
|
|
EXPECT_EQ(parsed, expected);
|
|
}
|
|
|
|
TEST(SerializationTest, FloatListConstant) {
|
|
const auto parsed = std::get<std::vector<double>>(
|
|
convertListConstant("[1.0, 2.0, 3.0, 4.0]"));
|
|
const auto expected = std::vector<double>{1.0, 2.0, 3.0, 4.0};
|
|
EXPECT_EQ(parsed, expected);
|
|
}
|
|
|
|
TEST(SerializationTest, BoolListConstant) {
|
|
const auto parsed =
|
|
std::get<std::vector<bool>>(convertListConstant("[false, true, false]"));
|
|
const auto expected = std::vector<bool>{false, true, false};
|
|
EXPECT_EQ(parsed, expected);
|
|
}
|
|
|
|
} // namespace torch::nativert
|