From 0733590332296a49a3fc60b3dd6484ed16b487a5 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Thu, 4 Mar 2021 19:45:58 +0000 Subject: [PATCH] [Fix] Call clip gradients if clip val greater than 0 (#6330) * Call clip gradients if clip val greater than 0 * format * Format * Move to top of file --- CHANGELOG.md | 3 +++ .../plugins/precision/sharded_native_amp.py | 5 ++++- tests/plugins/test_sharded_plugin.py | 20 +++++++++++++++++++ 3 files changed, 27 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c3251a32a49ba3..9ba56999139608 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Resolve memory leak for evaluation ([#6326](https://github.com/PyTorchLightning/pytorch-lightning/pull/6326) +- Ensure that clip gradients is only called if the value is greater than 0 ([#6330](https://github.com/PyTorchLightning/pytorch-lightning/pull/6330) + + ## [1.2.2] - 2021-03-02 ### Added diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index b3b01fc720d2ba..8ade1396a174c0 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -31,6 +31,9 @@ def __init__(self): super().__init__() self.scaler = ShardedGradScaler() - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + if clip_val <= 0: + return + optimizer = cast(OSS, optimizer) optimizer.clip_grad_norm(clip_val, norm_type=norm_type) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index f3683ffcba252c..bca5079f82f82e 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -1,5 +1,6 @@ import os import platform +from unittest import mock import pytest import torch @@ -12,6 +13,25 @@ from tests.helpers.boring_model import BoringModel +@pytest.mark.parametrize("clip_val", [0, 10]) +@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires GPU machine") +@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") +@mock.patch('fairscale.optim.oss.OSS.clip_grad_norm') +def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_val, tmpdir): + """ + Ensure that clip gradients is only called if the value is greater than 0. + """ + model = BoringModel() + trainer = Trainer(accelerator='ddp_sharded', gpus=1, precision=16, fast_dev_run=True, gradient_clip_val=clip_val) + trainer.fit(model) + if clip_val > 0: + mock_oss_clip_grad_norm.assert_called() + else: + mock_oss_clip_grad_norm.assert_not_called() + + +@RunIf(fairscale=True) @pytest.mark.parametrize(["accelerator"], [("ddp_sharded", ), ("ddp_sharded_spawn", )]) @pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_sharded_ddp_choice(tmpdir, accelerator):