185 lines
6.2 KiB
Python
185 lines
6.2 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import os
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo.testing
|
|
import torch._inductor.config
|
|
import torch._inductor.test_case
|
|
import torch.onnx.operators
|
|
import torch.utils.cpp_extension
|
|
from torch._dynamo.package import CompilePackage, DynamoStore
|
|
from torch._functorch import config as functorch_config
|
|
from torch._inductor.runtime.runtime_utils import cache_dir
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
)
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
|
|
|
|
|
@functorch_config.patch("bundled_autograd_cache", True)
|
|
@instantiate_parametrized_tests
|
|
class TestPackage(torch._inductor.test_case.TestCase):
|
|
def path(self):
|
|
path = os.path.join(cache_dir(), f"package_{self.id()}")
|
|
os.makedirs(path, exist_ok=True)
|
|
return path
|
|
|
|
@parametrize("backend", ("eager", "inductor"))
|
|
@parametrize("device", ("cpu", "cuda"))
|
|
def test_basic_fn(self, backend, device):
|
|
if device == "cuda" and not HAS_CUDA:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
ctx = DynamoStore()
|
|
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
args = (
|
|
torch.randn(
|
|
3,
|
|
2,
|
|
device=device,
|
|
),
|
|
)
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(backend, package=package)(fn)
|
|
expected = compiled_fn(*args)
|
|
if backend == "eager":
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
|
|
ctx.save_package(package, self.path())
|
|
# Loading
|
|
torch._dynamo.reset()
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args)
|
|
|
|
package, backends = ctx.load_package(fn, self.path())
|
|
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
|
package.install(backends)
|
|
self.assertEqual(expected, compiled_fn(*args))
|
|
|
|
@parametrize("backend", ("eager", "inductor"))
|
|
@parametrize("device", ("cpu", "cuda"))
|
|
def test_graph_break_bomb(self, backend, device):
|
|
if device == "cuda" and not HAS_CUDA:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
|
|
ctx = DynamoStore()
|
|
|
|
def fn(x, l, r):
|
|
if l > r:
|
|
return x.sum()
|
|
mid = (l + r) // 2
|
|
if x.sum() == mid:
|
|
return x.sum()
|
|
elif x.sum() < mid:
|
|
return fn(x, l, mid)
|
|
else:
|
|
return fn(x, mid + 1, r)
|
|
|
|
def guard_filter_fn(guards):
|
|
return [
|
|
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
|
|
for guard in guards
|
|
]
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(
|
|
backend=backend, package=package, guard_filter_fn=guard_filter_fn
|
|
)(fn)
|
|
N = 10
|
|
args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)]
|
|
for args in args_list:
|
|
compiled_fn(*args)
|
|
if backend == "eager":
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
ctx.save_package(package, self.path())
|
|
|
|
# Loading
|
|
torch._dynamo.reset()
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
for args in args_list:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args)
|
|
package, backends = ctx.load_package(fn, self.path())
|
|
compiled_fn = torch._dynamo.optimize(
|
|
backend="eager", package=package, guard_filter_fn=guard_filter_fn
|
|
)(fn)
|
|
package.install(backends)
|
|
for args in args_list:
|
|
self.assertEqual(compiled_fn(*args), args[0].sum())
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(torch.tensor(N), 0, N - 1)
|
|
|
|
@parametrize("backend", ("eager", "inductor"))
|
|
@parametrize("device", ("cpu", "cuda"))
|
|
def test_dynamic_shape(self, backend, device):
|
|
if device == "cuda" and not HAS_CUDA:
|
|
raise unittest.SkipTest("Requires CUDA/Triton")
|
|
ctx = DynamoStore()
|
|
|
|
def fn(x):
|
|
return x + x.shape[0]
|
|
|
|
args = (torch.randn(3, 2, device=device),)
|
|
args1 = (torch.randn(5, 2, device=device),)
|
|
args2 = (torch.randn(7, 2, device=device),)
|
|
expected1 = fn(*args1)
|
|
|
|
torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5)
|
|
|
|
# Saving
|
|
package = CompilePackage(fn)
|
|
compiled_fn = torch._dynamo.optimize(backend=backend, package=package)(fn)
|
|
compiled_fn(*args)
|
|
if backend == "eager":
|
|
for backend_id, backend in package.cached_backends.items():
|
|
ctx.record_eager_backend(backend_id, backend)
|
|
ctx.save_package(package, self.path())
|
|
|
|
# Loading
|
|
torch._dynamo.reset()
|
|
with torch.compiler.set_stance("fail_on_recompile"):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args1)
|
|
|
|
package, backends = ctx.load_package(fn, self.path())
|
|
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
|
package.install(backends)
|
|
|
|
self.assertEqual(expected1, compiled_fn(*args1))
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
|
):
|
|
compiled_fn(*args2)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|