Skip to content

Commit

Permalink
Update Gradient Clipping for TPU Accelerator (#6576)
Browse files Browse the repository at this point in the history
(cherry picked from commit 87c03b1)
  • Loading branch information
kaushikb11 authored and Borda committed Mar 23, 2021
1 parent 9cd985b commit 9ce794c
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576))


### Fixed

Expand Down
16 changes: 16 additions & 0 deletions pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
from torch_xla._patched_functions import clip_grad_norm_

xla_clip_grad_norm_ = clip_grad_norm_


class TPUAccelerator(Accelerator):
Expand Down Expand Up @@ -44,3 +47,16 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
return xm.all_gather(tensor).view(-1, *tensor.shape)
return tensor

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[float, int], norm_type: float = 2.0):

model = self.lightning_module
parameters = model.parameters()

grad_clip_val = float(clip_val)
if grad_clip_val <= 0:
return

max_norm = grad_clip_val

xla_clip_grad_norm_(parameters, max_norm, norm_type)
1 change: 0 additions & 1 deletion pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None:

def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)) -> None:
"""Clips the gradients to a specific value"""
# TODO: separate TPU case from here
if clip_val is None:
return

Expand Down
28 changes: 28 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,31 @@ def test_reduce(rank):
assert result.item() == 8

xmp.spawn(test_reduce, nprocs=8, start_method='fork')


@pytest.mark.parametrize("clip_val", [0, 10])
@RunIf(tpu=True)
@pl_multi_process_test
@mock.patch("pytorch_lightning.accelerators.tpu.xla_clip_grad_norm_")
def test_tpu_precision_16_clip_gradients(mock_clip_grad_norm, clip_val, tmpdir):
"""
Ensure that clip gradients is only called if the value is greater than 0.
"""
tutils.reset_seed()
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=1,
precision=16,
limit_train_batches=4,
limit_val_batches=4,
gradient_clip_val=clip_val,
)
model = BoringModel()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)

if clip_val > 0:
mock_clip_grad_norm.assert_called()
else:
mock_clip_grad_norm.assert_not_called()

0 comments on commit 9ce794c

Please sign in to comment.