Skip to content

Commit

Permalink
Match the number of outputs of backward with forward for AllGatherGrad (
Browse files Browse the repository at this point in the history
  • Loading branch information
ArvinZhuang authored and lexierule committed Mar 30, 2021
1 parent 832e771 commit ba370de
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def backward(ctx, *grad_output):

torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)

return grad_output[torch.distributed.get_rank()]
return grad_output[torch.distributed.get_rank()], None


def all_gather_ddp_if_available(
Expand Down
23 changes: 23 additions & 0 deletions tests/utilities/test_all_gather_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,26 @@ def training_epoch_end(self, outputs) -> None:

trainer.fit(model)
assert model.training_epoch_end_called


@RunIf(min_gpus=2, skip_windows=True, special=True)
def test_all_gather_sync_grads(tmpdir):

class TestModel(BoringModel):

training_step_called = False

def training_step(self, batch, batch_idx):
self.training_step_called = True
tensor = torch.rand(2, 2, requires_grad=True, device=self.device)
gathered_tensor = self.all_gather(tensor, sync_grads=True)
assert gathered_tensor.shape == torch.Size([2, 2, 2])

loss = gathered_tensor.sum()

return loss

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=2)
trainer.fit(model)
assert model.training_step_called

0 comments on commit ba370de

Please sign in to comment.