sglang_v0.5.2/pytorch_2.8.0/test/dynamo/test_pgo.py

352 lines
11 KiB
Python

# Owner(s): ["module: dynamo"]
import contextlib
import importlib.util
import os
import re
import tempfile
import torch._dynamo.config
import torch._dynamo.test_case
import torch._inductor.mock_cache as mock_cache
import torch.compiler.config
import torch.nested
from torch._dynamo.testing import CompileCounter
from torch._inductor.utils import clear_caches, fresh_cache
class PgoTest(torch._dynamo.test_case.TestCase):
def setUp(self):
super().setUp()
self._test_stack = contextlib.ExitStack()
self._test_stack.enter_context(torch.compiler.config.patch(job_id=self.id()))
self._test_stack.enter_context(
torch._dynamo.config.patch(automatic_dynamic_local_pgo=True)
)
if os.environ.get("INDUCTOR_TEST_DISABLE_FRESH_CACHE") != "1":
self._test_stack.enter_context(fresh_cache())
mock_cache.PatchCaches.setUp()
def tearDown(self):
super().tearDown()
torch._dynamo.reset()
self._test_stack.close()
mock_cache.PatchCaches.tearDown()
def reset(self):
torch._dynamo.reset()
clear_caches()
def test_basic(self):
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def f(x):
return x * 2
f(torch.randn(2, 3))
f(torch.randn(2, 4))
self.assertEqual(cnts.frame_count, 2)
self.reset()
cnts.clear()
f(torch.randn(2, 5))
f(torch.randn(2, 6))
self.assertEqual(cnts.frame_count, 1)
def test_whitelist_suggestion(self):
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(4, 4)
self.attr = torch.randn(4)
def forward(self, x, y):
return self.lin(x) + self.attr + y
sources = [
"L['x']",
"L['self']._modules['lin']._parameters['weight']",
"L['self']._modules['lin']._parameters['bias']",
"L['self'].attr",
"L['y']",
]
def check_whitelist(sources_):
state = torch._dynamo.pgo.render_code_state(
torch._dynamo.pgo.get_code_state()
)
whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(
1
)
for src in sources_:
self.assertTrue(src in whitelist)
# check growing whitelist
f = Foo()
f(torch.randn(2, 4), torch.randn(4))
# only x
f(torch.randn(4, 4), torch.randn(4))
check_whitelist(sources[:1])
# x, lin.weight
f.lin = torch.nn.Linear(8, 4)
f(torch.randn(8, 8), torch.randn(4))
check_whitelist(sources[:2])
# x, y, lin.weight, lin.bias, attr
f.lin = torch.nn.Linear(8, 8)
f.attr = torch.randn(8)
f(torch.randn(8, 8), torch.randn(8))
check_whitelist(sources)
# now use suggested whitelist
self.reset()
cnts.clear()
state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state())
whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(1)
with torch.compiler.config.patch(dynamic_sources=whitelist):
f = Foo()
f(torch.randn(2, 4), torch.randn(4))
f(torch.randn(4, 4), torch.randn(4))
f.lin = torch.nn.Linear(8, 8)
f.attr = torch.randn(8)
f(torch.randn(8, 8), torch.randn(8))
self.assertEqual(cnts.frame_count, 1)
def test_pgo_dynamic_false(self):
@torch.compile(backend="eager", dynamic=False)
class Foo(torch.nn.Module):
def forward(self, x, y):
x += 2
y += 2
torch._dynamo.graph_break()
x -= 2
y *= 2
return x, y
self.reset()
f = Foo()
f(torch.randn(2, 4), torch.randn(2, 4))
f(torch.randn(4, 4), torch.randn(6, 8))
# check PGO code state is overwritten with static value, both before/after graph break
for code_state in torch._dynamo.pgo.get_code_state().values():
self.assertEqual(code_state.automatic_dynamic["L['x']"].size, (4, 4))
self.assertEqual(code_state.automatic_dynamic["L['y']"].size, (6, 8))
def test_whitelist_ints_floats(self):
@torch.compile(backend="eager", fullgraph=True)
class Bar(torch.nn.Module):
def __init__(self, c):
super().__init__()
self.c = c
def forward(self, x, y, z):
if self.c == 1.0:
return x + y + torch.tensor([z])
f = Bar(1.0)
f(2, 1.0, 2.0)
f(3, 1.2, 2.0)
state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state())
whitelist = re.search(r'TORCH_COMPILE_DYNAMIC_SOURCES="(.*)"', state).group(1)
self.assertTrue("L['x']" in whitelist)
self.assertTrue("L['y']" in whitelist)
self.assertTrue(
"___as_tensor(L['y'])" not in whitelist
) # ephemeral FloatTensor source
self.assertTrue("L['z']" not in whitelist) # static float
self.assertTrue("L['self'].c" not in whitelist) # static float property
def test_pgo_dynamic_params(self):
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = None
def forward(self, x):
return self.lin(x)
f = Foo()
def run():
self.reset()
cnts.clear()
f.lin = torch.nn.Linear(4, 4)
f(torch.randn(2, 4))
f(torch.randn(4, 4))
f.lin = torch.nn.Linear(8, 8)
f(torch.randn(8, 8))
# recompile each run
run()
self.assertEqual(cnts.frame_count, 3)
# parameter static shapes are forced static, so we recompile once
run()
self.assertEqual(cnts.frame_count, 2)
# flags are flipped, PGO records dynamism, so params are dynamically compiled to start
torch._dynamo.config.force_parameter_static_shapes = False
torch._dynamo.config.force_nn_module_property_static_shapes = False
run()
self.assertEqual(cnts.frame_count, 1)
def test_njt(self):
cnts = CompileCounter()
# NB: PGO doesn't do anything here, the point is to catch pickle
# problem with nested int
@torch.compile(backend=cnts, fullgraph=True)
def f(x):
return x * 2
x = torch.nested.nested_tensor_from_jagged(
torch.randn(10, 3), torch.tensor([0, 3, 7, 10]), torch.tensor([1, 2, 3])
)
y = torch.nested.nested_tensor_from_jagged(
torch.randn(13, 3), torch.tensor([0, 3, 7, 13]), torch.tensor([1, 2, 6])
)
f(x)
f(y)
self.assertEqual(cnts.frame_count, 1)
self.reset()
cnts.clear()
a = torch.nested.nested_tensor_from_jagged(
torch.randn(14, 3), torch.tensor([0, 3, 7, 14]), torch.tensor([1, 2, 7])
)
b = torch.nested.nested_tensor_from_jagged(
torch.randn(15, 3), torch.tensor([0, 3, 7, 15]), torch.tensor([1, 2, 8])
)
f(a)
f(b)
self.assertEqual(cnts.frame_count, 1)
def test_distinct_compile_id(self):
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def f(x):
return x * 2
with torch.compiler.config.patch(job_id="foo"):
f(torch.randn(2, 3))
f(torch.randn(2, 4))
self.assertEqual(cnts.frame_count, 2)
self.reset()
cnts.clear()
with torch.compiler.config.patch(job_id="bar"):
f(torch.randn(2, 5))
f(torch.randn(2, 6))
self.assertEqual(cnts.frame_count, 2)
torch._dynamo.reset()
clear_caches()
cnts.clear()
with torch.compiler.config.patch(job_id="foo"):
f(torch.randn(2, 7))
f(torch.randn(2, 8))
self.assertEqual(cnts.frame_count, 1)
# TODO: to test local need to ensure the local filesystem gets cleared out
@torch._dynamo.config.patch(
automatic_dynamic_remote_pgo=True, automatic_dynamic_local_pgo=False
)
def test_remote_basic(self):
cnts = CompileCounter()
@torch.compile(backend=cnts, fullgraph=True)
def f(x):
return x * 2
with mock_cache.PatchCaches():
f(torch.randn(2, 3))
f(torch.randn(2, 4))
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(
mock_cache.global_stats.dynamo_pgo, mock_cache.Stats(2, 0, 1)
)
self.reset()
cnts.clear()
f(torch.randn(2, 5))
f(torch.randn(2, 6))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(
mock_cache.global_stats.dynamo_pgo, mock_cache.Stats(2, 1, 1)
)
self.reset()
cnts.clear()
with torch.compiler.config.patch({"cache_key_tag": "test"}):
f(torch.randn(2, 7))
f(torch.randn(2, 8))
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(
mock_cache.global_stats.dynamo_pgo, mock_cache.Stats(4, 1, 2)
)
# Test that if the same file appears in two different paths for two different compilations PGO still works.
def test_different_file_paths_local_pgo(self):
content = """
import torch
def run(cnt):
@torch.compile(backend=cnt, fullgraph=True)
def func(x):
return x*10
func(torch.rand(10))
func(torch.rand(20))
func(torch.rand(30))
"""
temp_dir1 = tempfile.TemporaryDirectory()
temp_dir2 = tempfile.TemporaryDirectory()
path1 = os.path.join(temp_dir1.name, "example.py")
path2 = os.path.join(temp_dir2.name, "example.py")
cnts = CompileCounter()
assert path1 != path2
def write_load_and_run(path):
with open(path, "w") as file:
file.write(content)
spec = importlib.util.spec_from_file_location("example", path1)
assert spec is not None
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
module.run(cnts)
write_load_and_run(path1)
self.assertEqual(cnts.frame_count, 2)
state = torch._dynamo.pgo.render_code_state(torch._dynamo.pgo.get_code_state())
self.assertTrue("hash(390fe689)" in state)
self.assertTrue("/example.py:4:func:" in state)
self.assertTrue(" L['x']: tensor size=[?] stride=[1]" in state)
# We should compile this only once due to PGO.
cnts.clear()
write_load_and_run(path2)
self.assertEqual(cnts.frame_count, 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()