Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Resolve bug with multiple optimizers and toggle. #5574

Merged
merged 22 commits into from
Jan 25, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `toggle_optimizer` to reset `requieres_grad` state ([#5574](https://github.com/PyTorchLightning/pytorch-lightning/pull/5574))




Expand Down Expand Up @@ -52,7 +54,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Check environ before selecting a seed to prevent warning message ([#4743](https://github.com/PyTorchLightning/pytorch-lightning/pull/4743))
- Fixed signature mismatch in `model_to_device` of `DDPCPUHPCAccelerator` ([#5505](https://github.com/PyTorchLightning/pytorch-lightning/pull/5505))


## [1.1.3] - 2021-01-05

### Added
Expand Down
45 changes: 38 additions & 7 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,16 +1169,47 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just use the opt_idx to get the optimizer down there... just to unify the arguments with untoggle_optimizer. Either way is fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great ! I will do so.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to do it in this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Override for your own behavior

It works with ``untoggle_optimizer`` to make sure param_requires_grad_state is properly reset.

Args:
optimizer:
optimizer_idx:
optimizer: Current optimizer used in training_loop
optimizer_idx: Current optimizer idx in training_loop
"""
for param in self.parameters():
param.requires_grad = False
param_requires_grad_state = {}
# make sure current optimizer is latest to be iterated over.
optimizers = [opt for opt in self.optimizers(use_pl_optimizer=False) if opt != optimizer] + [optimizer]
num_optimizers = len(optimizers) - 1
for opt_idx, opt in enumerate(optimizers):
for group in opt.param_groups:
for param in group['params']:
if num_optimizers == opt_idx:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# If a param appears in 2 optimizers, revert `requires_grad` to before toggle.
if param in param_requires_grad_state:
param.requires_grad = param_requires_grad_state[param]
else:
# save requires_grad for later restoration
param_requires_grad_state[param] = param.requires_grad
param.requires_grad = False

self._param_requires_grad_state = param_requires_grad_state

def untoggle_optimizer(self, optimizer_idx: int):
"""
.. note:: Only called when using multiple optimizers

for group in optimizer.param_groups:
for param in group['params']:
param.requires_grad = True
Override for your own behavior

Args:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
optimizer_idx: Current optimizer idx in training_loop
"""
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
if optimizer_idx != opt_idx:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for group in opt.param_groups:
for param in group['params']:
if param in self._param_requires_grad_state:
param.requires_grad = self._param_requires_grad_state[param]
# save memory
del self._param_requires_grad_state

def optimizer_step(
self,
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,10 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
if self.trainer.terminate_on_nan:
self.trainer.detect_nan_tensors(result.loss)

if len(self.trainer.optimizers) > 1:
# revert back to previous state
self.trainer.get_model().untoggle_optimizer(opt_idx)

return result

def backward(self, result, optimizer, opt_idx, *args, **kwargs):
Expand Down
74 changes: 73 additions & 1 deletion tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
import pickle
from argparse import ArgumentParser
from typing import Optional
from unittest.mock import MagicMock, patch

import pytest
import torch
from torch import nn
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader, random_split

Expand Down Expand Up @@ -139,3 +140,74 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos
)

trainer.fit(model)


def test_toggle_untoggle(tmpdir):

class TestModel(BoringModel):

def training_step(self, batch, batch_idx, optimizer_idx=None):
return super().training_step(batch, batch_idx)

def __init__(self):
super().__init__()
self.layer_1 = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
)

self.layer_2 = nn.Sequential(
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 2)
)

# set some weights to False to check untoggle works as expected.
self.layer_1[2].weight.requires_grad = False
self.layer_1[4].weight.requires_grad = False

self.layer_2[1].weight.requires_grad = False
self.layer_2[3].weight.requires_grad = False

def configure_optimizers(self):
optimizer = SGD(self.layer_1.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer_2.parameters(), lr=0.1)
return [optimizer, optimizer_2]

def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
if optimizer_idx == 0:
assert self.layer_1[0].weight.requires_grad is True
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is False

if optimizer_idx == 1:
assert self.layer_1[0].weight.requires_grad is False
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is True
optimizer.step(closure=closure)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
)

trainer.fit(model)