# Owner(s): ["module: dynamo"] # flake8: noqa: B950 import contextlib import torch import torch.fx from torch._dynamo.graph_utils import _detect_cycles from torch._dynamo.test_case import TestCase from torch._dynamo.testing import ( AotEagerAndRecordGraphs, extract_graph_and_tracker, normalize_gm, ) from torch.utils._ordered_set import OrderedSet def extract_graph(fn, *args, **kwargs): backend = AotEagerAndRecordGraphs() result = torch.compile(backend=backend)(fn)(*args, **kwargs) return result, backend.graphs, backend.fw_graphs def graph_str(gm): return normalize_gm(gm.print_readable(print_output=False)) class GraphDededuplicationTests(TestCase): def setUp(self): self.exit_stack = contextlib.ExitStack() self.exit_stack.enter_context( torch._dynamo.config.patch("use_graph_deduplication", True) ) super().setUp() def tearDown(self): self.exit_stack.close() super().tearDown() def run_and_return_graphs(self, fn, *args, **kwargs): return extract_graph(fn, *args, **kwargs) def run_and_get_simple_graph(self): def fn(x, y): x0 = x + 1 y0 = y + 2 z = x0.sum() + y0.sum() return z x = torch.rand(10, 10, requires_grad=False) y = torch.rand(10, 20, requires_grad=False) _, _, fw_graphs = self.run_and_return_graphs(fn, x, y) return fw_graphs[0] def test_single_subgraph(self): def inner_fn(x, y): x0 = x + 1 y0 = y + 2 z = x0.sum() + y0.sum() return z def fn(x, y): _o0 = inner_fn(x, y) o1 = torch.sin(y) o2 = inner_fn(x, o1) o3 = inner_fn(x, y) o4 = o3 * o3 return o2 * o4 x = torch.rand(10, 10, requires_grad=True) y = torch.rand(10, 20, requires_grad=True) x_clone = x.clone().requires_grad_(True) y_clone = y.clone().requires_grad_(True) ref_result = fn(x, y) result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone) torch.allclose(ref_result, result) ref_result.sum().backward() result.sum().backward() self.assertEqual(len(graphs), 1) self.assertEqual(len(fw_graphs), 1) self.assertExpectedInline( graph_str(graphs[0]), """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"): subgraph_0 = self.subgraph_0 l_x_ = L_x_ l_y_ = L_y_ o1: "f32[10, 20]" = torch.sin(l_y_) invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); invoke_subgraph = None invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, o1); o1 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None o4: "f32[]" = getitem_2 * getitem_2; getitem_2 = None mul_1: "f32[]" = getitem_1 * o4; getitem_1 = o4 = None return (mul_1,) class subgraph_0(torch.nn.Module): def forward(self, subgraph_input_l_x_, subgraph_input_l_y_): x0: "f32[10, 10]" = subgraph_input_l_x_ + 1; subgraph_input_l_x_ = None y0: "f32[10, 20]" = subgraph_input_l_y_ + 2; subgraph_input_l_y_ = None sum_1: "f32[]" = x0.sum(); x0 = None sum_2: "f32[]" = y0.sum(); y0 = None z: "f32[]" = sum_1 + sum_2; sum_1 = sum_2 = None return (z,) """, ) self.assertExpectedInline( graph_str(fw_graphs[0]), """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2) partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0 invoke_subgraph_5 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1, sin); partitioned_fw_subgraph_0_0 = sin = None getitem_1: "f32[]" = invoke_subgraph_5[0]; invoke_subgraph_5 = None partitioned_fw_subgraph_0_1 = self.partitioned_fw_subgraph_0_0 invoke_subgraph_7 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_1, 'partitioned_fw_subgraph_0_0', primals_1, primals_2); partitioned_fw_subgraph_0_1 = primals_1 = None getitem_2: "f32[]" = invoke_subgraph_7[0]; invoke_subgraph_7 = None mul: "f32[]" = torch.ops.aten.mul.Tensor(getitem_2, getitem_2) mul_1: "f32[]" = torch.ops.aten.mul.Tensor(getitem_1, mul); mul = None return (mul_1, primals_2, getitem_1, getitem_2) class partitioned_fw_subgraph_0_0(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add_2,) """, ) def test_single_subgraph2(self): def fn(x): x0 = x + 2 o = inner_fn(x0) o = torch.cos(o) o = inner_fn(o) return torch.sin(o) def inner_fn(x): o = x * 7 o += 1 o += 2 return o x = torch.rand(10, 10, requires_grad=True) x_clone = x.clone().requires_grad_(True) ref_result = fn(x) result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone) torch.allclose(ref_result, result) ref_result.sum().backward() result.sum().backward() self.assertEqual(len(graphs), 1) self.assertEqual(len(fw_graphs), 1) self.assertExpectedInline( graph_str(graphs[0]), """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[10, 10]"): subgraph_0 = self.subgraph_0 l_x_ = L_x_ x0: "f32[10, 10]" = l_x_ + 2; l_x_ = None invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', x0); x0 = None getitem: "f32[10, 10]" = invoke_subgraph[0]; invoke_subgraph = None o_3: "f32[10, 10]" = torch.cos(getitem); getitem = None invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', o_3); subgraph_0 = o_3 = None getitem_1: "f32[10, 10]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None sin: "f32[10, 10]" = torch.sin(getitem_1); getitem_1 = None return (sin,) class subgraph_0(torch.nn.Module): def forward(self, subgraph_input_x0): o: "f32[10, 10]" = subgraph_input_x0 * 7; subgraph_input_x0 = None o += 1; o_1: "f32[10, 10]" = o; o = None o_1 += 2; o_2: "f32[10, 10]" = o_1; o_1 = None return (o_2,) """, ) self.assertExpectedInline( graph_str(fw_graphs[0]), """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[10, 10]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0 invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', add); partitioned_fw_subgraph_0_0 = add = None getitem: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None cos: "f32[10, 10]" = torch.ops.aten.cos.default(getitem) partitioned_fw_subgraph_0_1 = self.partitioned_fw_subgraph_0_0 invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_1, 'partitioned_fw_subgraph_0_0', cos); partitioned_fw_subgraph_0_1 = cos = None getitem_1: "f32[10, 10]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None sin: "f32[10, 10]" = torch.ops.aten.sin.default(getitem_1) cos_1: "f32[10, 10]" = torch.ops.aten.cos.default(getitem_1); getitem_1 = None sin_1: "f32[10, 10]" = torch.ops.aten.sin.default(getitem); getitem = None neg: "f32[10, 10]" = torch.ops.aten.neg.default(sin_1); sin_1 = None return (sin, cos_1, neg) class partitioned_fw_subgraph_0_0(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]"): mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(primals_0, 7); primals_0 = None add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul, 1); mul = None add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, 2); add = None return (add_1,) """, ) def test_multiple_subgraphs(self): def inner_fn(x, y): x1 = x + 1 y1 = y + 2 z = x1.sum() + y1.sum() return z def inner_fn2(a, b): a0 = a + 2 b0 = b + 3 c = a0 * b0.cos().sum() return c def fn(x, y): x0 = torch.cos(x) y0 = torch.sin(y) o1 = inner_fn2(x0, y0) o0 = inner_fn(x, y) o1 = torch.sin(o0) o2 = inner_fn(x, y0) o3 = inner_fn2(x0, y0) o4 = inner_fn(x, y) return o1 * o2 * o3 + o4 x = torch.rand(10, 10, requires_grad=True) y = torch.rand(10, 20, requires_grad=True) x_clone = x.clone().requires_grad_(True) y_clone = y.clone().requires_grad_(True) ref_result = fn(x, y) result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone) torch.allclose(ref_result, result) ref_result.sum().backward() result.sum().backward() self.assertEqual(len(graphs), 1) self.assertEqual(len(fw_graphs), 1) self.assertExpectedInline( graph_str(graphs[0]), """\ class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"): subgraph_1 = self.subgraph_1 subgraph_0 = self.subgraph_0 l_x_ = L_x_ l_y_ = L_y_ x0: "f32[10, 10]" = torch.cos(l_x_) y0: "f32[10, 20]" = torch.sin(l_y_) invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_) getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None o1: "f32[]" = torch.sin(getitem); getitem = None invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, y0) getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None mul_2: "f32[]" = o1 * getitem_1; o1 = getitem_1 = None invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', x0, y0); invoke_subgraph_3 = None invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', x0, y0); subgraph_1 = x0 = y0 = None getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None mul_3: "f32[10, 10]" = mul_2 * getitem_4; mul_2 = getitem_4 = None add_13: "f32[10, 10]" = mul_3 + getitem_2; mul_3 = getitem_2 = None return (add_13,) class subgraph_1(torch.nn.Module): def forward(self, subgraph_input_x0, subgraph_input_y0): a0: "f32[10, 10]" = subgraph_input_x0 + 2; subgraph_input_x0 = None b0: "f32[10, 20]" = subgraph_input_y0 + 3; subgraph_input_y0 = None cos_1: "f32[10, 20]" = b0.cos(); b0 = None sum_1: "f32[]" = cos_1.sum(); cos_1 = None c: "f32[10, 10]" = a0 * sum_1; a0 = sum_1 = None return (c,) class subgraph_0(torch.nn.Module): def forward(self, subgraph_input_l_x_, subgraph_input_l_y_): x1: "f32[10, 10]" = subgraph_input_l_x_ + 1; subgraph_input_l_x_ = None y1: "f32[10, 20]" = subgraph_input_l_y_ + 2; subgraph_input_l_y_ = None sum_2: "f32[]" = x1.sum(); x1 = None sum_3: "f32[]" = y1.sum(); y1 = None z: "f32[]" = sum_2 + sum_3; sum_2 = sum_3 = None return (z,) """, ) self.assertExpectedInline( graph_str(fw_graphs[0]), """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): cos: "f32[10, 10]" = torch.ops.aten.cos.default(primals_1) sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2) partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0 invoke_subgraph_9 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1, primals_2); partitioned_fw_subgraph_0_0 = None getitem: "f32[]" = invoke_subgraph_9[0]; invoke_subgraph_9 = None sin_1: "f32[]" = torch.ops.aten.sin.default(getitem) partitioned_fw_subgraph_0_1 = self.partitioned_fw_subgraph_0_0 invoke_subgraph_11 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_1, 'partitioned_fw_subgraph_0_0', primals_1, sin); partitioned_fw_subgraph_0_1 = None getitem_1: "f32[]" = invoke_subgraph_11[0]; invoke_subgraph_11 = None mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_1); sin_1 = None partitioned_fw_subgraph_0_2 = self.partitioned_fw_subgraph_0_0 invoke_subgraph_13 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_2, 'partitioned_fw_subgraph_0_0', primals_1, primals_2); partitioned_fw_subgraph_0_2 = None getitem_2: "f32[]" = invoke_subgraph_13[0]; invoke_subgraph_13 = None partitioned_fw_subgraph_1_0 = self.partitioned_fw_subgraph_1_0 invoke_subgraph_15 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_1_0, 'partitioned_fw_subgraph_1_0', cos, sin); partitioned_fw_subgraph_1_0 = cos = sin = None getitem_19: "f32[]" = invoke_subgraph_15[3] getitem_18: "f32[10, 20]" = invoke_subgraph_15[2] getitem_17: "f32[10, 10]" = invoke_subgraph_15[1] getitem_4: "f32[10, 10]" = invoke_subgraph_15[0]; invoke_subgraph_15 = None mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(mul, getitem_4); mul = None add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul_1, getitem_2); mul_1 = getitem_2 = None return (add, primals_1, primals_2, getitem, getitem_1, getitem_19, getitem_18, getitem_17, getitem_4) class partitioned_fw_subgraph_0_0(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add_2,) class partitioned_fw_subgraph_1_0(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 2) add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 3) cos: "f32[10, 20]" = torch.ops.aten.cos.default(add_1); add_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(cos); cos = None mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(add, sum_1); add = None return (mul, primals_0, primals_1, sum_1) """, ) def test_dependent_subgraphs(self): def inner_fn(x, y): x0 = x + 1 y0 = y + 2 z = x0.sum() + y0.sum() return z def fn(x, y): o0 = inner_fn(x, y) o1 = inner_fn(x, o0) return o1 x = torch.rand(10, 10, requires_grad=True) y = torch.rand(10, 20, requires_grad=True) x_clone = x.clone().requires_grad_(True) y_clone = y.clone().requires_grad_(True) ref_result = fn(x, y) result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone) torch.allclose(ref_result, result) ref_result.sum().backward() result.sum().backward() self.assertEqual(len(graphs), 1) self.assertEqual(len(fw_graphs), 1) self.assertExpectedInline( graph_str(fw_graphs[0]), """\ class GraphModule(torch.nn.Module): def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): add: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_2, 2); primals_2 = None sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None partitioned_fw_subgraph_0_0 = self.partitioned_fw_subgraph_0_0 invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_0, 'partitioned_fw_subgraph_0_0', primals_1, sum_1); partitioned_fw_subgraph_0_0 = sum_1 = None getitem: "f32[]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None add_1: "f32[]" = torch.ops.aten.add.Tensor(getitem, 2); getitem = None sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None partitioned_fw_subgraph_0_1 = self.partitioned_fw_subgraph_0_0 invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(partitioned_fw_subgraph_0_1, 'partitioned_fw_subgraph_0_0', primals_1, sum_2); partitioned_fw_subgraph_0_1 = primals_1 = sum_2 = None getitem_1: "f32[]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None return (getitem_1,) class partitioned_fw_subgraph_0_0(torch.nn.Module): def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[]"): add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None add_1: "f32[]" = torch.ops.aten.add.Tensor(sum_1, primals_1); sum_1 = primals_1 = None return (add_1,) """, ) def test_input_mutation(self): def inner_fn2(x, y): x0 = x + 1 y0 = y + 1 x.add_(x0) y.add_(y0) return x.sum() + y.sum() def fn(x, y): x0 = torch.sin(x) o2 = inner_fn2(x0, y) o3 = inner_fn2(x0.clone(), y.clone()) return o2 + o3 x = torch.rand(10, 10, requires_grad=False) y = torch.rand(10, 20, requires_grad=False) x_clone = x.clone() y_clone = y.clone() ref_result = fn(x, y) result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone) torch.allclose(ref_result, result) self.assertEqual(len(graphs), 1) self.assertEqual(len(fw_graphs), 1) self.assertExpectedInline( graph_str(fw_graphs[0]), """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): sin: "f32[10, 10]" = torch.ops.aten.sin.default(arg0_1); arg0_1 = None add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, 1) add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, 1) add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, add); sin = add = None add_3: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, add_1); add_1 = None clone: "f32[10, 10]" = torch.ops.aten.clone.default(add_2) clone_1: "f32[10, 20]" = torch.ops.aten.clone.default(add_3) add_4: "f32[10, 10]" = torch.ops.aten.add.Tensor(clone, 1) add_5: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, 1) add_6: "f32[10, 10]" = torch.ops.aten.add.Tensor(clone, add_4); clone = add_4 = None add_7: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, add_5); clone_1 = add_5 = None repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', add_2, add_3); repeated_subgraph0 = add_2 = None getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None repeated_subgraph0_1 = self.repeated_subgraph0 invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', add_6, add_7); repeated_subgraph0_1 = add_6 = add_7 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None add_8: "f32[]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None copy_: "f32[10, 20]" = torch.ops.aten.copy_.default(arg1_1, add_3); arg1_1 = add_3 = copy_ = None return (add_8,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None sum_2: "f32[]" = torch.ops.aten.sum.default(arg1_1); arg1_1 = None add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add,) """, ) def test_input_aliasing(self): def inner_fn(x, y): x0 = x.view(x.size()) return x0.view(x.size()) def inner_fn2(x, y): x = x * 2 y = y * 2 return x.sum() + y.sum() def fn(x, y): o0 = inner_fn(x, y) o1 = inner_fn(x, y) o2 = inner_fn2(x, y) o3 = inner_fn2(x, y) return o0 + o1 + o2.sum() + o3.sum() x = torch.rand(10, 10, requires_grad=False) y = torch.rand(10, 20, requires_grad=False) x_clone = x.clone() y_clone = y.clone() ref_result = fn(x, y) result, graphs, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone) torch.allclose(ref_result, result) self.assertEqual(len(graphs), 1) self.assertEqual(len(fw_graphs), 1) self.assertExpectedInline( graph_str(fw_graphs[0]), """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): view: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10]) view_1: "f32[10, 10]" = torch.ops.aten.view.default(view, [10, 10]); view = None view_2: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10]) view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None repeated_subgraph0_1 = self.repeated_subgraph0 invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None return (add_2,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add,) """, ) def test_cycle_detection_no_cycle(self): mod = self.run_and_get_simple_graph() self.assertExpectedInline( _detect_cycles(mod.graph, {}), """no cycle detected""" ) def test_cycle_detection_single_node(self): def fn(x, y): x0 = x + 1 y0 = y + 2 z = x0.sum() + y0.sum() return z x = torch.rand(10, 10, requires_grad=False) y = torch.rand(10, 20, requires_grad=False) _, _, fw_graphs = self.run_and_return_graphs(fn, x, y) mod = fw_graphs[0] add_node = next(n for n in mod.graph.nodes if n.name == "add") add_2 = next(n for n in mod.graph.nodes if n.name == "add_2") args = add_node.args add_node.args = (args[0], add_2) self.assertExpectedInline( _detect_cycles(mod.graph, {add_2: OrderedSet([add_2])}), """cycle detected in path: deque([output, add_2, add_2])""", ) def test_cycle_detection_two_node(self): def fn(x, y): x0 = x + 1 y0 = y + 2 z = x0.sum() + y0.sum() return z x = torch.rand(10, 10, requires_grad=False) y = torch.rand(10, 20, requires_grad=False) _, _, fw_graphs = self.run_and_return_graphs(fn, x, y) mod = fw_graphs[0] add_node = next(n for n in mod.graph.nodes if n.name == "add") add_2 = next(n for n in mod.graph.nodes if n.name == "add_2") args = add_node.args add_node.args = (args[0], add_2) self.assertExpectedInline( _detect_cycles( mod.graph, {add_2: OrderedSet([add_node]), add_node: OrderedSet([add_2])}, ), """cycle detected in path: deque([output, add_2, add, add_2])""", ) def test_cycle_detection_arg_and_additional_deps(self): def fn(x, y): x0 = x + 1 y0 = y + 2 z = x0.sum() + y0.sum() return z x = torch.rand(10, 10, requires_grad=False) y = torch.rand(10, 20, requires_grad=False) _, _, fw_graphs = self.run_and_return_graphs(fn, x, y) mod = fw_graphs[0] add_node = next(n for n in mod.graph.nodes if n.name == "add") add_2 = next(n for n in mod.graph.nodes if n.name == "add_2") args = add_node.args add_node.args = (args[0], add_2) self.assertExpectedInline( _detect_cycles(mod.graph, {add_2: OrderedSet([add_node])}), """cycle detected in path: deque([output, add_2, add, add_2])""", ) def test_cycle_detection_simple(self): mod = self.run_and_get_simple_graph() add_node = next(n for n in mod.graph.nodes if n.name == "add") add_2 = next(n for n in mod.graph.nodes if n.name == "add_2") args = add_node.args add_node.args = (args[0], add_2) self.assertExpectedInline( _detect_cycles(mod.graph, {}), """cycle detected in path: deque([output, add_2, sum_1, add, add_2])""", ) def test_cycle_detection_complex(self): def inner_fn(x, y): x0 = x.view(x.size()) return x0.view(x.size()) def inner_fn2(x, y): x = x * 2 y = y * 2 return x.sum() + y.sum() def fn(x, y): o0 = inner_fn(x, y) o1 = inner_fn(x, y) o2 = inner_fn2(x, y) o3 = inner_fn2(x, y) return o0 + o1 + o2.sum() + o3.sum() x = torch.rand(10, 10, requires_grad=False) y = torch.rand(10, 20, requires_grad=False) x_clone = x.clone() y_clone = y.clone() _, _, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone) mod = fw_graphs[0] invoke_subgraph_node = next( n for n in mod.graph.nodes if n.name == "invoke_subgraph" ) add_2 = next(n for n in mod.graph.nodes if n.name == "add_2") args = invoke_subgraph_node.args invoke_subgraph_node.args = (add_2, args[1]) self.assertExpectedInline( _detect_cycles(mod.graph, {}), """cycle detected in path: deque([output, add_2, add_1, sum_1, getitem, invoke_subgraph, add_2])""", ) def test_autocast_ordering(self): from torch._dynamo.graph_deduplication import ( _populate_additional_deps, _stable_topological_sort, ) def inner_fn(x, y): x0 = x.view(x.size()) return x0.view(x.size()) def inner_fn2(x, y): x = x * 2 y = y * 2 return x.sum() + y.sum() def fn(x, y): o0 = inner_fn(x, y) o1 = inner_fn(x, y) o2 = inner_fn2(x, y) o3 = inner_fn2(x, y) return o0 + o1 + o2.sum() + o3.sum() x = torch.rand(10, 10, requires_grad=False) y = torch.rand(10, 20, requires_grad=False) x_clone = x.clone() y_clone = y.clone() _, _, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone) mod = fw_graphs[0] def get_node(name): return next(n for n in mod.graph.nodes if n.name == name) sum_1 = get_node("sum_1") enter_autocast = mod.graph.call_function(torch.amp._enter_autocast) sum_1.append(enter_autocast) sum_2 = get_node("sum_2") exit_autocast = mod.graph.call_function(torch.amp._exit_autocast) sum_2.append(exit_autocast) additional_deps = _populate_additional_deps(mod.graph, {}) invoke_subgraph = get_node("invoke_subgraph") invoke_subgraph.append(enter_autocast) getitem_1 = get_node("getitem_1") getitem_1.append(exit_autocast) self.assertExpectedInline( graph_str(mod), """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): view: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10]) view_1: "f32[10, 10]" = torch.ops.aten.view.default(view, [10, 10]); view = None view_2: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10]) view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None _enter_autocast = torch.amp.autocast_mode._enter_autocast(); _enter_autocast = None getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None repeated_subgraph0_1 = self.repeated_subgraph0 invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None _exit_autocast = torch.amp.autocast_mode._exit_autocast(); _exit_autocast = None sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None return (add_2,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add,) """, ) _stable_topological_sort(mod.graph, additional_deps) self.assertExpectedInline( graph_str(mod), """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): view: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10]) view_1: "f32[10, 10]" = torch.ops.aten.view.default(view, [10, 10]); view = None view_2: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10]) view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None _enter_autocast = torch.amp.autocast_mode._enter_autocast(); _enter_autocast = None add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None repeated_subgraph0_1 = self.repeated_subgraph0 invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None _exit_autocast = torch.amp.autocast_mode._exit_autocast(); _exit_autocast = None add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None return (add_2,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add,) """, ) def test_output_nodes_last(self): from torch._dynamo.graph_deduplication import _stable_topological_sort def inner_fn(x, y): x0 = x.view(x.size()) return x0.view(x.size()) def inner_fn2(x, y): x = x * 2 y = y * 2 return x.sum() + y.sum() def fn(x, y): o0 = inner_fn(x, y) o1 = inner_fn(x, y) o2 = inner_fn2(x, y) o3 = inner_fn2(x, y) return o0 + o1 + o2.sum() + o3.sum() x = torch.rand(10, 10, requires_grad=False) y = torch.rand(10, 20, requires_grad=False) x_clone = x.clone() y_clone = y.clone() _, _, fw_graphs = self.run_and_return_graphs(fn, x_clone, y_clone) mod = fw_graphs[0] output = next(n for n in mod.graph.nodes if n.op == "output") add_2 = next(n for n in mod.graph.nodes if n.name == "sum_2") add_2.append(output) self.assertExpectedInline( graph_str(mod), """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): view: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10]) view_1: "f32[10, 10]" = torch.ops.aten.view.default(view, [10, 10]); view = None view_2: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10]) view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None repeated_subgraph0_1 = self.repeated_subgraph0 invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None return (add_2,) add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add,) """, ) _stable_topological_sort(mod.graph, {}) self.assertExpectedInline( graph_str(mod), """\ class (torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): view: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10]) view_1: "f32[10, 10]" = torch.ops.aten.view.default(view, [10, 10]); view = None view_2: "f32[10, 10]" = torch.ops.aten.view.default(arg0_1, [10, 10]) view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None repeated_subgraph0 = self.repeated_subgraph0 invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = None getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None sum_1: "f32[]" = torch.ops.aten.sum.default(getitem); getitem = None add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, sum_1); add = sum_1 = None repeated_subgraph0_1 = self.repeated_subgraph0 invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0_1 = arg0_1 = arg1_1 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None sum_2: "f32[]" = torch.ops.aten.sum.default(getitem_1); getitem_1 = None add_2: "f32[10, 10]" = torch.ops.aten.add.Tensor(add_1, sum_2); add_1 = sum_2 = None return (add_2,) class repeated_subgraph0(torch.nn.Module): def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add,) """, ) def test_mutation_ordering(self): from torch._dynamo.graph_deduplication import _stable_topological_sort def inner_fn(x, y): x0 = x.view(x.size()) return x0.view(x.size()) def inner_fn2(x, y): x = x * 2 y = y * 2 return x.sum() + y.sum() def fn(x, y): o0 = inner_fn(x, y) o1 = inner_fn(x, y) x.add_(x) o2 = inner_fn2(x, y) y.mul_(y) o3 = inner_fn2(x, y) return o0 + o1 + o2.sum() + o3.sum() x = torch.rand(10, 10) y = torch.rand(10, 20) x_clone = x.clone() y_clone = y.clone() graph, _ = extract_graph_and_tracker(fn, x_clone, y_clone) def graph_code(graph): return graph.python_code("self").src def get_node(name): return next(n for n in graph.nodes if n.name == name) self.assertExpectedInline( graph_code(graph), """\ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): subgraph_0 = self.subgraph_0 l_x_ = L_x_ l_y_ = L_y_ x0 = l_x_.view((10, 10)) o0 = x0.view((10, 10)); x0 = None x0_1 = l_x_.view((10, 10)) o1 = x0_1.view((10, 10)); x0_1 = None add_ = l_x_.add_(l_x_); add_ = None add_2 = o0 + o1; o0 = o1 = None invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_) mul_ = l_y_.mul_(l_y_); mul_ = None getitem = invoke_subgraph[0]; invoke_subgraph = None sum_5 = getitem.sum(); getitem = None add_3 = add_2 + sum_5; add_2 = sum_5 = None invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None sum_6 = getitem_1.sum(); getitem_1 = None add_4 = add_3 + sum_6; add_3 = sum_6 = None return (add_4,) """, ) # Shuffle nodes in the graph add_ = get_node("add_") mul_ = get_node("mul_") o1 = get_node("o1") o1.append(mul_) add_2 = get_node("add_2") add_2.append(add_) self.assertExpectedInline( graph_code(graph), """\ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): subgraph_0 = self.subgraph_0 l_x_ = L_x_ l_y_ = L_y_ x0 = l_x_.view((10, 10)) o0 = x0.view((10, 10)); x0 = None x0_1 = l_x_.view((10, 10)) o1 = x0_1.view((10, 10)); x0_1 = None mul_ = l_y_.mul_(l_y_); mul_ = None add_2 = o0 + o1; o0 = o1 = None add_ = l_x_.add_(l_x_); add_ = None invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_) getitem = invoke_subgraph[0]; invoke_subgraph = None sum_5 = getitem.sum(); getitem = None add_3 = add_2 + sum_5; add_2 = sum_5 = None invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None sum_6 = getitem_1.sum(); getitem_1 = None add_4 = add_3 + sum_6; add_3 = sum_6 = None return (add_4,) """, ) _stable_topological_sort( graph, torch._dynamo.graph_deduplication.last_node_to_additional_deps ) self.assertExpectedInline( graph_code(graph), """\ def forward(self, L_x_ : torch.Tensor, L_y_ : torch.Tensor): subgraph_0 = self.subgraph_0 l_x_ = L_x_ l_y_ = L_y_ x0 = l_x_.view((10, 10)) o0 = x0.view((10, 10)); x0 = None x0_1 = l_x_.view((10, 10)) o1 = x0_1.view((10, 10)); x0_1 = None add_2 = o0 + o1; o0 = o1 = None add_ = l_x_.add_(l_x_); add_ = None invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_) mul_ = l_y_.mul_(l_y_); mul_ = None getitem = invoke_subgraph[0]; invoke_subgraph = None sum_5 = getitem.sum(); getitem = None add_3 = add_2 + sum_5; add_2 = sum_5 = None invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', l_x_, l_y_); subgraph_0 = l_x_ = l_y_ = None getitem_1 = invoke_subgraph_1[0]; invoke_subgraph_1 = None sum_6 = getitem_1.sum(); getitem_1 = None add_4 = add_3 + sum_6; add_3 = sum_6 = None return (add_4,) """, ) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()