sglang_v0.5.2/pytorch_2.8.0/test/dynamo/test_package.py

185 lines
6.2 KiB
Python

# Owner(s): ["module: dynamo"]
import os
import unittest
import torch
import torch._dynamo.testing
import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.package import CompilePackage, DynamoStore
from torch._functorch import config as functorch_config
from torch._inductor.runtime.runtime_utils import cache_dir
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
from torch.testing._internal.inductor_utils import HAS_CUDA
@functorch_config.patch("bundled_autograd_cache", True)
@instantiate_parametrized_tests
class TestPackage(torch._inductor.test_case.TestCase):
def path(self):
path = os.path.join(cache_dir(), f"package_{self.id()}")
os.makedirs(path, exist_ok=True)
return path
@parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda"))
def test_basic_fn(self, backend, device):
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton")
ctx = DynamoStore()
def fn(x):
return x + 1
args = (
torch.randn(
3,
2,
device=device,
),
)
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(backend, package=package)(fn)
expected = compiled_fn(*args)
if backend == "eager":
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args)
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(backends)
self.assertEqual(expected, compiled_fn(*args))
@parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda"))
def test_graph_break_bomb(self, backend, device):
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton")
ctx = DynamoStore()
def fn(x, l, r):
if l > r:
return x.sum()
mid = (l + r) // 2
if x.sum() == mid:
return x.sum()
elif x.sum() < mid:
return fn(x, l, mid)
else:
return fn(x, mid + 1, r)
def guard_filter_fn(guards):
return [
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
for guard in guards
]
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(
backend=backend, package=package, guard_filter_fn=guard_filter_fn
)(fn)
N = 10
args_list = [(torch.tensor(x, device=device), 0, N - 1) for x in range(N)]
for args in args_list:
compiled_fn(*args)
if backend == "eager":
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
for args in args_list:
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args)
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(
backend="eager", package=package, guard_filter_fn=guard_filter_fn
)(fn)
package.install(backends)
for args in args_list:
self.assertEqual(compiled_fn(*args), args[0].sum())
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(torch.tensor(N), 0, N - 1)
@parametrize("backend", ("eager", "inductor"))
@parametrize("device", ("cpu", "cuda"))
def test_dynamic_shape(self, backend, device):
if device == "cuda" and not HAS_CUDA:
raise unittest.SkipTest("Requires CUDA/Triton")
ctx = DynamoStore()
def fn(x):
return x + x.shape[0]
args = (torch.randn(3, 2, device=device),)
args1 = (torch.randn(5, 2, device=device),)
args2 = (torch.randn(7, 2, device=device),)
expected1 = fn(*args1)
torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5)
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(backend=backend, package=package)(fn)
compiled_fn(*args)
if backend == "eager":
for backend_id, backend in package.cached_backends.items():
ctx.record_eager_backend(backend_id, backend)
ctx.save_package(package, self.path())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args1)
package, backends = ctx.load_package(fn, self.path())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(backends)
self.assertEqual(expected1, compiled_fn(*args1))
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args2)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()