Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Gradient Clipping for TPU Accelerator #6576

Merged
merged 6 commits into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505),

[#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530),

[#6547](https://github.com/PyTorchLightning/pytorch-lightning/pull/6547),

[#6515](https://github.com/PyTorchLightning/pytorch-lightning/pull/6515),
Expand Down Expand Up @@ -161,6 +161,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511))


- Update Gradient Clipping for the TPU Accelerator ([#6576](https://github.com/PyTorchLightning/pytorch-lightning/pull/6576))


## [1.2.3] - 2021-03-09

### Fixed
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/accelerators/tpu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, TYPE_CHECKING, Union

import torch
from torch.optim import Optimizer
Expand All @@ -12,6 +12,7 @@

if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
from torch_xla._patched_functions import clip_grad_norm_

if TYPE_CHECKING:
from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -55,3 +56,11 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
if torch.distributed.is_initialized():
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
return tensor

def clip_gradients(self, optimizer: Optimizer, grad_clip_val: Union[float, int], norm_type: float = 2.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The patched XLA function does not do a check for grad_clip_val>0 so we might need to include this as done here:

https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/precision/precision_plugin.py#L109-L110

In addition, will this mean we can get rid of the TODO here? https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/plugins/precision/precision_plugin.py#L103

The implementation of clip gradients is TPU compliant I think, but I think that's not an issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point !

Copy link
Contributor Author

@kaushikb11 kaushikb11 Mar 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeanNaren @tchaton The issue for aitextgen?

The code hangs when it calls xm.save in Transformer’s save_pretrained method, when passed xm_save to the save_function param from aitextgen for TPUs.
Code: https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_utils.py#L835

My hunch, xm.save has rendezvous in it. All the cores were not syncing up for rendezvous. Hence, the code was hanging.

Need to do the required changes on the aitextgen end.


model = self.lightning_module
parameters = model.parameters()
max_norm = grad_clip_val

clip_grad_norm_(parameters, max_norm, norm_type)