Skip to content

Commit

Permalink
Fix toggle optimizer (#5775)
Browse files Browse the repository at this point in the history
* Update lightning.py

* update changelog

* add a 3 optimizer test

* resolve flake8

* remove extra code

* typo

* resolve typo

* Apply suggestions from code review

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people committed Feb 5, 2021
1 parent 51d40e4 commit e991509
Show file tree
Hide file tree
Showing 3 changed files with 254 additions and 13 deletions.
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,19 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [unreleased] - YYYY-MM-DD

### Added

### Changed

### Deprecated

### Removed

### Fixed

- Fixed `toggle_optimizers` not handling all optimizer parameters ([#5775](https://github.com/PyTorchLightning/pytorch-lightning/pull/5775))

## [unreleased.Features] - YYYY-MM-DD

Expand Down
28 changes: 15 additions & 13 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,22 +1196,24 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):
optimizer: Current optimizer used in training_loop
optimizer_idx: Current optimizer idx in training_loop
"""

# Iterate over all optimizer parameters to preserve their `requires_grad` information
# in case these are pre-defined during `configure_optimizers`
param_requires_grad_state = {}
# make sure current optimizer is latest to be iterated over.
optimizers = [opt for opt in self.optimizers(use_pl_optimizer=False) if opt != optimizer] + [optimizer]
num_optimizers = len(optimizers) - 1
for opt_idx, opt in enumerate(optimizers):
for opt in self.optimizers(use_pl_optimizer=False):
for group in opt.param_groups:
for param in group['params']:
if num_optimizers == opt_idx:
# If a param appears in 2 optimizers, revert `requires_grad` to before toggle.
if param in param_requires_grad_state:
param.requires_grad = param_requires_grad_state[param]
else:
# save requires_grad for later restoration
param_requires_grad_state[param] = param.requires_grad
param.requires_grad = False

# If a param already appear in param_requires_grad_state, continue
if param in param_requires_grad_state:
continue
param_requires_grad_state[param] = param.requires_grad
param.requires_grad = False

# Then iterate over the current optimizer's parameters and set its `requires_grad`
# properties accordingly
for group in optimizer.param_groups:
for param in group['params']:
param.requires_grad = param_requires_grad_state[param]
self._param_requires_grad_state = param_requires_grad_state

def untoggle_optimizer(self, optimizer_idx: int):
Expand Down
226 changes: 226 additions & 0 deletions tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from unittest.mock import Mock, patch

import pytest
from torch import nn
from torch.optim import Adam, SGD

from pytorch_lightning import Trainer
Expand Down Expand Up @@ -184,3 +185,228 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos
)

trainer.fit(model)


def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmpdir):

class TestModel(BoringModel):

def training_step(self, batch, batch_idx, optimizer_idx=None):
return super().training_step(batch, batch_idx)

def __init__(self):
super().__init__()
self.layer_1 = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
)

self.layer_2 = nn.Sequential(
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 2)
)

# set some weights to False to check untoggle works as expected.
self.layer_1[2].weight.requires_grad = False
self.layer_1[4].weight.requires_grad = False

self.layer_2[1].weight.requires_grad = False
self.layer_2[3].weight.requires_grad = False

def configure_optimizers(self):
optimizer = SGD(self.layer_1.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer_2.parameters(), lr=0.1)
return [optimizer, optimizer_2]

def optimizer_step(
self,
current_epoch,
batch_nb,
optimizer,
optimizer_idx,
closure,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False
):
if optimizer_idx == 0:
assert self.layer_1[0].weight.requires_grad is True
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is False

if optimizer_idx == 1:
assert self.layer_1[0].weight.requires_grad is False
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is True

optimizer.step(closure=closure)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
limit_val_batches=0,
)

results = trainer.fit(model)
assert results


def test_toggle_untoggle_3_optimizers_shared_parameters(tmpdir):

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.layer_1 = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
)

self.layer_2 = nn.Sequential(
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 2)
)

self.layer_3 = nn.Sequential(
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 2)
)

# set some weights to False to check untoggle works as expected.
self.layer_1[2].weight.requires_grad = False
self.layer_1[4].weight.requires_grad = False

self.layer_2[1].weight.requires_grad = False
self.layer_2[3].weight.requires_grad = False

self.layer_3[1].weight.requires_grad = False
self.layer_3[5].weight.requires_grad = False

def optimizer_step(
self,
current_epoch,
batch_nb,
optimizer,
optimizer_idx,
closure,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False
):
if optimizer_idx == 0:
assert self.layer_1[0].weight.requires_grad is True
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is True

assert self.layer_3[1].weight.requires_grad is False
assert self.layer_3[3].weight.requires_grad is False
assert self.layer_3[5].weight.requires_grad is False

if optimizer_idx == 1:
assert self.layer_1[0].weight.requires_grad is False
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is True

assert self.layer_3[1].weight.requires_grad is False
assert self.layer_3[3].weight.requires_grad is True
assert self.layer_3[5].weight.requires_grad is False

if optimizer_idx == 2:
assert self.layer_1[0].weight.requires_grad is True
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is False

assert self.layer_3[1].weight.requires_grad is False
assert self.layer_3[3].weight.requires_grad is True
assert self.layer_3[5].weight.requires_grad is False

optimizer.step(closure=closure)

def training_step(self, batch, batch_idx, optimizer_idx=None):
return super().training_step(batch, batch_idx)

@staticmethod
def combine_generators(gen_1, gen_2):
for p in gen_1:
yield p
for p in gen_2:
yield p

def configure_optimizers(self):
optimizer_1 = SGD(
self.combine_generators(
self.layer_1.parameters(),
self.layer_2.parameters()
),
lr=0.1
)
optimizer_2 = Adam(
self.combine_generators(
self.layer_2.parameters(),
self.layer_3.parameters()
),
lr=0.1
)
optimizer_3 = SGD(
self.combine_generators(
self.layer_3.parameters(),
self.layer_1.parameters()
),
lr=0.1
)
return [optimizer_1, optimizer_2, optimizer_3]

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
)

trainer.fit(model)

0 comments on commit e991509

Please sign in to comment.