Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Gradient Clipping for TPU Accelerator #6576

Merged
merged 6 commits into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505),

[#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530),

[#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547),

[#6515](https://github.com/PyTorchLightning/pytorch-lightning/pull/6515),
Expand Down Expand Up @@ -161,6 +161,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511))


- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576))


## [1.2.3] - 2021-03-09

### Fixed
Expand Down
18 changes: 17 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer
Expand All @@ -12,6 +12,9 @@

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
from torch_xla._patched_functions import clip_grad_norm_

xla_clip_grad_norm_ = clip_grad_norm_

if TYPE_CHECKING:
from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -55,3 +58,16 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
if torch.distributed.is_initialized():
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
return tensor

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):

model = self.lightning_module
parameters = model.parameters()

grad_clip_val = float(clip_val)
if grad_clip_val <= 0:
return

max_norm = grad_clip_val

xla_clip_grad_norm_(parameters, max_norm, norm_type)
1 change: 0 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def post_optimizer_step(self, optimizer: 'Optimizer', optimizer_idx: int) -> Non

def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
"""Clips the gradients to a specific value"""
# TODO: separate TPU case from here
if clip_val is None:
return

Expand Down
28 changes: 28 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,3 +355,31 @@ def test_reduce(rank):
assert result.item() == 8

xmp.spawn(test_reduce, nprocs=8, start_method='fork')


@pytest.mark.parametrize("clip_val", [0, 10])
@RunIf(tpu=True)
@pl_multi_process_test
@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_")
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
"""
Ensure that clip gradients is only called if the value is greater than 0.
"""
tutils.reset_seed()
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=1,
precision=16,
limit_train_batches=4,
limit_val_batches=4,
gradient_clip_val=clip_val,
)
model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)

if clip_val > 0:
mock_clip_grad_norm.assert_called()
else:
mock_clip_grad_norm.assert_not_called()