From 76eadd424ec5ed94bfd11885f3c357fc409e1021 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 8 Apr 2021 21:29:17 +0530 Subject: [PATCH 1/7] Fix Reduce error for TPU Spawn --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 68068935127e2..48166e22c63d8 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -154,15 +154,9 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj = torch.load(buffer) return obj - def reduce_boolean_decision(self, decision: bool) -> bool: - decision = torch.tensor(int(decision), device=self.device) - decision = self.reduce(decision, "sum") - decision = bool(decision == self.world_size) - return decision - def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): if not isinstance(output, torch.Tensor): - output = torch.tensor(output, device=self.device) + output = torch.tensor(output, device=self.lightning_module.device) _invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM _invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg") From 78a2684a801686e3ab364f07714ac63e05b64872 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 9 Apr 2021 00:10:08 +0530 Subject: [PATCH 2/7] Fix all gather --- pytorch_lightning/accelerators/tpu.py | 15 --------------- .../plugins/training_type/tpu_spawn.py | 12 ++++++++++++ 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 35a475e3e790d..ee1f5dcb9c0a2 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -57,21 +57,6 @@ def run_optimizer_step( ) -> None: xm.optimizer_step(optimizer, barrier=False, optimizer_args={'closure': lambda_closure, **kwargs}) - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: - """ - Function to gather a tensor from several distributed processes - Args: - tensor: tensor of shape (batch, ...) - group: not available with TPUs - sync_grads: not available with TPUs - Return: - A tensor of shape (world_size, batch, ...) - """ - # todo: Add support for backward with all_gather - if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: - return xm.all_gather(tensor).view(-1, *tensor.shape) - return tensor - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0): model = self.lightning_module diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 48166e22c63d8..418952563f237 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -221,3 +221,15 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: if _OMEGACONF_AVAILABLE: checkpoint = apply_to_collection(checkpoint, (DictConfig, ListConfig), OmegaConf.to_container) self.save({k: v for k, v in checkpoint.items() if k != "callbacks"}, filepath) + + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: + """ + Function to gather a tensor from several distributed processes + Args: + tensor: tensor of shape (batch, ...) + group: not available with TPUs + sync_grads: not available with TPUs + Return: + A tensor of shape (world_size, batch, ...) + """ + return xm.all_gather(tensor.unsqueeze(0)).view(-1, *tensor.shape) From 4a85dde4c69faa38a0a8083e7d5a5f518e841c1e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 9 Apr 2021 00:14:50 +0530 Subject: [PATCH 3/7] Update reduce bool decision --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 418952563f237..175af76f2eff3 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -154,6 +154,12 @@ def broadcast(self, obj: object, src: int = 0) -> object: obj = torch.load(buffer) return obj + def reduce_boolean_decision(self, decision: bool) -> bool: + decision = torch.tensor(int(decision), device=self.lightning_module.device) + decision = self.reduce(decision, reduce_op="sum") + decision = bool(decision == self.world_size) + return decision + def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None): if not isinstance(output, torch.Tensor): output = torch.tensor(output, device=self.lightning_module.device) From c4f3706a2cccbf85889170e0d5b7f643e61c571e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 9 Apr 2021 01:11:58 +0530 Subject: [PATCH 4/7] Update --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 175af76f2eff3..f5c38dc4d4e46 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -238,4 +238,4 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra Return: A tensor of shape (world_size, batch, ...) """ - return xm.all_gather(tensor.unsqueeze(0)).view(-1, *tensor.shape) + return xm.all_gather(tensor.unsqueeze(0)) From 31435d5b538b1106e6832591b2eac952d401cb4d Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 9 Apr 2021 01:37:30 +0530 Subject: [PATCH 5/7] fix code format --- pytorch_lightning/accelerators/tpu.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index ee1f5dcb9c0a2..087f6df7a1c6a 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -11,9 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, TYPE_CHECKING, Union -import torch from torch.optim import Optimizer from pytorch_lightning.accelerators.accelerator import Accelerator From 2b439a487cf8379f5626dcb7a0986791a8af1dd6 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 9 Apr 2021 14:06:59 +0530 Subject: [PATCH 6/7] Update Changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9eda878028e21..348aa04744bb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -218,6 +218,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705)) +- Fixed TPU Spawn all gather ([#6896](https://github.com/PyTorchLightning/pytorch-lightning/pull/6896)) + + ## [1.2.7] - 2021-04-06 ### Fixed From 0d6e27bb750bf68ab4476f5a75ba6c67e54d420c Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Fri, 9 Apr 2021 18:04:42 +0530 Subject: [PATCH 7/7] Update test --- tests/models/test_tpu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 30149a7e022cd..6409f2ef4bcbf 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -229,8 +229,8 @@ def test_tpu_clip_grad_by_value(tmpdir): progress_bar_refresh_rate=0, max_epochs=4, tpu_cores=1, - limit_train_batches=4, - limit_val_batches=4, + limit_train_batches=10, + limit_val_batches=10, gradient_clip_val=0.5, gradient_clip_algorithm='value' )