Skip to content

Commit

Permalink
[bugfix] Resolve bug with multiple optimizers and toggle. (#5574)
Browse files Browse the repository at this point in the history
* fix toggle_optimizer

* update doc

* resolve bug

* update

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* update on comments

* update on comments

* update

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
tchaton and rohitgr7 committed Jan 25, 2021
1 parent e87424a commit c76cc23
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 9 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `toggle_optimizer` to reset `requieres_grad` state ([#5574](https://github.com/PyTorchLightning/pytorch-lightning/pull/5574))


- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620))


Expand Down Expand Up @@ -63,7 +66,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Check environ before selecting a seed to prevent warning message ([#4743](https://github.com/PyTorchLightning/pytorch-lightning/pull/4743))
- Fixed signature mismatch in `model_to_device` of `DDPCPUHPCAccelerator` ([#5505](https://github.com/PyTorchLightning/pytorch-lightning/pull/5505))


## [1.1.3] - 2021-01-05

### Added
Expand Down
45 changes: 38 additions & 7 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,16 +1170,47 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
Override for your own behavior
It works with ``untoggle_optimizer`` to make sure param_requires_grad_state is properly reset.
Args:
optimizer:
optimizer_idx:
optimizer: Current optimizer used in training_loop
optimizer_idx: Current optimizer idx in training_loop
"""
for param in self.parameters():
param.requires_grad = False
param_requires_grad_state = {}
# make sure current optimizer is latest to be iterated over.
optimizers = [opt for opt in self.optimizers(use_pl_optimizer=False) if opt != optimizer] + [optimizer]
num_optimizers = len(optimizers) - 1
for opt_idx, opt in enumerate(optimizers):
for group in opt.param_groups:
for param in group['params']:
if num_optimizers == opt_idx:
# If a param appears in 2 optimizers, revert `requires_grad` to before toggle.
if param in param_requires_grad_state:
param.requires_grad = param_requires_grad_state[param]
else:
# save requires_grad for later restoration
param_requires_grad_state[param] = param.requires_grad
param.requires_grad = False

self._param_requires_grad_state = param_requires_grad_state

def untoggle_optimizer(self, optimizer_idx: int):
"""
.. note:: Only called when using multiple optimizers
for group in optimizer.param_groups:
for param in group['params']:
param.requires_grad = True
Override for your own behavior
Args:
optimizer_idx: Current optimizer idx in training_loop
"""
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
if optimizer_idx != opt_idx:
for group in opt.param_groups:
for param in group['params']:
if param in self._param_requires_grad_state:
param.requires_grad = self._param_requires_grad_state[param]
# save memory
del self._param_requires_grad_state

def optimizer_step(
self,
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,10 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
if self.trainer.terminate_on_nan:
self.trainer.detect_nan_tensors(result.loss)

if len(self.trainer.optimizers) > 1:
# revert back to previous state
self.trainer.get_model().untoggle_optimizer(opt_idx)

return result

def backward(self, result, optimizer, opt_idx, *args, **kwargs):
Expand Down
74 changes: 73 additions & 1 deletion tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
import pickle
from argparse import ArgumentParser
from typing import Optional
from unittest.mock import MagicMock, patch

import pytest
import torch
from torch import nn
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader, random_split

Expand Down Expand Up @@ -139,3 +140,74 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos
)

trainer.fit(model)


def test_toggle_untoggle(tmpdir):

class TestModel(BoringModel):

def training_step(self, batch, batch_idx, optimizer_idx=None):
return super().training_step(batch, batch_idx)

def __init__(self):
super().__init__()
self.layer_1 = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
)

self.layer_2 = nn.Sequential(
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 2)
)

# set some weights to False to check untoggle works as expected.
self.layer_1[2].weight.requires_grad = False
self.layer_1[4].weight.requires_grad = False

self.layer_2[1].weight.requires_grad = False
self.layer_2[3].weight.requires_grad = False

def configure_optimizers(self):
optimizer = SGD(self.layer_1.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer_2.parameters(), lr=0.1)
return [optimizer, optimizer_2]

def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
if optimizer_idx == 0:
assert self.layer_1[0].weight.requires_grad is True
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is False

if optimizer_idx == 1:
assert self.layer_1[0].weight.requires_grad is False
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is True
optimizer.step(closure=closure)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
)

trainer.fit(model)

0 comments on commit c76cc23

Please sign in to comment.