sglang_v0.5.2/pytorch_2.8.0/test/cpp/nativert/test_graph.cpp

648 lines
18 KiB
C++

#include <c10/core/Device.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <torch/nativert/graph/Graph.h>
using namespace ::testing;
namespace torch::nativert {
TEST(GraphTest, Basic) {
static constexpr std::string_view source =
R"(graph(%foo, %bar, %baz):
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
return(%o2, %baz)
)";
auto graph = stringToGraph(source);
EXPECT_EQ(graph->inputs().size(), 3);
EXPECT_EQ(graph->inputs()[0]->name(), "foo");
EXPECT_EQ(graph->inputs()[1]->name(), "bar");
EXPECT_EQ(graph->inputs()[2]->name(), "baz");
const auto& nodes = graph->nodes();
EXPECT_EQ(nodes.size(), 3);
// First node is the input node
auto it = nodes.begin();
{
const auto& node = *it;
EXPECT_EQ(node.target(), "prim.Input");
EXPECT_EQ(node.inputs().size(), 0);
EXPECT_EQ(node.outputs().size(), 3);
EXPECT_EQ(node.outputs()[0]->name(), "foo");
EXPECT_EQ(node.outputs()[1]->name(), "bar");
EXPECT_EQ(node.outputs()[2]->name(), "baz");
}
{
std::advance(it, 1);
const auto& node = *it;
EXPECT_EQ(node.target(), "aten.foo");
EXPECT_EQ(node.inputs().size(), 2);
EXPECT_EQ(node.inputs()[0].name, "self");
EXPECT_EQ(node.inputs()[1].name, "target");
EXPECT_EQ(node.attributes().size(), 1);
EXPECT_EQ(node.attributes()[0].name, "alpha");
}
{
std::advance(it, 1);
const auto& node = *it;
EXPECT_EQ(node.target(), "prim.Output");
EXPECT_EQ(node.inputs().size(), 2);
EXPECT_EQ(node.inputs()[0].name, "o2");
EXPECT_EQ(node.inputs()[1].name, "baz");
}
EXPECT_EQ(graph->outputs().size(), 2);
EXPECT_EQ(graph->outputs()[0]->name(), "o2");
EXPECT_EQ(graph->outputs()[1]->name(), "baz");
const auto& values = graph->values();
EXPECT_EQ(values.size(), 5);
std::vector<std::string> valueNames;
valueNames.reserve(values.size());
for (const auto& v : values) {
valueNames.emplace_back(v->name());
}
std::sort(valueNames.begin(), valueNames.end());
EXPECT_THAT(
valueNames,
ContainerEq(std::vector<std::string>({"bar", "baz", "foo", "o1", "o2"})));
}
TEST(GraphTest, ValueProducer) {
static constexpr std::string_view source =
R"(graph(%foo, %bar, %baz):
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
return(%o2, %baz)
)";
auto graph = stringToGraph(source);
auto foo = graph->getValue("foo");
EXPECT_EQ(foo->producer()->target(), "prim.Input");
auto o1 = graph->getValue("o1");
EXPECT_EQ(o1->producer()->target(), "aten.foo");
}
TEST(GraphTest, InsertBeforeAfter) {
static constexpr std::string_view source =
R"(graph(%foo, %bar, %baz):
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
return(%o2, %baz)
)";
auto graph = stringToGraph(source);
auto it = graph->nodes().begin();
++it;
auto& node = *it;
EXPECT_EQ(node.target(), "aten.foo");
auto before = graph->createNode("before", {});
auto after = graph->createNode("after", {});
auto atEnd = graph->createNode("atEnd", {});
graph->insertBefore(before, &node);
graph->insertAfter(after, &node);
graph->insert(atEnd);
static constexpr std::string_view expected =
R"(graph(%foo, %bar, %baz):
= before()
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
= after()
= atEnd()
return(%o2, %baz)
)";
EXPECT_EQ(graphToString(*graph), expected);
}
TEST(GraphTest, ValueUses) {
static constexpr std::string_view source =
R"(graph(%foo, %bar, %baz):
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
return(%o2, %baz)
)";
auto graph = stringToGraph(source);
auto o2 = graph->getValue("o2");
EXPECT_EQ(o2->users().size(), 1);
EXPECT_EQ(o2->users()[0]->target(), "prim.Output");
}
TEST(GraphTest, ApplyDevicePlacement) {
auto graph = Graph::createGraph();
auto node1 = graph->insertNode("node1");
auto node2 = graph->insertNode("node2");
node1->addAttribute({"a", c10::Device(c10::DeviceType::CPU)});
node1->addAttribute({"b", c10::Device(c10::DeviceType::CUDA, 0)});
node1->addAttribute({"c", c10::Device(c10::DeviceType::CUDA, 1)});
node2->addAttribute({"d", c10::Device(c10::DeviceType::CUDA, 0)});
graph->applyDevicePlacement(
Placement(std::unordered_map<c10::Device, c10::Device>{
{c10::Device(c10::DeviceType::CUDA, 0),
c10::Device(c10::DeviceType::CUDA, 1)}}));
EXPECT_EQ(
std::get<c10::Device>(node1->getAttribute("a").value),
c10::Device(c10::DeviceType::CPU));
EXPECT_EQ(
std::get<c10::Device>(node1->getAttribute("b").value),
c10::Device(c10::DeviceType::CUDA, 1));
EXPECT_EQ(
std::get<c10::Device>(node1->getAttribute("c").value),
c10::Device(c10::DeviceType::CUDA, 1));
EXPECT_EQ(
std::get<c10::Device>(node2->getAttribute("d").value),
c10::Device(c10::DeviceType::CUDA, 1));
}
TEST(GraphTest, ReplaceAllUses) {
static constexpr std::string_view source =
R"(graph(%foo, %bar, %baz):
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
return(%o2, %baz)
)";
auto graph = stringToGraph(source);
auto o2 = graph->getValue("o2");
auto bar = graph->getValue("bar");
auto foo = graph->getValue("foo");
EXPECT_EQ(o2->users().size(), 1);
EXPECT_EQ(bar->users().size(), 1);
EXPECT_EQ(foo->users().size(), 1);
graph->replaceAllUses(o2, bar);
EXPECT_EQ(o2->users().size(), 0);
EXPECT_EQ(bar->users().size(), 2);
graph->replaceAllUses(bar, foo);
EXPECT_EQ(bar->users().size(), 0);
EXPECT_EQ(foo->users().size(), 2);
static constexpr std::string_view expected =
R"(graph(%foo, %bar, %baz):
%o1, %o2 = aten.foo(self=%foo, target=%foo, alpha=0.1)
return(%foo, %baz)
)";
EXPECT_EQ(graphToString(*graph), expected);
}
TEST(GraphTest, GetUniqueValueName) {
static constexpr std::string_view source =
R"(graph(%foo, %bar, %baz):
%o1, %o2 = aten.foo(self=%foo, target=%bar, alpha=0.1)
return(%o2, %bar)
)";
auto graph = stringToGraph(source);
auto o2 = graph->getValue("o2");
auto fooNode = o2->producer();
auto v0 = graph->getUniqueValueName();
graph->addValue(v0, Type::Kind::None, fooNode);
auto v1 = graph->getUniqueValueName();
graph->addValue(v1, Type::Kind::None, fooNode);
auto v2 = graph->getUniqueValueName();
EXPECT_EQ(v0, "v0");
EXPECT_EQ(v1, "v1");
EXPECT_EQ(v2, "v2");
}
TEST(GraphTest, ReplaceAllUsesMultiUse) {
static constexpr std::string_view source =
R"(graph(%foo, %bar):
%o1 = aten.foo(a=%foo, b=%foo, c=%bar)
return(%o1)
)";
auto graph = stringToGraph(source);
auto foo = graph->getValue("foo");
auto bar = graph->getValue("bar");
graph->replaceAllUses(foo, bar);
static constexpr std::string_view expected =
R"(graph(%foo, %bar):
%o1 = aten.foo(a=%bar, b=%bar, c=%bar)
return(%o1)
)";
EXPECT_EQ(graphToString(*graph), expected);
}
TEST(GraphTest, ReplaceAllUsesAfter) {
static constexpr std::string_view source =
R"(graph(%foo):
%o1 = aten.foo1(a=%foo)
%o2 = aten.foo2(a=%o1, b=%foo)
%o3 = aten.foo3(a=%o2, b=%o2, c=%foo)
return(%foo, %o1, %o2, %o3)
)";
auto graph = stringToGraph(source);
auto foo = graph->getValue("foo");
auto o1 = graph->getValue("o1");
auto foo3Node = graph->getValue("o3")->producer();
graph->replaceAllUsesAfterNode(foo, o1, foo3Node);
static constexpr std::string_view expected =
R"(graph(%foo):
%o1 = aten.foo1(a=%foo)
%o2 = aten.foo2(a=%o1, b=%foo)
%o3 = aten.foo3(a=%o2, b=%o2, c=%foo)
return(%o1, %o1, %o2, %o3)
)";
EXPECT_EQ(graphToString(*graph), expected);
EXPECT_EQ(foo->users().size(), 3);
EXPECT_EQ(o1->users().size(), 2);
}
TEST(GraphTest, InsertingAfter) {
static constexpr std::string_view source =
R"(graph(%foo, %bar):
%o1 = aten.first(a=%foo)
%o2 = aten.foo(c=%bar)
return(%o1, %o2)
)";
auto graph = stringToGraph(source);
auto origNode = graph->getValue("o1")->producer();
{
InsertingAfter guard(origNode);
graph->insertNode("one");
graph->insertNode("two");
graph->insertNode("three");
}
graph->insertNode("four");
static constexpr std::string_view expected =
R"(graph(%foo, %bar):
%o1 = aten.first(a=%foo)
= one()
= two()
= three()
%o2 = aten.foo(c=%bar)
= four()
return(%o1, %o2)
)";
EXPECT_EQ(graphToString(*graph), expected);
}
TEST(NodeTest, GetInputAndAttribute) {
auto graph = Graph::createGraph();
auto input1 = graph->addInput("input1", Type::Kind::Tensor);
auto input2 = graph->addInput("input2", Type::Kind::Tensor);
auto input3 = graph->addInput("input3", Type::Kind::Tensor);
auto node = graph->createNode("foo.bar");
node->addInput({"out_of_order", input1});
node->addInput({"arg1", input2});
node->addInput({"arg2", input3});
node->addAttribute({"b", static_cast<int64_t>(0)});
node->addAttribute({"a", static_cast<int64_t>(2)});
node->addAttribute({"c", static_cast<int64_t>(1)});
{
const auto& input = node->getInput("out_of_order");
EXPECT_EQ(input.name, "out_of_order");
EXPECT_EQ(input.value, input1);
}
{
const auto& input = node->getInput("arg1");
EXPECT_EQ(input.name, "arg1");
EXPECT_EQ(input.value, input2);
}
{
const auto& input = node->getInput("arg2");
EXPECT_EQ(input.name, "arg2");
EXPECT_EQ(input.value, input3);
}
{
const auto& attr = node->getAttribute("a");
EXPECT_EQ(attr.name, "a");
EXPECT_EQ(attr.value, Constant(static_cast<int64_t>(2)));
}
{
const auto& attr = node->getAttribute("b");
EXPECT_EQ(attr.name, "b");
EXPECT_EQ(attr.value, Constant(static_cast<int64_t>(0)));
}
{
const auto& attr = node->getAttribute("c");
EXPECT_EQ(attr.name, "c");
EXPECT_EQ(attr.value, Constant(static_cast<int64_t>(1)));
}
EXPECT_EQ(node->tryGetInput("doesnotexist"), nullptr);
EXPECT_EQ(node->tryGetAttribute("doesnotexist"), nullptr);
}
TEST(NodeTest, NextPrev) {
static constexpr std::string_view source =
R"(graph(%foo):
%o1 = aten.foo1(a=%foo)
%o2 = aten.foo2(a=%o1, b=%foo)
%o3 = aten.foo3(a=%o2, b=%o2, c=%foo)
return(%foo, %o1, %o2, %o3)
)";
auto graph = stringToGraph(source);
auto foo1 = graph->getValue("o1")->producer();
auto foo2 = graph->getValue("o2")->producer();
auto foo3 = graph->getValue("o3")->producer();
EXPECT_EQ(foo1->next(), foo2);
EXPECT_EQ(foo2->next(), foo3);
EXPECT_EQ(foo3->prev(), foo2);
EXPECT_EQ(foo3->next(), graph->outputNode());
EXPECT_EQ(foo2->prev(), foo1);
EXPECT_EQ(foo1->prev(), graph->inputNode());
EXPECT_EQ(graph->inputNode()->prev(), nullptr);
EXPECT_EQ(graph->outputNode()->next(), nullptr);
}
TEST(GraphTest, IsBefore) {
auto source = R"IR(
graph(%foo):
%o1 = aten.foo1(a=%foo)
%o2 = aten.foo2(a=%o1)
%o3 = aten.foo3(a=%o2)
return (%o3)
)IR";
auto graph = stringToGraph(source);
ASSERT_NE(graph, nullptr);
auto* o1 = graph->tryGetValue("o1");
auto* o2 = graph->tryGetValue("o2");
auto* o3 = graph->tryGetValue("o3");
auto* foo1 = o1->producer();
auto* foo2 = o2->producer();
auto* foo3 = o3->producer();
EXPECT_TRUE(foo1->isBefore(foo2)) << "foo1 should appear before foo2";
EXPECT_TRUE(foo2->isBefore(foo3)) << "foo2 should appear before foo3";
EXPECT_TRUE(foo1->isBefore(foo3)) << "foo1 should appear before foo3";
EXPECT_FALSE(foo2->isBefore(foo1)) << "foo2 should not appear before foo1";
EXPECT_FALSE(foo3->isBefore(foo2)) << "foo3 should not appear before foo2";
}
TEST(GraphTest, RemoveNodeWithUsers) {
// Check we shouldn't be able to remove a node that still has users
auto source = R"IR(
graph(%foo):
%o1 = aten.foo1(a=%foo)
%o2 = aten.foo2(a=%o1, b=%foo)
%o3 = aten.foo3(a=%o2, b=%o2, c=%foo)
return (%foo, %o1, %o3)
)IR";
auto graph = stringToGraph(source);
ASSERT_NE(graph, nullptr);
auto* o2 = graph->tryGetValue("o2");
auto* foo2 = o2->producer();
EXPECT_THROW(graph->removeNode(foo2), c10::Error);
}
TEST(GraphTest, RemoveNodeUnused) {
// Check node removal works as expected
auto source = R"IR(
graph(%foo):
%o1 = aten.foo1(a=%foo)
%o2 = aten.foo2(a=%o1, b=%foo)
%unused = aten.fooUnused(a=%o2)
return(%foo, %o1, %o2)
)IR";
auto graph = stringToGraph(source);
auto* valUnused = graph->tryGetValue("unused");
Node* nodeUnused = valUnused->producer();
EXPECT_EQ(nodeUnused->target(), "aten.fooUnused");
graph->removeNode(nodeUnused);
graph->lint();
// %unused should now be gone
EXPECT_EQ(graph->tryGetValue("unused"), nullptr)
<< "Value %unused should no longer exist in the graph";
for (const auto& node : graph->nodes()) {
EXPECT_NE(node.target(), "aten.fooUnused");
for (const auto* output : node.outputs()) {
EXPECT_NE(output->name(), "unused")
<< "Should not find %unused in any remaining node's outputs";
}
}
}
TEST(GraphTest, RemoveValue) {
auto source = R"IR(
graph(%foo):
%o1 = aten.foo1(a=%foo)
%o2 = aten.foo2(a=%o1, b=%foo)
%o3 = aten.foo3(a=%o2, b=%o2, c=%foo)
return (%foo, %o1, %o3)
)IR";
auto graph = stringToGraph(source);
auto* val_o1 = graph->tryGetValue("o1");
{
// Check we shouldn't be able to remove a value that still has users
EXPECT_THROW(graph->removeValue(val_o1), c10::Error);
}
{
// Check value removal works as expected
graph->replaceAllUses(val_o1, graph->tryGetValue("foo"));
graph->removeValue(val_o1);
EXPECT_EQ(graph->tryGetValue("%o1"), nullptr);
}
}
TEST(GraphTest, InsertGraph) {
auto source = R"IR(
graph(%foo):
%o1 = aten.foo1(a=%foo)
return (%o1)
)IR";
// Subgraph to be inserted
auto subgraphSource = R"IR(
graph(%x):
%s1 = aten.subFoo1(a=%x)
%s2 = aten.subFoo2(a=%s1)
return (%s2)
)IR";
auto mainGraph = stringToGraph(source);
auto subGraph = stringToGraph(subgraphSource);
// Insert subGraph into mainGraph. Use %o1 as the subGraph's %x
auto val_o1 = mainGraph->tryGetValue("o1");
std::unordered_map<const Value*, Value*> valueMap;
std::vector<Value*> insertedOutputs =
mainGraph->insertGraph(*subGraph, {val_o1}, valueMap);
EXPECT_EQ(insertedOutputs.size(), 1);
// Check all new nodes are inserted correctly from the copied %s2
auto* newS2 = insertedOutputs.front();
auto* newSubFoo2 = newS2->producer();
EXPECT_EQ(newSubFoo2->target(), "aten.subFoo2");
auto* newS1 = newSubFoo2->inputs().front().value;
auto* newSubFoo1 = newS1->producer();
EXPECT_EQ(newSubFoo1->target(), "aten.subFoo1");
EXPECT_EQ(newSubFoo1->inputs().front().value, val_o1);
auto* subInputVal = subGraph->inputs().front();
EXPECT_EQ(valueMap[subInputVal], val_o1);
for (const auto& [val1, val2] : valueMap) {
if (val1->name() == "s1") {
EXPECT_EQ(val2->name(), newS1->name());
}
if (val1->name() == "s2") {
EXPECT_EQ(val2->name(), newS2->name());
}
if (val1->name() == "x") {
EXPECT_EQ(val2->name(), val_o1->name());
}
}
mainGraph->lint();
}
TEST(GraphTest, CleanupDeadNodes) {
// %c is unused
const std::string source = R"(
graph(%x, %y):
%a = foo(a=%x, b=%y)
%b = foo1(c=%a)
%c = foo2(a=%b, b=%y)
return(%b)
)";
auto graph = stringToGraph(source);
// Verify that %c exists initially
auto* cVal = graph->tryGetValue("c");
ASSERT_NE(nullptr, cVal);
size_t nodeCountBefore = graph->nodes().size();
graph->cleanupDeadNodes();
// %c should now be gone
EXPECT_EQ(nullptr, graph->tryGetValue("c"));
// %b should still be there
EXPECT_NE(nullptr, graph->tryGetValue("b"));
EXPECT_EQ(nodeCountBefore - 1, graph->nodes().size());
}
TEST(GraphTest, RenumberValues) {
const std::string source = R"(
graph(%x):
%a = foo(a=%x)
%b = foo1(a=%a)
return (%a)
)";
auto graph = stringToGraph(source);
graph->cleanupDeadNodes();
// %b should now be gone
EXPECT_EQ(nullptr, graph->tryGetValue("b"));
// %a should now be the last value
EXPECT_EQ(graph->tryGetValue("a")->id(), graph->numValues() - 1);
// All values should be renumbered
size_t numVals = graph->numValues();
std::unordered_set<ValueId> ids;
ids.reserve(numVals);
for (const auto* val : graph->values()) {
ASSERT_LT(val->id(), numVals);
ids.insert(val->id());
}
// Check ids are contiguous and unique b/w 0 and numVals
EXPECT_EQ(numVals, ids.size());
for (size_t i = 0; i < numVals; ++i) {
EXPECT_NE(ids.end(), ids.find(i));
}
}
TEST(SerializationTest, RoundTrip) {
static constexpr std::string_view source =
R"(graph(%foo, %bar, %baz):
%o1 = aten.foo(self=%foo, target=%bar, alpha=0.1)
return(%o1, %baz)
)";
const auto graph = stringToGraph(source);
const auto serialized = graphToString(*graph);
EXPECT_EQ(source, serialized);
}
TEST(SerializationTest, EscapedStringConstant) {
const auto parsed =
std::get<std::string>(convertAtomicConstant(R"("string_\"escape")"));
std::string expected = "string_\\\"escape";
EXPECT_EQ(parsed, expected);
}
TEST(SerializationTest, DeviceConstant) {
const auto device =
std::get<c10::Device>(convertAtomicConstant("Device{cuda:1}"));
EXPECT_EQ(device.index(), 1);
EXPECT_EQ(device.type(), c10::DeviceType::CUDA);
}
TEST(SerializationTest, TrueConstant) {
const auto parsedTrue = std::get<bool>(convertAtomicConstant("true"));
EXPECT_EQ(parsedTrue, true);
const auto parsedFalse = std::get<bool>(convertAtomicConstant("false"));
EXPECT_EQ(parsedFalse, false);
}
TEST(SerializationTest, MemoryFormatConstant) {
const auto parsed = std::get<c10::MemoryFormat>(
convertAtomicConstant("MemoryFormat::ContiguousFormat"));
EXPECT_EQ(parsed, c10::MemoryFormat::Contiguous);
}
TEST(SerializationTest, FloatConstant) {
const auto parsed = std::get<double>(convertAtomicConstant("5.0"));
EXPECT_EQ(parsed, 5.0);
}
TEST(SerializationTest, IntConstant) {
const auto parsed = std::get<int64_t>(convertAtomicConstant("5"));
EXPECT_EQ(parsed, 5);
}
TEST(SerializationTest, FloatExponentConstant) {
const auto parsed = std::get<double>(convertAtomicConstant("1e-05"));
EXPECT_EQ(parsed, 0.00001);
}
TEST(SerializationTest, SingleElementListConstant) {
const auto parsed =
std::get<std::vector<int64_t>>(convertListConstant("[1]"));
const auto expected = std::vector<int64_t>{1};
EXPECT_EQ(parsed, expected);
}
TEST(SerializationTest, IntListConstant) {
const auto parsed =
std::get<std::vector<int64_t>>(convertListConstant("[1, 2, 3, 4]"));
const auto expected = std::vector<int64_t>{1, 2, 3, 4};
EXPECT_EQ(parsed, expected);
}
TEST(SerializationTest, FloatListConstant) {
const auto parsed = std::get<std::vector<double>>(
convertListConstant("[1.0, 2.0, 3.0, 4.0]"));
const auto expected = std::vector<double>{1.0, 2.0, 3.0, 4.0};
EXPECT_EQ(parsed, expected);
}
TEST(SerializationTest, BoolListConstant) {
const auto parsed =
std::get<std::vector<bool>>(convertListConstant("[false, true, false]"));
const auto expected = std::vector<bool>{false, true, false};
EXPECT_EQ(parsed, expected);
}
} // namespace torch::nativert