Skip to content

Commit

Permalink
Fix all_gather for tpu_cores=8 (#6587)
Browse files Browse the repository at this point in the history
(cherry picked from commit 983a888)
  • Loading branch information
ethanwharris authored and lexierule committed Mar 24, 2021
1 parent b895dd9 commit f4a2dff
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
35 changes: 23 additions & 12 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,29 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))


- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541))


- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))


- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))


- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)



## [1.2.5] - 2021-03-23

### Changed


### Fixed

- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587))


## [1.2.4] - 2021-03-16

### Changed
Expand All @@ -139,9 +162,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541))


- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541))


## [1.2.3] - 2021-03-09

### Fixed
Expand Down Expand Up @@ -180,9 +200,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)


- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))


## [1.2.1] - 2021-02-23

### Fixed
Expand All @@ -192,12 +209,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107))


- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))


- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)


## [1.2.0] - 2021-02-18

### Added
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
Function to gather a tensor from several distributed processes
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
group: not available with TPUs
sync_grads: not available with TPUs
Return:
A tensor of shape (world_size, batch, ...)
"""
# todo: Add support for backward with all_gather
if torch.distributed.is_initialized():
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
return xm.all_gather(tensor).view(-1, *tensor.shape)
return tensor

0 comments on commit f4a2dff

Please sign in to comment.