# Owner(s): ["module: dynamo"] import os import unittest import torch import torch._dynamo.testing import torch._inductor.config import torch._inductor.test_case import torch.onnx.operators import torch.utils.cpp_extension from torch._dynamo.package import CompilePackage, DynamoStore from torch._functorch import config as functorch_config from torch._inductor.runtime.runtime_utils import cache_dir from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, ) from torch.testing._internal.inductor_utils import HAS_CUDA @functorch_config.patch("bundled_autograd_cache", True) @instantiate_parametrized_tests class TestPackage(torch._inductor.test_case.TestCase): def path(self): path = os.path.join(cache_dir(), f"package_{self.id()}") os.makedirs(path, exist_ok=True) return path @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda")) def test_basic_fn(self, backend, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") ctx = DynamoStore() def fn(x): return x + 1 args = ( torch.randn( 3, 2, device=device, ), ) # Saving package = CompilePackage(fn) compiled_fn = torch._dynamo.optimize(backend, package=package)(fn) expected = compiled_fn(*args) if backend == "eager": for backend_id, backend in package.cached_backends.items(): ctx.record_eager_backend(backend_id, backend) ctx.save_package(package, self.path()) # Loading torch._dynamo.reset() with torch.compiler.set_stance("fail_on_recompile"): with self.assertRaisesRegex( RuntimeError, "Detected recompile when torch.compile stance is 'fail_on_recompile'", ): compiled_fn(*args) package, backends = ctx.load_package(fn, self.path()) compiled_fn = torch._dynamo.optimize(package=package)(fn) package.install(backends) self.assertEqual(expected, compiled_fn(*args)) @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda")) def test_graph_break_bomb(self, backend, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") ctx = DynamoStore() def fn(x, l, r): if l > r: return x.sum() mid = (l + r) // 2 if x.sum() == mid: return x.sum() elif x.sum() < mid: return fn(x, l, mid) else: return fn(x, mid + 1, r) def guard_filter_fn(guards): return [ guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH") for guard in guards ] # Saving package = CompilePackage(fn) compiled_fn = torch._dynamo.optimize( backend=backend, package=package, guard_filter_fn=guard_filter_fn )(fn) N = 10 args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)] for args in args_list: compiled_fn(*args) if backend == "eager": for backend_id, backend in package.cached_backends.items(): ctx.record_eager_backend(backend_id, backend) ctx.save_package(package, self.path()) # Loading torch._dynamo.reset() with torch.compiler.set_stance("fail_on_recompile"): for args in args_list: with self.assertRaisesRegex( RuntimeError, "Detected recompile when torch.compile stance is 'fail_on_recompile'", ): compiled_fn(*args) package, backends = ctx.load_package(fn, self.path()) compiled_fn = torch._dynamo.optimize( backend="eager", package=package, guard_filter_fn=guard_filter_fn )(fn) package.install(backends) for args in args_list: self.assertEqual(compiled_fn(*args), args[0].sum()) with self.assertRaisesRegex( RuntimeError, "Detected recompile when torch.compile stance is 'fail_on_recompile'", ): compiled_fn(torch.tensor(N), 0, N - 1) @parametrize("backend", ("eager", "inductor")) @parametrize("device", ("cpu", "cuda")) def test_dynamic_shape(self, backend, device): if device == "cuda" and not HAS_CUDA: raise unittest.SkipTest("Requires CUDA/Triton") ctx = DynamoStore() def fn(x): return x + x.shape[0] args = (torch.randn(3, 2, device=device),) args1 = (torch.randn(5, 2, device=device),) args2 = (torch.randn(7, 2, device=device),) expected1 = fn(*args1) torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5) # Saving package = CompilePackage(fn) compiled_fn = torch._dynamo.optimize(backend=backend, package=package)(fn) compiled_fn(*args) if backend == "eager": for backend_id, backend in package.cached_backends.items(): ctx.record_eager_backend(backend_id, backend) ctx.save_package(package, self.path()) # Loading torch._dynamo.reset() with torch.compiler.set_stance("fail_on_recompile"): with self.assertRaisesRegex( RuntimeError, "Detected recompile when torch.compile stance is 'fail_on_recompile'", ): compiled_fn(*args1) package, backends = ctx.load_package(fn, self.path()) compiled_fn = torch._dynamo.optimize(package=package)(fn) package.install(backends) self.assertEqual(expected1, compiled_fn(*args1)) with self.assertRaisesRegex( RuntimeError, "Detected recompile when torch.compile stance is 'fail_on_recompile'", ): compiled_fn(*args2) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()