192 lines
6.6 KiB
Python
192 lines
6.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
import copy
|
|
|
|
from model_registry import MLPModule
|
|
|
|
import torch
|
|
from torch.distributed.pipelining._backward import (
|
|
stage_backward,
|
|
stage_backward_input,
|
|
stage_backward_weight,
|
|
)
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
|
|
|
|
d_hid = 512
|
|
batch_size = 256
|
|
|
|
|
|
class StageBackwardTests(TestCase):
|
|
def test_stage_backward(self, device):
|
|
# MLP as a stage module
|
|
mod = MLPModule(d_hid).to(device)
|
|
x = torch.randn(batch_size, d_hid, device=device)
|
|
# As in a pipeline stage, the inputs to this stage requires gradients
|
|
x.requires_grad_(True)
|
|
target = torch.randn(batch_size, d_hid, device=device)
|
|
loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
|
|
# Make a copy
|
|
ref_mod = copy.deepcopy(mod).to(device)
|
|
ref_x = x.detach().requires_grad_(x.requires_grad).to(device)
|
|
ref_target = target.detach().to(device)
|
|
|
|
# Forward and backward in stage manner
|
|
out = mod(x)
|
|
loss = loss_fn(out, target)
|
|
grad_inputs = stage_backward(
|
|
stage_output=loss,
|
|
output_grads=None,
|
|
input_values=(x,),
|
|
)
|
|
|
|
# Run reference
|
|
ref_out = ref_mod(ref_x)
|
|
ref_loss = loss_fn(ref_out, ref_target)
|
|
ref_loss.backward()
|
|
|
|
torch.testing.assert_close(grad_inputs[0], ref_x.grad)
|
|
|
|
# Every rank checks gradients
|
|
for name, p in mod.named_parameters():
|
|
ref_p = ref_mod.get_parameter(name)
|
|
try:
|
|
torch.testing.assert_close(p.grad, ref_p.grad)
|
|
except AssertionError:
|
|
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
|
|
raise
|
|
|
|
def test_stage_backward_input(self, device):
|
|
# MLP as a stage module
|
|
mod = MLPModule(d_hid).to(device)
|
|
x = torch.randn(batch_size, d_hid, device=device)
|
|
# As in a pipeline stage, the inputs to this stage requires gradients
|
|
x.requires_grad_(True)
|
|
target = torch.randn(batch_size, d_hid, device=device)
|
|
loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
|
|
# Make a copy
|
|
ref_mod = copy.deepcopy(mod).to(device)
|
|
ref_x = x.detach().requires_grad_(x.requires_grad).to(device)
|
|
ref_target = target.detach().to(device)
|
|
|
|
# Forward, then backward of loss with respect to inputs
|
|
out = mod(x)
|
|
loss = loss_fn(out, target)
|
|
dinputs, _param_groups = stage_backward_input(
|
|
stage_outputs_or_loss=(loss,),
|
|
output_grads=None,
|
|
input_values=[x],
|
|
weights=mod.parameters(),
|
|
)
|
|
|
|
# Run reference
|
|
ref_out = ref_mod(ref_x)
|
|
ref_loss = loss_fn(ref_out, ref_target)
|
|
ref_loss.backward()
|
|
|
|
torch.testing.assert_close(x.grad, ref_x.grad)
|
|
torch.testing.assert_close(dinputs[0], ref_x.grad)
|
|
for _, p in mod.named_parameters():
|
|
# Check that the weight gradients were not updated
|
|
self.assertEqual(p.grad, None)
|
|
|
|
def test_stage_backward_weight(self, device):
|
|
# MLP as a stage module
|
|
mod = MLPModule(d_hid).to(device)
|
|
x = torch.randn(batch_size, d_hid, device=device)
|
|
# As in a pipeline stage, the inputs to this stage requires gradients
|
|
x.requires_grad_(True)
|
|
target = torch.randn(batch_size, d_hid, device=device)
|
|
loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
|
|
# Make a copy
|
|
ref_mod = copy.deepcopy(mod).to(device)
|
|
ref_x = x.detach().requires_grad_(x.requires_grad).to(device)
|
|
ref_target = target.detach().to(device)
|
|
# Forward, then backward of loss with respect to inputs
|
|
out = mod(x)
|
|
loss = loss_fn(out, target)
|
|
_dinputs, param_groups = stage_backward_input(
|
|
stage_outputs_or_loss=(loss,),
|
|
output_grads=None,
|
|
input_values=[x],
|
|
weights=mod.parameters(),
|
|
)
|
|
|
|
# backward of loss with respect to weights
|
|
stage_backward_weight(mod.parameters(), param_groups, retain_graph=True)
|
|
|
|
# Run reference
|
|
ref_out = ref_mod(ref_x)
|
|
ref_loss = loss_fn(ref_out, ref_target)
|
|
ref_loss.backward()
|
|
|
|
# Every rank checks gradients
|
|
for name, p in mod.named_parameters():
|
|
ref_p = ref_mod.get_parameter(name)
|
|
try:
|
|
torch.testing.assert_close(p.grad, ref_p.grad)
|
|
except AssertionError:
|
|
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
|
|
raise
|
|
|
|
def test_stage_backward_weight_multiple_iters(self, device):
|
|
# MLP as a stage module
|
|
mod = MLPModule(d_hid).to(device)
|
|
inputs = []
|
|
for _ in range(10):
|
|
x = torch.randn(batch_size, d_hid, device=device)
|
|
inputs.append(x)
|
|
# As in a pipeline stage, the inputs to this stage requires gradients
|
|
x.requires_grad_(True)
|
|
|
|
target = torch.randn(batch_size, d_hid, device=device)
|
|
loss_fn = torch.nn.MSELoss(reduction="sum")
|
|
|
|
# Make a copy
|
|
ref_mod = copy.deepcopy(mod).to(device)
|
|
ref_inputs = []
|
|
for x in inputs:
|
|
ref_x = x.detach().requires_grad_(x.requires_grad).to(device)
|
|
ref_inputs.append(ref_x)
|
|
ref_target = target.detach().to(device)
|
|
|
|
# Forward, then backward of loss with respect to inputs
|
|
for x in inputs:
|
|
out = mod(x)
|
|
loss = loss_fn(out, target)
|
|
_dinputs, param_groups = stage_backward_input(
|
|
stage_outputs_or_loss=(loss,),
|
|
output_grads=None,
|
|
input_values=[x],
|
|
weights=mod.parameters(),
|
|
)
|
|
|
|
# backward of loss with respect to weights
|
|
stage_backward_weight(mod.parameters(), param_groups)
|
|
|
|
# Run reference
|
|
for ref_x in ref_inputs:
|
|
ref_out = ref_mod(ref_x)
|
|
ref_loss = loss_fn(ref_out, ref_target)
|
|
ref_loss.backward()
|
|
|
|
# Every rank checks gradients
|
|
for name, p in mod.named_parameters():
|
|
ref_p = ref_mod.get_parameter(name)
|
|
try:
|
|
torch.testing.assert_close(p.grad, ref_p.grad)
|
|
except AssertionError:
|
|
print(f"Gradient test failed for {name}: {p.grad} vs {ref_p.grad}")
|
|
raise
|
|
|
|
|
|
devices = ["cpu", "cuda", "hpu", "xpu"]
|
|
instantiate_device_type_tests(StageBackwardTests, globals(), only_for=devices)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|