1028 lines
29 KiB
Python
1028 lines
29 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
# ruff: noqa: TRY002
|
|
|
|
import itertools
|
|
import types
|
|
import unittest
|
|
import weakref
|
|
from collections import defaultdict, namedtuple, OrderedDict
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch._functorch.config
|
|
import torch.nn
|
|
import torch.utils.checkpoint
|
|
from torch._dynamo.testing import same
|
|
from torch._dynamo.utils import dict_items
|
|
|
|
|
|
class SimpleDict(dict):
|
|
pass
|
|
|
|
|
|
class DictTests(torch._dynamo.test_case.TestCase):
|
|
def test_dict_subclass_instantiation(self):
|
|
def fn(x):
|
|
sd = SimpleDict(x=5)
|
|
return sd["x"] * x
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_subclass_local_mutation(self):
|
|
def fn(x):
|
|
sd = SimpleDict(x=5)
|
|
z = sd["x"] * x
|
|
sd["x"] = 10
|
|
return z * sd["x"]
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_subclass_local_with_non_dict_method(self):
|
|
# Checks that add_1 method is inlined
|
|
class MethodDict(dict):
|
|
def add_1(self, x):
|
|
return x + 1
|
|
|
|
def fn(x):
|
|
sd = MethodDict(x=5)
|
|
z = sd["x"] * x
|
|
sd["x"] = 10
|
|
return sd.add_1(z * sd["x"])
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_contains(self):
|
|
sd = dict()
|
|
sd[2] = 5
|
|
sd[4] = 10
|
|
|
|
def fn(x):
|
|
if 1 in sd:
|
|
x = x * 2
|
|
else:
|
|
x = x * 3
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
# Ensure a recompilation
|
|
sd[1] = 15
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
# Ensure not recompilation because the traced program remains same here.
|
|
sd[2] = 10
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_subclass_methods_fallback_readonly(self):
|
|
sd = SimpleDict()
|
|
sd[2] = 5
|
|
sd[4] = 10
|
|
# check that regular attr accesses work well
|
|
sd.attr = 4
|
|
|
|
def fn(x):
|
|
for value in sd.values():
|
|
x = x * value
|
|
for key in sd.keys():
|
|
x = x * key
|
|
for k, v in sd.items():
|
|
x = x * k
|
|
x = x * v
|
|
# for k in sd:
|
|
# x = x * k
|
|
|
|
if 1 in sd:
|
|
x = x * 2
|
|
else:
|
|
x = x * 3
|
|
|
|
x = x * sd.get(2, 0)
|
|
x = x * sd.get(3, 4)
|
|
x = len(sd) * x
|
|
x = x * sd.attr
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
# Ensure a recompilation
|
|
sd[6] = 15
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_subclass_instantiation_return(self):
|
|
def fn(x):
|
|
sd = SimpleDict(x=5 * x)
|
|
sd["y"] = 10
|
|
return sd
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(type(ref), type(res))
|
|
self.assertEqual(ref["x"], res["x"])
|
|
self.assertEqual(ref["y"], res["y"])
|
|
|
|
def test_dict_subclass_methods_fallback_mutation(self):
|
|
def fn(sd, x):
|
|
for value in sd.values():
|
|
x = x * value
|
|
sd[6] = 14
|
|
for key in sd.keys():
|
|
x = x * key
|
|
for k, v in sd.items():
|
|
x = x * k
|
|
x = x * v
|
|
# for k in sd:
|
|
# x = x * k
|
|
|
|
if 1 in sd:
|
|
x = x * 2
|
|
else:
|
|
x = x * 3
|
|
|
|
x = x * sd.get(2, 0)
|
|
x = x * sd.get(3, 4)
|
|
x = len(sd) * x
|
|
return x
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
sd1 = SimpleDict()
|
|
sd1[2] = 5
|
|
sd1[4] = 10
|
|
|
|
sd2 = SimpleDict()
|
|
sd2[2] = 5
|
|
sd2[4] = 10
|
|
self.assertTrue(sd1 == sd2)
|
|
|
|
self.assertEqual(fn(sd1, x), opt_fn(sd2, x))
|
|
self.assertTrue(sd1 == sd2)
|
|
|
|
def test_dict_subclass_setitem(self):
|
|
class SetItemDict(dict):
|
|
def __setitem__(self, key, value):
|
|
super().__setitem__(key, value + 1)
|
|
|
|
def fn(x):
|
|
sd = SetItemDict(x=5 * x)
|
|
sd["y"] = 10
|
|
return sd
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(type(ref), type(res))
|
|
self.assertEqual(ref["x"], res["x"])
|
|
self.assertEqual(ref["y"], res["y"])
|
|
|
|
def test_custom_iter_dict(self):
|
|
class ReversedDict(dict):
|
|
def __iter__(self):
|
|
return reversed(list(self.keys()))
|
|
|
|
d = {
|
|
"foo": 1,
|
|
"bar": 2,
|
|
}
|
|
|
|
d = ReversedDict(d)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x, d):
|
|
# Forces side effects attribute reapplication logic
|
|
d.sample = 1
|
|
d["baz"] = 4
|
|
return x * d["foo"] * d["bar"]
|
|
|
|
fn(torch.randn(4), d)
|
|
# This is intentional because the dict is mutated, so we will have a recompilation.
|
|
fn(torch.randn(4), d)
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
fn(torch.randn(4), d)
|
|
|
|
def test_custom_keys_iter_dict(self):
|
|
class ReversedDict(dict):
|
|
def keys(self):
|
|
return ["bar", "foo"]
|
|
|
|
d = {
|
|
"foo": 1,
|
|
"bar": 2,
|
|
}
|
|
|
|
d = ReversedDict(d)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x, d):
|
|
return x * d["foo"] * d["bar"]
|
|
|
|
fn(torch.randn(4), d)
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
fn(torch.randn(4), d)
|
|
|
|
def test_dict_guard_on_keys_order(self):
|
|
d = {
|
|
2: 4,
|
|
3: 5,
|
|
}
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x, d):
|
|
for key, value in d.items():
|
|
x = x * key + value
|
|
return x
|
|
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
opt_fn(torch.randn(4), d)
|
|
opt_fn(torch.randn(4), d)
|
|
# No recompilation
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
# move 2 to the end
|
|
d[2] = d.pop(2)
|
|
|
|
x = torch.randn(4)
|
|
res = opt_fn(x, d)
|
|
# Check recompilation
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(res, fn(x, d))
|
|
|
|
def test_dict_guard_on_keys_order2(self):
|
|
d = {
|
|
2: 4,
|
|
3: 5,
|
|
}
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x, d):
|
|
for key in d:
|
|
value = d[key]
|
|
x = x * key + value
|
|
return x
|
|
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
opt_fn(torch.randn(4), d)
|
|
opt_fn(torch.randn(4), d)
|
|
# No recompilation
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
# move 2 to the end
|
|
d[2] = d.pop(2)
|
|
|
|
x = torch.randn(4)
|
|
res = opt_fn(x, d)
|
|
# Check recompilation
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(res, fn(x, d))
|
|
|
|
def test_ordered_dict_reordered_keys(self):
|
|
d = OrderedDict()
|
|
d[2] = 4
|
|
d[3] = 5
|
|
d.move_to_end(2)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x, d):
|
|
y = 0
|
|
for idx, (key, value) in enumerate(d.items()):
|
|
if idx == 0:
|
|
y += torch.sin(x * value)
|
|
else:
|
|
y += torch.cos(x * value)
|
|
return y
|
|
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
x = torch.randn(4)
|
|
self.assertEqual(opt_fn(x, d), fn(x, d))
|
|
|
|
def test_ordered_dict_subclass_reordered_keys(self):
|
|
class ODSubclass(OrderedDict):
|
|
def keys(self):
|
|
return super().keys()
|
|
|
|
d = ODSubclass()
|
|
d[2] = 4
|
|
d[3] = 5
|
|
d.move_to_end(2)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x, d):
|
|
y = 0
|
|
for idx, (key, value) in enumerate(d.items()):
|
|
if idx == 0:
|
|
y += torch.sin(x * value)
|
|
else:
|
|
y += torch.cos(x * value)
|
|
return y
|
|
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
x = torch.randn(4)
|
|
self.assertEqual(opt_fn(x, d), fn(x, d))
|
|
|
|
def test_lazy_key_guarding(self):
|
|
d = {"a": 2, "b": 3, "c": 5}
|
|
|
|
def fn(x):
|
|
return x * d["a"]
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
# Since key c was not used, it should not lead to a recompilation
|
|
d.pop("c")
|
|
d["d"] = 10
|
|
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_lazy_key_non_const_guarding(self):
|
|
d = {
|
|
list: 2,
|
|
dict: 3,
|
|
OrderedDict: 5,
|
|
namedtuple: 7,
|
|
}
|
|
|
|
def fn(x):
|
|
return x * d[list]
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
# Since key c was not used, it should not lead to a recompilation
|
|
d.pop(dict)
|
|
d[defaultdict] = 10
|
|
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_dict_mutation_side_effect(self):
|
|
def fn(d):
|
|
d["c"] = d["a"] + d.pop("b")
|
|
return d
|
|
|
|
args1 = {"a": torch.randn(10), "b": torch.randn(10)}
|
|
args2 = dict(args1)
|
|
assert fn(args1) is args1
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
self.assertIs(opt_fn(args2), args2)
|
|
self.assertTrue(same(args1, args2))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 1)
|
|
|
|
def test_dict_copy_alias(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def run(x, d0):
|
|
d1 = d0.copy()
|
|
d1[0] = 1
|
|
return x + 1, d1
|
|
|
|
d0 = {}
|
|
res, d1 = run(torch.zeros(1), d0)
|
|
self.assertTrue(same(res, torch.ones(1)))
|
|
self.assertEqual(d0, {})
|
|
self.assertEqual(d1, {0: 1})
|
|
|
|
def test_dict_subclass_get_method(self):
|
|
class dotdict(dict):
|
|
"""dot.notation access to dictionary attributes"""
|
|
|
|
__getattr__ = dict.get
|
|
__setattr__ = dict.__setitem__
|
|
__delattr__ = dict.__delitem__
|
|
|
|
config = dotdict({"a": 1, "b": 2})
|
|
|
|
def fn(x):
|
|
x2 = x * 2 # noqa: F841
|
|
x3 = x * config.get("a", 3)
|
|
return x3
|
|
|
|
x = torch.randn(2)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_order_keys(self):
|
|
def fn(d):
|
|
c = 0
|
|
for v in d.values():
|
|
c += v
|
|
return c
|
|
|
|
args1 = {}
|
|
args1["a"] = torch.rand(10)
|
|
args1["b"] = torch.rand(10)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
self.assertEqual(fn(args1), opt_fn(args1))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
# A different order of keys recompiles
|
|
args2 = {}
|
|
args2["b"] = args1["b"]
|
|
args2["a"] = args1["a"]
|
|
self.assertEqual(fn(args2), opt_fn(args2))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
# Extra calls don't recompile
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_dict_namedtuple(self):
|
|
def fn(d):
|
|
if namedtuple in d:
|
|
return d[3] * 2
|
|
else:
|
|
return d[3] * 3
|
|
|
|
args1 = {namedtuple: None, 3: torch.randn(3)}
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
self.assertEqual(fn(args1), opt_fn(args1))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
# Test a failing namedtuple guard
|
|
args2 = {2: None, 3: torch.randn(3)}
|
|
self.assertEqual(fn(args2), opt_fn(args2))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_dict_order_keys_tensors(self):
|
|
def fn(d, x):
|
|
return d[x] + 3
|
|
|
|
args1 = {}
|
|
x = torch.randn(10)
|
|
y = torch.randn(10)
|
|
z = torch.randn(10)
|
|
args1[x] = y
|
|
args1[3] = z
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
self.assertEqual(fn(args1, x), opt_fn(args1, x))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
# Calling again doesn't recompile (same id and key order)
|
|
opt_fn(args1, x)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
args2 = {}
|
|
args2[3] = z
|
|
args2[x] = y
|
|
|
|
# Different order recompiles
|
|
self.assertEqual(fn(args2, x), opt_fn(args2, x))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_dict_order_keys_modules(self):
|
|
def fn(d, x):
|
|
return d[x](torch.ones(2, 2))
|
|
|
|
args1 = {}
|
|
x = torch.nn.Linear(2, 2)
|
|
y = torch.nn.Linear(2, 2)
|
|
z = torch.nn.Linear(2, 2)
|
|
args1[x] = y
|
|
args1[3] = z
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
self.assertEqual(fn(args1, x), opt_fn(args1, x))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
# Calling again doesn't recompile (same id and key order)
|
|
opt_fn(args1, x)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
args2 = {}
|
|
args2[3] = z
|
|
args2[x] = y
|
|
|
|
# Different order recompiles
|
|
self.assertEqual(fn(args2, x), opt_fn(args2, x))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_contains_dunder_dict(self):
|
|
class UserDefined:
|
|
def __init__(self) -> None:
|
|
self.a = 3
|
|
self.b = 5
|
|
|
|
def run(self, x):
|
|
if "a" in self.__dict__:
|
|
x = x * self.a
|
|
if "b" in self.__dict__:
|
|
x = x * self.b
|
|
self.c = 7
|
|
if "c" in self.__dict__:
|
|
x = x * self.c
|
|
return x * self.__dict__.get("a") * self.__dict__.get("z", 2)
|
|
|
|
obj = UserDefined()
|
|
|
|
def fn(x):
|
|
return obj.run(x)
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_contains_module_dunder_dict(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = 1
|
|
self.bar = 2
|
|
self.baz = 3
|
|
|
|
def forward(self, x):
|
|
if "foo" in self.__dict__:
|
|
return x * self.bar
|
|
return x * self.baz
|
|
|
|
mod = MyModule()
|
|
x = torch.randn(10)
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
self.assertEqual(mod(x), opt_mod(x))
|
|
|
|
def test_update_dunder_dict(self):
|
|
class UserDefined:
|
|
def run(self, x):
|
|
self.__dict__["a"] = 10
|
|
return x * self.a + self.__dict__["a"]
|
|
|
|
obj1 = UserDefined()
|
|
obj2 = UserDefined()
|
|
|
|
def fn(x, obj):
|
|
return obj.run(x)
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
ref = fn(x, obj1)
|
|
res = opt_fn(x, obj2)
|
|
self.assertEqual(ref, res)
|
|
# Make sure only `a` is updated.
|
|
self.assertEqual(obj1.__dict__, obj2.__dict__)
|
|
|
|
def test_update_module_dunder_dict(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
self.__dict__["a"] = 10
|
|
return x * self.a + self.__dict__["a"]
|
|
|
|
mod = MyModule()
|
|
x = torch.randn(10)
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
self.assertEqual(mod(x), opt_mod(x))
|
|
|
|
def test_dict_reconstruct_keeps_original_order(self):
|
|
def fn():
|
|
modules = OrderedDict([("act", torch.nn.ReLU())])
|
|
module_dict = torch.nn.ModuleDict(modules)
|
|
|
|
next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()}
|
|
modules.update(next_modules.items())
|
|
module_dict.update(next_modules)
|
|
return modules, module_dict
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnts)
|
|
modules, module_dict = opt_fn()
|
|
|
|
self.assertEqual(len(module_dict), len(modules))
|
|
for k1, m2 in zip(modules, module_dict.children()):
|
|
self.assertTrue(modules[k1] is m2)
|
|
|
|
def test_dict_subclass_initialization_in_graph(self):
|
|
for super_class in (
|
|
OrderedDict,
|
|
dict,
|
|
):
|
|
|
|
class CustomDict(super_class):
|
|
def __new__(cls, *args, **kwargs):
|
|
return super().__new__(cls, *args, **kwargs)
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def fn(x):
|
|
c = CustomDict()
|
|
c["key"] = x
|
|
assert "key" in c
|
|
return c["key"] + 1
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
x = torch.rand(4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_dict_list_values(self):
|
|
def inner_fn(args):
|
|
return [x[1].shape for x in args]
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(tensors):
|
|
return inner_fn(zip(itertools.count(), tensors["args"]))
|
|
|
|
fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]})
|
|
fn({"args": [torch.ones(5, 5)]})
|
|
|
|
def test_dict_iter(self):
|
|
class MyMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
z = {"my": 1, "const": 2, "dict": 3, "variable": 4}
|
|
tot = 0
|
|
for key in z:
|
|
tot += z[key]
|
|
|
|
return tot
|
|
|
|
x = torch.tensor([0])
|
|
model = MyMod()
|
|
opt_model = torch.compile(model, backend="eager", fullgraph=True)
|
|
y = opt_model(x)
|
|
|
|
self.assertEqual(y, 10)
|
|
|
|
def test_dict_subclass_contains(self):
|
|
# pattern from huggingface
|
|
class ClassInstantier(OrderedDict):
|
|
pass
|
|
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def f(x, d):
|
|
if "key1" in d:
|
|
x = x + 2
|
|
if "key2" in d:
|
|
x = x + 4
|
|
x = x + 8
|
|
return x
|
|
|
|
result = f(torch.ones(8), ClassInstantier({"key1": torch.ones(8)}))
|
|
self.assertTrue(same(result, torch.full([8], 11.0)))
|
|
|
|
result = f(torch.ones(8), ClassInstantier({"key2": torch.ones(8)}))
|
|
self.assertTrue(same(result, torch.full([8], 13.0)))
|
|
|
|
def test_dict_tag_guard(self):
|
|
class Foo:
|
|
def __init__(self) -> None:
|
|
self.scalar = 10
|
|
|
|
def fn(d, x):
|
|
return d["a"] * d["b"] * d["c"].scalar * x
|
|
|
|
foo = Foo()
|
|
|
|
d = {"a": 2, "b": 3, "c": foo}
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
inp = torch.randn(3, 3)
|
|
self.assertEqual(fn(d, inp), opt_fn(d, inp))
|
|
|
|
d["a"] = 4
|
|
self.assertEqual(fn(d, inp), opt_fn(d, inp))
|
|
|
|
# Check that recompilation happens
|
|
foo.scalar = 12
|
|
self.assertEqual(fn(d, inp), opt_fn(d, inp))
|
|
|
|
def test_empty_dict_recompilation(self):
|
|
def fn(d, x):
|
|
if d:
|
|
return torch.cos(x)
|
|
return torch.sin(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
self.assertEqual(fn({}, x), opt_fn({}, x))
|
|
self.assertEqual(fn({"a": 1}, x), opt_fn({"a": 1}, x))
|
|
|
|
def test_udf_dict_reconstruction(self):
|
|
class MyDict(dict):
|
|
pass
|
|
|
|
def fn(x, klass):
|
|
x = x * 2
|
|
sc_dict = dict.__new__(klass)
|
|
sc_dict["x"] = x
|
|
if isinstance(sc_dict, MyDict):
|
|
sc_dict.attr = 3
|
|
return sc_dict
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
ref = fn(x, MyDict)
|
|
res = opt_fn(x, MyDict)
|
|
self.assertEqual(ref, res)
|
|
self.assertTrue(isinstance(res, MyDict))
|
|
self.assertEqual(ref.attr, res.attr)
|
|
|
|
ref = fn(x, dict)
|
|
res = opt_fn(x, dict)
|
|
self.assertEqual(ref, res)
|
|
self.assertTrue(isinstance(res, dict))
|
|
|
|
def test_weakref_dict(self):
|
|
states = weakref.WeakKeyDictionary()
|
|
|
|
mod1 = torch.nn.Module()
|
|
mod2 = torch.nn.Module()
|
|
|
|
states[mod1] = 2
|
|
states[mod2] = 3
|
|
|
|
def fn(x):
|
|
if mod1 in states:
|
|
x = torch.sin(x)
|
|
if mod2 in states:
|
|
x = torch.cos(x)
|
|
return x
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_fn_id(self):
|
|
def fn(x, f):
|
|
d = {id(f): 3}
|
|
return x * d[id(f)]
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
|
|
def nothing():
|
|
pass
|
|
|
|
f = nothing
|
|
self.assertEqual(fn(x, f), opt_fn(x, f))
|
|
|
|
def test_mapping_proxy_for_local(self):
|
|
def fn(x):
|
|
d = {"a": 2, "b": 3, "c": 5 * x}
|
|
mp = types.MappingProxyType(d)
|
|
y = torch.sin(x * mp["a"])
|
|
for k, v in mp.items(): # noqa: PERF102
|
|
y += torch.cos(x * v)
|
|
return mp
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertTrue(type(res) is types.MappingProxyType)
|
|
|
|
def test_mapping_proxy_for_nonlocal(self):
|
|
d = {"a": 2, "b": 3, "c": 5}
|
|
|
|
def fn(x):
|
|
mp = types.MappingProxyType(d)
|
|
y = torch.sin(x * mp["a"])
|
|
for k, v in mp.items(): # noqa: PERF102
|
|
y += torch.cos(x * v)
|
|
d["d"] = 4
|
|
return mp
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertTrue(type(res) is types.MappingProxyType)
|
|
|
|
# check update to d is reflected in res
|
|
d["e"] = 5
|
|
self.assertEqual(d["e"], res["e"])
|
|
|
|
def test_mapping_proxy_existing(self):
|
|
d = {"a": 2, "b": 3, "c": 5}
|
|
|
|
def fn(x, mp):
|
|
y = torch.sin(x * mp["a"])
|
|
for k, v in mp.items(): # noqa: PERF102
|
|
y += torch.cos(x * v)
|
|
if isinstance(mp, types.MappingProxyType):
|
|
y *= 2
|
|
return y
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
mp = types.MappingProxyType(d)
|
|
ref = fn(x, mp)
|
|
res = opt_fn(x, mp)
|
|
self.assertEqual(ref, res)
|
|
|
|
d["a"] = 3
|
|
ref = fn(x, mp)
|
|
res = opt_fn(x, mp)
|
|
self.assertEqual(ref, res)
|
|
|
|
d.pop("b")
|
|
ref = fn(x, mp)
|
|
res = opt_fn(x, mp)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_dict_construction_from_mapping_proxy(self):
|
|
d = {"a": 2, "b": 3, "c": 5}
|
|
|
|
def fn(x, mp):
|
|
d = dict(mp)
|
|
y = torch.sin(x * d["a"])
|
|
return y
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
mp = types.MappingProxyType(d)
|
|
ref = fn(x, mp)
|
|
res = opt_fn(x, mp)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_mapping_proxy_existing_mutation(self):
|
|
d = {"a": 2, "b": 3, "c": 5}
|
|
|
|
mp = types.MappingProxyType(d)
|
|
|
|
def fn(x):
|
|
d["d"] = 4
|
|
y = torch.sin(x * mp["d"])
|
|
return y
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
x = torch.randn(4)
|
|
ref = torch.sin(x * 4)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(d.keys(), mp.keys())
|
|
|
|
def test_mapping_proxy_existing_local_mutation(self):
|
|
d = {"a": 2, "b": 3, "c": 5}
|
|
|
|
mp = types.MappingProxyType(d)
|
|
|
|
def fn(x):
|
|
# Dynamo should not cause a graph break here because it knows that
|
|
# the existing proxy cant point to this new dict
|
|
other_dict = {}
|
|
other_dict["d"] = 4
|
|
y = torch.sin(x * mp["c"])
|
|
return y
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
ref = torch.sin(x * mp["c"])
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(d.keys(), mp.keys())
|
|
|
|
def test_move_to_end(self):
|
|
def fn(x):
|
|
d = OrderedDict({"a": torch.cos(x), "b": 3, "c": 5})
|
|
d.move_to_end("a")
|
|
return d
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
self.assertEqual(["b", "c", "a"], list(opt_fn(x).keys()))
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_overridden_get_item(self):
|
|
class MyDict(dict):
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.calls = 0
|
|
|
|
def __getitem__(self, key):
|
|
self.calls += 1
|
|
return super().__getitem__(key) + 1
|
|
|
|
def fn(x, d):
|
|
d["d"] = 4
|
|
return x * d["a"] + d["b"] + d["c"] + d["d"]
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
d1 = MyDict({"a": 2, "b": 3, "c": 5})
|
|
ref = fn(x, d1)
|
|
|
|
d2 = MyDict({"a": 2, "b": 3, "c": 5})
|
|
res = opt_fn(x, d2)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(d1.calls, d2.calls)
|
|
|
|
def test_items_type(self):
|
|
def fn():
|
|
d = dict({"a": 1, "b": "2", "c": torch.tensor(3)}) # noqa: C418
|
|
return d.items()
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
ref = fn()
|
|
res = opt_fn()
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(type(res), dict_items)
|
|
|
|
def test_builtin_or_with_invalid_types(self):
|
|
args = (
|
|
1, # int
|
|
1.0, # float
|
|
"a", # str
|
|
(1, 2), # tuple
|
|
[1, 2], # list
|
|
)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(b: Any):
|
|
a = {"one": torch.ones(1)}
|
|
return a | b
|
|
|
|
from torch._dynamo.exc import InternalTorchDynamoError
|
|
|
|
for arg in args:
|
|
with self.assertRaisesRegex(
|
|
InternalTorchDynamoError, "unsupported operand type"
|
|
):
|
|
_ = fn(arg)
|
|
|
|
def test_builtin_or_with_diff_keys(self):
|
|
def f():
|
|
a = {"one": torch.ones(1)}
|
|
b = {"two": torch.ones(2)}
|
|
return a, b, a | b, b | a, a.__or__(b), b.__or__(a)
|
|
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
self.assertEqual(f(), opt_f())
|
|
|
|
def test_builtin_or_with_same_keys(self):
|
|
def f():
|
|
a = {"one": torch.ones(1), "two": torch.ones(2)}
|
|
b = {"one": torch.ones(1), "three": torch.ones(3)}
|
|
return a, b, a | b, b | a, a.__or__(b), b.__or__(a)
|
|
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
self.assertEqual(f(), opt_f())
|
|
|
|
def test_builtin_ior_(self):
|
|
def f():
|
|
a = {"one": torch.ones(1)}
|
|
b = {"two": torch.ones(2)}
|
|
a |= b
|
|
return a, b
|
|
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
self.assertEqual(f(), opt_f())
|
|
|
|
def test_newly_constructed_default_dict(self):
|
|
def f(x):
|
|
d = defaultdict(list)
|
|
d[0] = 42
|
|
return x + 1, d
|
|
|
|
x = torch.ones(2)
|
|
ref = f(x)
|
|
res = torch.compile(f, backend="eager", fullgraph=True)(x)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|