50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch.nn as nn
|
|
|
|
|
|
class TestBuffersOverride(torch._dynamo.test_case.TestCase):
|
|
def test_buffers_override(self):
|
|
class SomeModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# Override buffers; should not cause breakage
|
|
# this is because we use `named_buffers` for
|
|
# static marking
|
|
self.register_buffer("A", torch.ones(3, 3))
|
|
self.buffers = []
|
|
|
|
def forward(self):
|
|
return self.A * torch.zeros(1, 1)
|
|
|
|
model = SomeModel().to(torch.device("cpu"))
|
|
compiled_model = torch.compile(model)
|
|
self.assertEqual(compiled_model.A, torch.ones(3, 3))
|
|
compiled_model()
|
|
|
|
def test_named_buffers_override(self):
|
|
class SomeModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# Override buffers; should not cause breakage
|
|
# but skip the marking static here since
|
|
# named_buffers is overridden
|
|
self.register_buffer("B", torch.ones(3, 3))
|
|
self.named_buffers = []
|
|
|
|
def forward(self):
|
|
return self.B * torch.zeros(1, 1)
|
|
|
|
model = SomeModel().to(torch.device("cpu"))
|
|
compiled_model = torch.compile(model)
|
|
self.assertEqual(compiled_model.B, torch.ones(3, 3))
|
|
compiled_model()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|