diff --git a/CHANGELOG.md b/CHANGELOG.md index a981d2f580145..3643655d55ec4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `toggle_optimizer` to reset `requieres_grad` state ([#5574](https://github.com/PyTorchLightning/pytorch-lightning/pull/5574)) + + - Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620)) @@ -63,7 +66,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Check environ before selecting a seed to prevent warning message ([#4743](https://github.com/PyTorchLightning/pytorch-lightning/pull/4743)) - Fixed signature mismatch in `model_to_device` of `DDPCPUHPCAccelerator` ([#5505](https://github.com/PyTorchLightning/pytorch-lightning/pull/5505)) - ## [1.1.3] - 2021-01-05 ### Added diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 02015f6dc6229..4880f8336d067 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1170,16 +1170,47 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int): Override for your own behavior + It works with ``untoggle_optimizer`` to make sure param_requires_grad_state is properly reset. + Args: - optimizer: - optimizer_idx: + optimizer: Current optimizer used in training_loop + optimizer_idx: Current optimizer idx in training_loop """ - for param in self.parameters(): - param.requires_grad = False + param_requires_grad_state = {} + # make sure current optimizer is latest to be iterated over. + optimizers = [opt for opt in self.optimizers(use_pl_optimizer=False) if opt != optimizer] + [optimizer] + num_optimizers = len(optimizers) - 1 + for opt_idx, opt in enumerate(optimizers): + for group in opt.param_groups: + for param in group['params']: + if num_optimizers == opt_idx: + # If a param appears in 2 optimizers, revert `requires_grad` to before toggle. + if param in param_requires_grad_state: + param.requires_grad = param_requires_grad_state[param] + else: + # save requires_grad for later restoration + param_requires_grad_state[param] = param.requires_grad + param.requires_grad = False + + self._param_requires_grad_state = param_requires_grad_state + + def untoggle_optimizer(self, optimizer_idx: int): + """ + .. note:: Only called when using multiple optimizers - for group in optimizer.param_groups: - for param in group['params']: - param.requires_grad = True + Override for your own behavior + + Args: + optimizer_idx: Current optimizer idx in training_loop + """ + for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)): + if optimizer_idx != opt_idx: + for group in opt.param_groups: + for param in group['params']: + if param in self._param_requires_grad_state: + param.requires_grad = self._param_requires_grad_state[param] + # save memory + del self._param_requires_grad_state def optimizer_step( self, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 47e254606af93..0925bc78a9533 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -798,6 +798,10 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, if self.trainer.terminate_on_nan: self.trainer.detect_nan_tensors(result.loss) + if len(self.trainer.optimizers) > 1: + # revert back to previous state + self.trainer.get_model().untoggle_optimizer(opt_idx) + return result def backward(self, result, optimizer, opt_idx, *args, **kwargs): diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 64b68245ba66e..6c4416da380e0 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from argparse import ArgumentParser import pickle +from argparse import ArgumentParser from typing import Optional from unittest.mock import MagicMock, patch import pytest import torch +from torch import nn from torch.optim import Adam, SGD from torch.utils.data import DataLoader, random_split @@ -139,3 +140,74 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos ) trainer.fit(model) + + +def test_toggle_untoggle(tmpdir): + + class TestModel(BoringModel): + + def training_step(self, batch, batch_idx, optimizer_idx=None): + return super().training_step(batch, batch_idx) + + def __init__(self): + super().__init__() + self.layer_1 = nn.Sequential( + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + ) + + self.layer_2 = nn.Sequential( + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 32), + nn.ReLU(), + nn.Linear(32, 2) + ) + + # set some weights to False to check untoggle works as expected. + self.layer_1[2].weight.requires_grad = False + self.layer_1[4].weight.requires_grad = False + + self.layer_2[1].weight.requires_grad = False + self.layer_2[3].weight.requires_grad = False + + def configure_optimizers(self): + optimizer = SGD(self.layer_1.parameters(), lr=0.1) + optimizer_2 = Adam(self.layer_2.parameters(), lr=0.1) + return [optimizer, optimizer_2] + + def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False): + if optimizer_idx == 0: + assert self.layer_1[0].weight.requires_grad is True + assert self.layer_1[2].weight.requires_grad is False + assert self.layer_1[4].weight.requires_grad is False + + assert self.layer_2[1].weight.requires_grad is False + assert self.layer_2[3].weight.requires_grad is False + assert self.layer_2[5].weight.requires_grad is False + + if optimizer_idx == 1: + assert self.layer_1[0].weight.requires_grad is False + assert self.layer_1[2].weight.requires_grad is False + assert self.layer_1[4].weight.requires_grad is False + + assert self.layer_2[1].weight.requires_grad is False + assert self.layer_2[3].weight.requires_grad is False + assert self.layer_2[5].weight.requires_grad is True + optimizer.step(closure=closure) + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + max_epochs=1, + default_root_dir=tmpdir, + limit_train_batches=8, + accumulate_grad_batches=1, + ) + + trainer.fit(model)