# Owner(s): ["module: dynamo"] import torch import torch._dynamo.test_case import torch.nn as nn class TestBuffersOverride(torch._dynamo.test_case.TestCase): def test_buffers_override(self): class SomeModel(nn.Module): def __init__(self): super().__init__() # Override buffers; should not cause breakage # this is because we use `named_buffers` for # static marking self.register_buffer("A", torch.ones(3, 3)) self.buffers = [] def forward(self): return self.A * torch.zeros(1, 1) model = SomeModel().to(torch.device("cpu")) compiled_model = torch.compile(model) self.assertEqual(compiled_model.A, torch.ones(3, 3)) compiled_model() def test_named_buffers_override(self): class SomeModel(nn.Module): def __init__(self): super().__init__() # Override buffers; should not cause breakage # but skip the marking static here since # named_buffers is overridden self.register_buffer("B", torch.ones(3, 3)) self.named_buffers = [] def forward(self): return self.B * torch.zeros(1, 1) model = SomeModel().to(torch.device("cpu")) compiled_model = torch.compile(model) self.assertEqual(compiled_model.B, torch.ones(3, 3)) compiled_model() if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()