# Owner(s): ["module: dynamo"] # ruff: noqa: TRY002 import itertools import types import unittest import weakref from collections import defaultdict, namedtuple, OrderedDict from typing import Any import torch import torch._dynamo.test_case import torch._dynamo.testing import torch._functorch.config import torch.nn import torch.utils.checkpoint from torch._dynamo.testing import same from torch._dynamo.utils import dict_items class SimpleDict(dict): pass class DictTests(torch._dynamo.test_case.TestCase): def test_dict_subclass_instantiation(self): def fn(x): sd = SimpleDict(x=5) return sd["x"] * x x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(fn(x), opt_fn(x)) def test_dict_subclass_local_mutation(self): def fn(x): sd = SimpleDict(x=5) z = sd["x"] * x sd["x"] = 10 return z * sd["x"] x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(fn(x), opt_fn(x)) def test_dict_subclass_local_with_non_dict_method(self): # Checks that add_1 method is inlined class MethodDict(dict): def add_1(self, x): return x + 1 def fn(x): sd = MethodDict(x=5) z = sd["x"] * x sd["x"] = 10 return sd.add_1(z * sd["x"]) x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(fn(x), opt_fn(x)) def test_dict_contains(self): sd = dict() sd[2] = 5 sd[4] = 10 def fn(x): if 1 in sd: x = x * 2 else: x = x * 3 return x x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(fn(x), opt_fn(x)) # Ensure a recompilation sd[1] = 15 self.assertEqual(fn(x), opt_fn(x)) # Ensure not recompilation because the traced program remains same here. sd[2] = 10 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): self.assertEqual(fn(x), opt_fn(x)) def test_dict_subclass_methods_fallback_readonly(self): sd = SimpleDict() sd[2] = 5 sd[4] = 10 # check that regular attr accesses work well sd.attr = 4 def fn(x): for value in sd.values(): x = x * value for key in sd.keys(): x = x * key for k, v in sd.items(): x = x * k x = x * v # for k in sd: # x = x * k if 1 in sd: x = x * 2 else: x = x * 3 x = x * sd.get(2, 0) x = x * sd.get(3, 4) x = len(sd) * x x = x * sd.attr return x x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(fn(x), opt_fn(x)) # Ensure a recompilation sd[6] = 15 self.assertEqual(fn(x), opt_fn(x)) def test_dict_subclass_instantiation_return(self): def fn(x): sd = SimpleDict(x=5 * x) sd["y"] = 10 return sd x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) ref = fn(x) res = opt_fn(x) self.assertEqual(type(ref), type(res)) self.assertEqual(ref["x"], res["x"]) self.assertEqual(ref["y"], res["y"]) def test_dict_subclass_methods_fallback_mutation(self): def fn(sd, x): for value in sd.values(): x = x * value sd[6] = 14 for key in sd.keys(): x = x * key for k, v in sd.items(): x = x * k x = x * v # for k in sd: # x = x * k if 1 in sd: x = x * 2 else: x = x * 3 x = x * sd.get(2, 0) x = x * sd.get(3, 4) x = len(sd) * x return x x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) sd1 = SimpleDict() sd1[2] = 5 sd1[4] = 10 sd2 = SimpleDict() sd2[2] = 5 sd2[4] = 10 self.assertTrue(sd1 == sd2) self.assertEqual(fn(sd1, x), opt_fn(sd2, x)) self.assertTrue(sd1 == sd2) def test_dict_subclass_setitem(self): class SetItemDict(dict): def __setitem__(self, key, value): super().__setitem__(key, value + 1) def fn(x): sd = SetItemDict(x=5 * x) sd["y"] = 10 return sd x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) ref = fn(x) res = opt_fn(x) self.assertEqual(type(ref), type(res)) self.assertEqual(ref["x"], res["x"]) self.assertEqual(ref["y"], res["y"]) def test_custom_iter_dict(self): class ReversedDict(dict): def __iter__(self): return reversed(list(self.keys())) d = { "foo": 1, "bar": 2, } d = ReversedDict(d) @torch.compile(backend="eager") def fn(x, d): # Forces side effects attribute reapplication logic d.sample = 1 d["baz"] = 4 return x * d["foo"] * d["bar"] fn(torch.randn(4), d) # This is intentional because the dict is mutated, so we will have a recompilation. fn(torch.randn(4), d) with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): fn(torch.randn(4), d) def test_custom_keys_iter_dict(self): class ReversedDict(dict): def keys(self): return ["bar", "foo"] d = { "foo": 1, "bar": 2, } d = ReversedDict(d) @torch.compile(backend="eager") def fn(x, d): return x * d["foo"] * d["bar"] fn(torch.randn(4), d) with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): fn(torch.randn(4), d) def test_dict_guard_on_keys_order(self): d = { 2: 4, 3: 5, } cnts = torch._dynamo.testing.CompileCounter() def fn(x, d): for key, value in d.items(): x = x * key + value return x opt_fn = torch.compile(fn, backend=cnts) opt_fn(torch.randn(4), d) opt_fn(torch.randn(4), d) # No recompilation self.assertEqual(cnts.frame_count, 1) # move 2 to the end d[2] = d.pop(2) x = torch.randn(4) res = opt_fn(x, d) # Check recompilation self.assertEqual(cnts.frame_count, 2) self.assertEqual(res, fn(x, d)) def test_dict_guard_on_keys_order2(self): d = { 2: 4, 3: 5, } cnts = torch._dynamo.testing.CompileCounter() def fn(x, d): for key in d: value = d[key] x = x * key + value return x opt_fn = torch.compile(fn, backend=cnts) opt_fn(torch.randn(4), d) opt_fn(torch.randn(4), d) # No recompilation self.assertEqual(cnts.frame_count, 1) # move 2 to the end d[2] = d.pop(2) x = torch.randn(4) res = opt_fn(x, d) # Check recompilation self.assertEqual(cnts.frame_count, 2) self.assertEqual(res, fn(x, d)) def test_ordered_dict_reordered_keys(self): d = OrderedDict() d[2] = 4 d[3] = 5 d.move_to_end(2) cnts = torch._dynamo.testing.CompileCounter() def fn(x, d): y = 0 for idx, (key, value) in enumerate(d.items()): if idx == 0: y += torch.sin(x * value) else: y += torch.cos(x * value) return y opt_fn = torch.compile(fn, backend=cnts) x = torch.randn(4) self.assertEqual(opt_fn(x, d), fn(x, d)) def test_ordered_dict_subclass_reordered_keys(self): class ODSubclass(OrderedDict): def keys(self): return super().keys() d = ODSubclass() d[2] = 4 d[3] = 5 d.move_to_end(2) cnts = torch._dynamo.testing.CompileCounter() def fn(x, d): y = 0 for idx, (key, value) in enumerate(d.items()): if idx == 0: y += torch.sin(x * value) else: y += torch.cos(x * value) return y opt_fn = torch.compile(fn, backend=cnts) x = torch.randn(4) self.assertEqual(opt_fn(x, d), fn(x, d)) def test_lazy_key_guarding(self): d = {"a": 2, "b": 3, "c": 5} def fn(x): return x * d["a"] opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) ref = fn(x) res = opt_fn(x) self.assertEqual(ref, res) # Since key c was not used, it should not lead to a recompilation d.pop("c") d["d"] = 10 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): ref = fn(x) res = opt_fn(x) self.assertEqual(ref, res) def test_lazy_key_non_const_guarding(self): d = { list: 2, dict: 3, OrderedDict: 5, namedtuple: 7, } def fn(x): return x * d[list] opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) ref = fn(x) res = opt_fn(x) self.assertEqual(ref, res) # Since key c was not used, it should not lead to a recompilation d.pop(dict) d[defaultdict] = 10 with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): ref = fn(x) res = opt_fn(x) self.assertEqual(ref, res) def test_dict_mutation_side_effect(self): def fn(d): d["c"] = d["a"] + d.pop("b") return d args1 = {"a": torch.randn(10), "b": torch.randn(10)} args2 = dict(args1) assert fn(args1) is args1 cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) self.assertIs(opt_fn(args2), args2) self.assertTrue(same(args1, args2)) self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 1) def test_dict_copy_alias(self): @torch.compile(backend="eager", fullgraph=True) def run(x, d0): d1 = d0.copy() d1[0] = 1 return x + 1, d1 d0 = {} res, d1 = run(torch.zeros(1), d0) self.assertTrue(same(res, torch.ones(1))) self.assertEqual(d0, {}) self.assertEqual(d1, {0: 1}) def test_dict_subclass_get_method(self): class dotdict(dict): """dot.notation access to dictionary attributes""" __getattr__ = dict.get __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ config = dotdict({"a": 1, "b": 2}) def fn(x): x2 = x * 2 # noqa: F841 x3 = x * config.get("a", 3) return x3 x = torch.randn(2) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) self.assertEqual(fn(x), opt_fn(x)) def test_dict_order_keys(self): def fn(d): c = 0 for v in d.values(): c += v return c args1 = {} args1["a"] = torch.rand(10) args1["b"] = torch.rand(10) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) self.assertEqual(fn(args1), opt_fn(args1)) self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 2) # A different order of keys recompiles args2 = {} args2["b"] = args1["b"] args2["a"] = args1["a"] self.assertEqual(fn(args2), opt_fn(args2)) self.assertEqual(cnts.frame_count, 2) # Extra calls don't recompile self.assertEqual(cnts.frame_count, 2) def test_dict_namedtuple(self): def fn(d): if namedtuple in d: return d[3] * 2 else: return d[3] * 3 args1 = {namedtuple: None, 3: torch.randn(3)} cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) self.assertEqual(fn(args1), opt_fn(args1)) self.assertEqual(cnts.frame_count, 1) # Test a failing namedtuple guard args2 = {2: None, 3: torch.randn(3)} self.assertEqual(fn(args2), opt_fn(args2)) self.assertEqual(cnts.frame_count, 2) def test_dict_order_keys_tensors(self): def fn(d, x): return d[x] + 3 args1 = {} x = torch.randn(10) y = torch.randn(10) z = torch.randn(10) args1[x] = y args1[3] = z cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) self.assertEqual(fn(args1, x), opt_fn(args1, x)) self.assertEqual(cnts.frame_count, 1) # Calling again doesn't recompile (same id and key order) opt_fn(args1, x) self.assertEqual(cnts.frame_count, 1) args2 = {} args2[3] = z args2[x] = y # Different order recompiles self.assertEqual(fn(args2, x), opt_fn(args2, x)) self.assertEqual(cnts.frame_count, 2) def test_dict_order_keys_modules(self): def fn(d, x): return d[x](torch.ones(2, 2)) args1 = {} x = torch.nn.Linear(2, 2) y = torch.nn.Linear(2, 2) z = torch.nn.Linear(2, 2) args1[x] = y args1[3] = z cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) self.assertEqual(fn(args1, x), opt_fn(args1, x)) self.assertEqual(cnts.frame_count, 1) # Calling again doesn't recompile (same id and key order) opt_fn(args1, x) self.assertEqual(cnts.frame_count, 1) args2 = {} args2[3] = z args2[x] = y # Different order recompiles self.assertEqual(fn(args2, x), opt_fn(args2, x)) self.assertEqual(cnts.frame_count, 2) def test_contains_dunder_dict(self): class UserDefined: def __init__(self) -> None: self.a = 3 self.b = 5 def run(self, x): if "a" in self.__dict__: x = x * self.a if "b" in self.__dict__: x = x * self.b self.c = 7 if "c" in self.__dict__: x = x * self.c return x * self.__dict__.get("a") * self.__dict__.get("z", 2) obj = UserDefined() def fn(x): return obj.run(x) x = torch.randn(4) ref = fn(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) res = opt_fn(x) self.assertEqual(ref, res) def test_contains_module_dunder_dict(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() self.foo = 1 self.bar = 2 self.baz = 3 def forward(self, x): if "foo" in self.__dict__: return x * self.bar return x * self.baz mod = MyModule() x = torch.randn(10) opt_mod = torch.compile(mod, backend="eager", fullgraph=True) self.assertEqual(mod(x), opt_mod(x)) def test_update_dunder_dict(self): class UserDefined: def run(self, x): self.__dict__["a"] = 10 return x * self.a + self.__dict__["a"] obj1 = UserDefined() obj2 = UserDefined() def fn(x, obj): return obj.run(x) x = torch.randn(4) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) ref = fn(x, obj1) res = opt_fn(x, obj2) self.assertEqual(ref, res) # Make sure only `a` is updated. self.assertEqual(obj1.__dict__, obj2.__dict__) def test_update_module_dunder_dict(self): class MyModule(torch.nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x): self.__dict__["a"] = 10 return x * self.a + self.__dict__["a"] mod = MyModule() x = torch.randn(10) opt_mod = torch.compile(mod, backend="eager", fullgraph=True) self.assertEqual(mod(x), opt_mod(x)) def test_dict_reconstruct_keeps_original_order(self): def fn(): modules = OrderedDict([("act", torch.nn.ReLU())]) module_dict = torch.nn.ModuleDict(modules) next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()} modules.update(next_modules.items()) module_dict.update(next_modules) return modules, module_dict cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch.compile(fn, backend=cnts) modules, module_dict = opt_fn() self.assertEqual(len(module_dict), len(modules)) for k1, m2 in zip(modules, module_dict.children()): self.assertTrue(modules[k1] is m2) def test_dict_subclass_initialization_in_graph(self): for super_class in ( OrderedDict, dict, ): class CustomDict(super_class): def __new__(cls, *args, **kwargs): return super().__new__(cls, *args, **kwargs) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def fn(x): c = CustomDict() c["key"] = x assert "key" in c return c["key"] + 1 opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.rand(4) self.assertEqual(fn(x), opt_fn(x)) def test_dict_list_values(self): def inner_fn(args): return [x[1].shape for x in args] @torch.compile(backend="eager") def fn(tensors): return inner_fn(zip(itertools.count(), tensors["args"])) fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]}) fn({"args": [torch.ones(5, 5)]}) def test_dict_iter(self): class MyMod(torch.nn.Module): def forward(self, x): z = {"my": 1, "const": 2, "dict": 3, "variable": 4} tot = 0 for key in z: tot += z[key] return tot x = torch.tensor([0]) model = MyMod() opt_model = torch.compile(model, backend="eager", fullgraph=True) y = opt_model(x) self.assertEqual(y, 10) def test_dict_subclass_contains(self): # pattern from huggingface class ClassInstantier(OrderedDict): pass @torch.compile(fullgraph=True, backend="eager") def f(x, d): if "key1" in d: x = x + 2 if "key2" in d: x = x + 4 x = x + 8 return x result = f(torch.ones(8), ClassInstantier({"key1": torch.ones(8)})) self.assertTrue(same(result, torch.full([8], 11.0))) result = f(torch.ones(8), ClassInstantier({"key2": torch.ones(8)})) self.assertTrue(same(result, torch.full([8], 13.0))) def test_dict_tag_guard(self): class Foo: def __init__(self) -> None: self.scalar = 10 def fn(d, x): return d["a"] * d["b"] * d["c"].scalar * x foo = Foo() d = {"a": 2, "b": 3, "c": foo} opt_fn = torch.compile(fn, backend="eager") inp = torch.randn(3, 3) self.assertEqual(fn(d, inp), opt_fn(d, inp)) d["a"] = 4 self.assertEqual(fn(d, inp), opt_fn(d, inp)) # Check that recompilation happens foo.scalar = 12 self.assertEqual(fn(d, inp), opt_fn(d, inp)) def test_empty_dict_recompilation(self): def fn(d, x): if d: return torch.cos(x) return torch.sin(x) opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) self.assertEqual(fn({}, x), opt_fn({}, x)) self.assertEqual(fn({"a": 1}, x), opt_fn({"a": 1}, x)) def test_udf_dict_reconstruction(self): class MyDict(dict): pass def fn(x, klass): x = x * 2 sc_dict = dict.__new__(klass) sc_dict["x"] = x if isinstance(sc_dict, MyDict): sc_dict.attr = 3 return sc_dict opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) ref = fn(x, MyDict) res = opt_fn(x, MyDict) self.assertEqual(ref, res) self.assertTrue(isinstance(res, MyDict)) self.assertEqual(ref.attr, res.attr) ref = fn(x, dict) res = opt_fn(x, dict) self.assertEqual(ref, res) self.assertTrue(isinstance(res, dict)) def test_weakref_dict(self): states = weakref.WeakKeyDictionary() mod1 = torch.nn.Module() mod2 = torch.nn.Module() states[mod1] = 2 states[mod2] = 3 def fn(x): if mod1 in states: x = torch.sin(x) if mod2 in states: x = torch.cos(x) return x opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) self.assertEqual(fn(x), opt_fn(x)) def test_fn_id(self): def fn(x, f): d = {id(f): 3} return x * d[id(f)] opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) def nothing(): pass f = nothing self.assertEqual(fn(x, f), opt_fn(x, f)) def test_mapping_proxy_for_local(self): def fn(x): d = {"a": 2, "b": 3, "c": 5 * x} mp = types.MappingProxyType(d) y = torch.sin(x * mp["a"]) for k, v in mp.items(): # noqa: PERF102 y += torch.cos(x * v) return mp opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) ref = fn(x) res = opt_fn(x) self.assertEqual(ref, res) self.assertTrue(type(res) is types.MappingProxyType) def test_mapping_proxy_for_nonlocal(self): d = {"a": 2, "b": 3, "c": 5} def fn(x): mp = types.MappingProxyType(d) y = torch.sin(x * mp["a"]) for k, v in mp.items(): # noqa: PERF102 y += torch.cos(x * v) d["d"] = 4 return mp opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) ref = fn(x) res = opt_fn(x) self.assertEqual(ref, res) self.assertTrue(type(res) is types.MappingProxyType) # check update to d is reflected in res d["e"] = 5 self.assertEqual(d["e"], res["e"]) def test_mapping_proxy_existing(self): d = {"a": 2, "b": 3, "c": 5} def fn(x, mp): y = torch.sin(x * mp["a"]) for k, v in mp.items(): # noqa: PERF102 y += torch.cos(x * v) if isinstance(mp, types.MappingProxyType): y *= 2 return y opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) mp = types.MappingProxyType(d) ref = fn(x, mp) res = opt_fn(x, mp) self.assertEqual(ref, res) d["a"] = 3 ref = fn(x, mp) res = opt_fn(x, mp) self.assertEqual(ref, res) d.pop("b") ref = fn(x, mp) res = opt_fn(x, mp) self.assertEqual(ref, res) def test_dict_construction_from_mapping_proxy(self): d = {"a": 2, "b": 3, "c": 5} def fn(x, mp): d = dict(mp) y = torch.sin(x * d["a"]) return y opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) mp = types.MappingProxyType(d) ref = fn(x, mp) res = opt_fn(x, mp) self.assertEqual(ref, res) def test_mapping_proxy_existing_mutation(self): d = {"a": 2, "b": 3, "c": 5} mp = types.MappingProxyType(d) def fn(x): d["d"] = 4 y = torch.sin(x * mp["d"]) return y opt_fn = torch.compile(fn, backend="eager") x = torch.randn(4) ref = torch.sin(x * 4) res = opt_fn(x) self.assertEqual(ref, res) self.assertEqual(d.keys(), mp.keys()) def test_mapping_proxy_existing_local_mutation(self): d = {"a": 2, "b": 3, "c": 5} mp = types.MappingProxyType(d) def fn(x): # Dynamo should not cause a graph break here because it knows that # the existing proxy cant point to this new dict other_dict = {} other_dict["d"] = 4 y = torch.sin(x * mp["c"]) return y opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) ref = torch.sin(x * mp["c"]) res = opt_fn(x) self.assertEqual(ref, res) self.assertEqual(d.keys(), mp.keys()) def test_move_to_end(self): def fn(x): d = OrderedDict({"a": torch.cos(x), "b": 3, "c": 5}) d.move_to_end("a") return d opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) self.assertEqual(["b", "c", "a"], list(opt_fn(x).keys())) self.assertEqual(fn(x), opt_fn(x)) def test_overridden_get_item(self): class MyDict(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.calls = 0 def __getitem__(self, key): self.calls += 1 return super().__getitem__(key) + 1 def fn(x, d): d["d"] = 4 return x * d["a"] + d["b"] + d["c"] + d["d"] opt_fn = torch.compile(fn, backend="eager", fullgraph=True) x = torch.randn(4) d1 = MyDict({"a": 2, "b": 3, "c": 5}) ref = fn(x, d1) d2 = MyDict({"a": 2, "b": 3, "c": 5}) res = opt_fn(x, d2) self.assertEqual(ref, res) self.assertEqual(d1.calls, d2.calls) def test_items_type(self): def fn(): d = dict({"a": 1, "b": "2", "c": torch.tensor(3)}) # noqa: C418 return d.items() opt_fn = torch.compile(fn, backend="eager", fullgraph=True) ref = fn() res = opt_fn() self.assertEqual(ref, res) self.assertEqual(type(res), dict_items) def test_builtin_or_with_invalid_types(self): args = ( 1, # int 1.0, # float "a", # str (1, 2), # tuple [1, 2], # list ) @torch.compile(backend="eager", fullgraph=True) def fn(b: Any): a = {"one": torch.ones(1)} return a | b from torch._dynamo.exc import InternalTorchDynamoError for arg in args: with self.assertRaisesRegex( InternalTorchDynamoError, "unsupported operand type" ): _ = fn(arg) def test_builtin_or_with_diff_keys(self): def f(): a = {"one": torch.ones(1)} b = {"two": torch.ones(2)} return a, b, a | b, b | a, a.__or__(b), b.__or__(a) opt_f = torch.compile(f, backend="eager", fullgraph=True) self.assertEqual(f(), opt_f()) def test_builtin_or_with_same_keys(self): def f(): a = {"one": torch.ones(1), "two": torch.ones(2)} b = {"one": torch.ones(1), "three": torch.ones(3)} return a, b, a | b, b | a, a.__or__(b), b.__or__(a) opt_f = torch.compile(f, backend="eager", fullgraph=True) self.assertEqual(f(), opt_f()) def test_builtin_ior_(self): def f(): a = {"one": torch.ones(1)} b = {"two": torch.ones(2)} a |= b return a, b opt_f = torch.compile(f, backend="eager", fullgraph=True) self.assertEqual(f(), opt_f()) def test_newly_constructed_default_dict(self): def f(x): d = defaultdict(list) d[0] = 42 return x + 1, d x = torch.ones(2) ref = f(x) res = torch.compile(f, backend="eager", fullgraph=True)(x) self.assertEqual(ref, res) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()