# Owner(s): ["module: dynamo"] import contextlib import importlib.util import os import re import tempfile import torch._dynamo.config import torch._dynamo.test_case import torch._inductor.mock_cache as mock_cache import torch.compiler.config import torch.nested from torch._dynamo.testing import CompileCounter from torch._inductor.utils import clear_caches, fresh_cache class PgoTest(torch._dynamo.test_case.TestCase): def setUp(self): super().setUp() self._test_stack = contextlib.ExitStack() self._test_stack.enter_context(torch.compiler.config.patch(job_id=self.id())) self._test_stack.enter_context( torch._dynamo.config.patch(automatic_dynamic_local_pgo=True) ) if os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1": self._test_stack.enter_context(fresh_cache()) mock_cache.PatchCaches.setUp() def tearDown(self): super().tearDown() torch._dynamo.reset() self._test_stack.close() mock_cache.PatchCaches.tearDown() def reset(self): torch._dynamo.reset() clear_caches() def test_basic(self): cnts = CompileCounter() @torch.compile(backend=cnts, fullgraph=True) def f(x): return x * 2 f(torch.randn(2, 3)) f(torch.randn(2, 4)) self.assertEqual(cnts.frame_count, 2) self.reset() cnts.clear() f(torch.randn(2, 5)) f(torch.randn(2, 6)) self.assertEqual(cnts.frame_count, 1) def test_whitelist_suggestion(self): cnts = CompileCounter() @torch.compile(backend=cnts, fullgraph=True) class Foo(torch.nn.Module): def __init__(self): super().__init__() self.lin = torch.nn.Linear(4, 4) self.attr = torch.randn(4) def forward(self, x, y): return self.lin(x) + self.attr + y sources = [ "L['x']", "L['self']._modules['lin']._parameters['weight']", "L['self']._modules['lin']._parameters['bias']", "L['self'].attr", "L['y']", ] def check_whitelist(sources_): state = torch._dynamo.pgo.render_code_state( torch._dynamo.pgo.get_code_state() ) whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group( 1 ) for src in sources_: self.assertTrue(src in whitelist) # check growing whitelist f = Foo() f(torch.randn(2, 4), torch.randn(4)) # only x f(torch.randn(4, 4), torch.randn(4)) check_whitelist(sources[:1]) # x, lin.weight f.lin = torch.nn.Linear(8, 4) f(torch.randn(8, 8), torch.randn(4)) check_whitelist(sources[:2]) # x, y, lin.weight, lin.bias, attr f.lin = torch.nn.Linear(8, 8) f.attr = torch.randn(8) f(torch.randn(8, 8), torch.randn(8)) check_whitelist(sources) # now use suggested whitelist self.reset() cnts.clear() state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state()) whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(1) with torch.compiler.config.patch(dynamic_sources=whitelist): f = Foo() f(torch.randn(2, 4), torch.randn(4)) f(torch.randn(4, 4), torch.randn(4)) f.lin = torch.nn.Linear(8, 8) f.attr = torch.randn(8) f(torch.randn(8, 8), torch.randn(8)) self.assertEqual(cnts.frame_count, 1) def test_pgo_dynamic_false(self): @torch.compile(backend="eager", dynamic=False) class Foo(torch.nn.Module): def forward(self, x, y): x += 2 y += 2 torch._dynamo.graph_break() x -= 2 y *= 2 return x, y self.reset() f = Foo() f(torch.randn(2, 4), torch.randn(2, 4)) f(torch.randn(4, 4), torch.randn(6, 8)) # check PGO code state is overwritten with static value, both before/after graph break for code_state in torch._dynamo.pgo.get_code_state().values(): self.assertEqual(code_state.automatic_dynamic["L['x']"].size, (4, 4)) self.assertEqual(code_state.automatic_dynamic["L['y']"].size, (6, 8)) def test_whitelist_ints_floats(self): @torch.compile(backend="eager", fullgraph=True) class Bar(torch.nn.Module): def __init__(self, c): super().__init__() self.c = c def forward(self, x, y, z): if self.c == 1.0: return x + y + torch.tensor([z]) f = Bar(1.0) f(2, 1.0, 2.0) f(3, 1.2, 2.0) state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state()) whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(1) self.assertTrue("L['x']" in whitelist) self.assertTrue("L['y']" in whitelist) self.assertTrue( "___as_tensor(L['y'])" not in whitelist ) # ephemeral FloatTensor source self.assertTrue("L['z']" not in whitelist) # static float self.assertTrue("L['self'].c" not in whitelist) # static float property def test_pgo_dynamic_params(self): cnts = CompileCounter() @torch.compile(backend=cnts, fullgraph=True) class Foo(torch.nn.Module): def __init__(self): super().__init__() self.lin = None def forward(self, x): return self.lin(x) f = Foo() def run(): self.reset() cnts.clear() f.lin = torch.nn.Linear(4, 4) f(torch.randn(2, 4)) f(torch.randn(4, 4)) f.lin = torch.nn.Linear(8, 8) f(torch.randn(8, 8)) # recompile each run run() self.assertEqual(cnts.frame_count, 3) # parameter static shapes are forced static, so we recompile once run() self.assertEqual(cnts.frame_count, 2) # flags are flipped, PGO records dynamism, so params are dynamically compiled to start torch._dynamo.config.force_parameter_static_shapes = False torch._dynamo.config.force_nn_module_property_static_shapes = False run() self.assertEqual(cnts.frame_count, 1) def test_njt(self): cnts = CompileCounter() # NB: PGO doesn't do anything here, the point is to catch pickle # problem with nested int @torch.compile(backend=cnts, fullgraph=True) def f(x): return x * 2 x = torch.nested.nested_tensor_from_jagged( torch.randn(10, 3), torch.tensor([0, 3, 7, 10]), torch.tensor([1, 2, 3]) ) y = torch.nested.nested_tensor_from_jagged( torch.randn(13, 3), torch.tensor([0, 3, 7, 13]), torch.tensor([1, 2, 6]) ) f(x) f(y) self.assertEqual(cnts.frame_count, 1) self.reset() cnts.clear() a = torch.nested.nested_tensor_from_jagged( torch.randn(14, 3), torch.tensor([0, 3, 7, 14]), torch.tensor([1, 2, 7]) ) b = torch.nested.nested_tensor_from_jagged( torch.randn(15, 3), torch.tensor([0, 3, 7, 15]), torch.tensor([1, 2, 8]) ) f(a) f(b) self.assertEqual(cnts.frame_count, 1) def test_distinct_compile_id(self): cnts = CompileCounter() @torch.compile(backend=cnts, fullgraph=True) def f(x): return x * 2 with torch.compiler.config.patch(job_id="foo"): f(torch.randn(2, 3)) f(torch.randn(2, 4)) self.assertEqual(cnts.frame_count, 2) self.reset() cnts.clear() with torch.compiler.config.patch(job_id="bar"): f(torch.randn(2, 5)) f(torch.randn(2, 6)) self.assertEqual(cnts.frame_count, 2) torch._dynamo.reset() clear_caches() cnts.clear() with torch.compiler.config.patch(job_id="foo"): f(torch.randn(2, 7)) f(torch.randn(2, 8)) self.assertEqual(cnts.frame_count, 1) # TODO: to test local need to ensure the local filesystem gets cleared out @torch._dynamo.config.patch( automatic_dynamic_remote_pgo=True, automatic_dynamic_local_pgo=False ) def test_remote_basic(self): cnts = CompileCounter() @torch.compile(backend=cnts, fullgraph=True) def f(x): return x * 2 with mock_cache.PatchCaches(): f(torch.randn(2, 3)) f(torch.randn(2, 4)) self.assertEqual(cnts.frame_count, 2) self.assertEqual( mock_cache.global_stats.dynamo_pgo, mock_cache.Stats(2, 0, 1) ) self.reset() cnts.clear() f(torch.randn(2, 5)) f(torch.randn(2, 6)) self.assertEqual(cnts.frame_count, 1) self.assertEqual( mock_cache.global_stats.dynamo_pgo, mock_cache.Stats(2, 1, 1) ) self.reset() cnts.clear() with torch.compiler.config.patch({"cache_key_tag": "test"}): f(torch.randn(2, 7)) f(torch.randn(2, 8)) self.assertEqual(cnts.frame_count, 2) self.assertEqual( mock_cache.global_stats.dynamo_pgo, mock_cache.Stats(4, 1, 2) ) # Test that if the same file appears in two different paths for two different compilations PGO still works. def test_different_file_paths_local_pgo(self): content = """ import torch def run(cnt): @torch.compile(backend=cnt, fullgraph=True) def func(x): return x*10 func(torch.rand(10)) func(torch.rand(20)) func(torch.rand(30)) """ temp_dir1 = tempfile.TemporaryDirectory() temp_dir2 = tempfile.TemporaryDirectory() path1 = os.path.join(temp_dir1.name, "example.py") path2 = os.path.join(temp_dir2.name, "example.py") cnts = CompileCounter() assert path1 != path2 def write_load_and_run(path): with open(path, "w") as file: file.write(content) spec = importlib.util.spec_from_file_location("example", path1) assert spec is not None module = importlib.util.module_from_spec(spec) assert spec.loader is not None spec.loader.exec_module(module) module.run(cnts) write_load_and_run(path1) self.assertEqual(cnts.frame_count, 2) state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state()) self.assertTrue("hash(390fe689)" in state) self.assertTrue("/example.py:4:func:" in state) self.assertTrue(" L['x']: tensor size=[?] stride=[1]" in state) # We should compile this only once due to PGO. cnts.clear() write_load_and_run(path2) self.assertEqual(cnts.frame_count, 1) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()