Skip to content

Commit

Permalink
[bugfix] TPU + all_gather + SingleTPU shouldn't call xm.all_gather (#…
Browse files Browse the repository at this point in the history
…6296)

* resolve an issue with TPU

* update

* add changelog
  • Loading branch information
tchaton authored Mar 3, 2021
1 parent 4a8422c commit 484dce1
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))


- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296))


- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)


Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
Return:
A tensor of shape (world_size, batch, ...)
"""
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
# todo: Add support for backward with all_gather
if torch.distributed.is_initialized():
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
return tensor

0 comments on commit 484dce1

Please sign in to comment.