Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Match the number of outputs of backward with forward for AllGatherGrad #6625

Merged
merged 4 commits into from
Mar 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def backward(ctx, *grad_output):

torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)

return grad_output[torch.distributed.get_rank()]
return grad_output[torch.distributed.get_rank()], None


def all_gather_ddp_if_available(
Expand Down
23 changes: 23 additions & 0 deletions tests/utilities/test_all_gather_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,26 @@ def training_epoch_end(self, outputs) -> None:

trainer.fit(model)
assert model.training_epoch_end_called


@RunIf(min_gpus=2, skip_windows=True, special=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this won't run
we need to add add it to tests/special_tests.sh
special=True means this test gets skipped by default
and skip_windows would be redundant in that case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I noticed this, I actually commented out special=True when I test this. I put it here just because I was following the same settings of other tests in this script.

def test_all_gather_sync_grads(tmpdir):

class TestModel(BoringModel):

training_step_called = False

def training_step(self, batch, batch_idx):
self.training_step_called = True
tensor = torch.rand(2, 2, requires_grad=True, device=self.device)
gathered_tensor = self.all_gather(tensor, sync_grads=True)
assert gathered_tensor.shape == torch.Size([2, 2, 2])

loss = gathered_tensor.sum()

return loss

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2)
trainer.fit(model)
assert model.training_step_called