# Copyright (c) Meta Platforms, Inc. and affiliates # Owner(s): ["oncall: distributed"] import torch from torch.distributed.pipelining import pipe_split, pipeline from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import run_tests, TestCase # Building block for model class Block(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d( in_channels=16, out_channels=16, kernel_size=3, padding=1 ) self.lin0 = torch.nn.Linear(256, 256) self.relu = torch.nn.ReLU() self.lin1 = torch.nn.Linear(256, 256) def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor: x = self.conv(x) x = self.lin0(x) pipe_split() x.add(constant) x = self.lin1(x) return self.relu(x) # Full model class M(torch.nn.Module): def __init__(self) -> None: super().__init__() self.block0 = Block() self.block1 = Block() def forward(self, x: torch.Tensor, constant=None) -> torch.Tensor: x = self.block0(x, constant=constant) pipe_split() x = self.block1(x, constant=constant) return x class UnflattenTests(TestCase): def test_unflatten(self, device): x = torch.randn(1, 16, 256, 256, device=device) constant = torch.ones(1, 16, 256, 256, device=device) mod = M().to(device) pipe = pipeline( mod, (x,), {"constant": constant}, ) assert pipe.num_stages == 4 orig_state_dict = mod.state_dict() # Check qualnames for stage_idx in range(pipe.num_stages): stage_mod = pipe.get_stage_module(stage_idx) for param_name, _ in stage_mod.named_parameters(): assert param_name in orig_state_dict, ( f"{param_name} not in original state dict" ) print("Param qualname test passed") # Check equivalence ref = mod(x, constant) out = pipe(x, constant)[0] torch.testing.assert_close(out, ref) print(f"Equivalence test passed {torch.sum(out)} ref {torch.sum(ref)}") devices = ["cpu", "cuda", "hpu", "xpu"] instantiate_device_type_tests(UnflattenTests, globals(), only_for=devices) if __name__ == "__main__": run_tests()