diff --git a/CHANGELOG.md b/CHANGELOG.md index d2a0d3641b40b..bd8f5e31770d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -167,6 +167,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587)) +- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576)) + + ## [1.2.3] - 2021-03-09 ### Fixed diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 5c4fb2815aa6d..fb4af24c93505 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torch.optim import Optimizer @@ -12,6 +12,9 @@ if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm + from torch_xla._patched_functions import clip_grad_norm_ + + xla_clip_grad_norm_ = clip_grad_norm_ if TYPE_CHECKING: from pytorch_lightning.core.lightning import LightningModule @@ -55,3 +58,16 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed: return xm.all_gather(tensor).view(-1, *tensor.shape) return tensor + + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0): + + model = self.lightning_module + parameters = model.parameters() + + grad_clip_val = float(clip_val) + if grad_clip_val <= 0: + return + + max_norm = grad_clip_val + + xla_clip_grad_norm_(parameters, max_norm, norm_type) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 2b1579cf497c0..7172d82391bd3 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -100,7 +100,6 @@ def post_optimizer_step(self, optimizer: 'Optimizer', optimizer_idx: int) -> Non def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: """Clips the gradients to a specific value""" - # TODO: separate TPU case from here if clip_val is None: return diff --git a/tests/models/test_tpu.py b/tests/models/test_tpu.py index 0c922c99149fa..5358b9f881048 100644 --- a/tests/models/test_tpu.py +++ b/tests/models/test_tpu.py @@ -355,3 +355,31 @@ def test_reduce(rank): assert result.item() == 8 xmp.spawn(test_reduce, nprocs=8, start_method='fork') + + +@pytest.mark.parametrize("clip_val", [0, 10]) +@RunIf(tpu=True) +@pl_multi_process_test +@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_") +def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir): + """ + Ensure that clip gradients is only called if the value is greater than 0. + """ + tutils.reset_seed() + trainer_options = dict( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=1, + tpu_cores=1, + precision=16, + limit_train_batches=4, + limit_val_batches=4, + gradient_clip_val=clip_val, + ) + model = BoringModel() + tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False) + + if clip_val > 0: + mock_clip_grad_norm.assert_called() + else: + mock_clip_grad_norm.assert_not_called()