Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Resolve bug with multiple optimizers and toggle. #5574

Merged
merged 22 commits into from
Jan 25, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,16 +1166,38 @@ def toggle_optimizer(self, optimizer: Optimizer, optimizer_idx: int):

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe just use the opt_idx to get the optimizer down there... just to unify the arguments with untoggle_optimizer. Either way is fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great ! I will do so.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to do it in this PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Override for your own behavior

It works with `untoggle_optimizer` to make sure param_requires_grad_state is properly reset.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved

Args:
optimizer:
optimizer_idx:
"""
for param in self.parameters():
param.requires_grad = False
param_requires_grad_state = {}
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
if optimizer_idx != opt_idx:
for group in opt.param_groups:
for param in group['params']:
param_requires_grad_state[param] = param.requires_grad
param.requires_grad = False

self._param_requires_grad_state = param_requires_grad_state

def untoggle_optimizer(self, optimizer_idx: int):
"""
.. note:: Only called when using multiple optimizers

Override for your own behavior

for group in optimizer.param_groups:
for param in group['params']:
param.requires_grad = True
Args:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
optimizer_idx:
"""
for opt_idx, opt in enumerate(self.optimizers(use_pl_optimizer=False)):
if optimizer_idx != opt_idx:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for group in opt.param_groups:
for param in group['params']:
param.requires_grad = self._param_requires_grad_state[param]
# save memory
del self._param_requires_grad_state

def optimizer_step(
self,
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,10 @@ def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer,
if self.trainer.terminate_on_nan:
self.trainer.detect_nan_tensors(result.loss)

if len(self.trainer.optimizers) > 1:
# revert back to previous state
self.trainer.get_model().untoggle_optimizer(opt_idx)

return result

def backward(self, result, optimizer, opt_idx, *args, **kwargs):
Expand Down
77 changes: 76 additions & 1 deletion tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from argparse import ArgumentParser
import pickle
from argparse import ArgumentParser
from typing import Optional
from unittest.mock import MagicMock, patch

import pytest
import torch
from torch import nn
from torch.optim import Adam, SGD
from torch.utils.data import DataLoader, random_split

Expand Down Expand Up @@ -139,3 +140,77 @@ def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, clos
)

trainer.fit(model)


def test_toggle_untoggle(tmpdir):

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.layer_1 = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
)

self.layer_2 = nn.Sequential(
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 2)
)

# set some weights to False to check untoggle works as expected.
self.layer_1[2].weight.requires_grad = False
self.layer_1[4].weight.requires_grad = False

self.layer_2[1].weight.requires_grad = False
self.layer_2[3].weight.requires_grad = False

def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}

tchaton marked this conversation as resolved.
Show resolved Hide resolved
def configure_optimizers(self):
optimizer = SGD(self.layer_1.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer_2.parameters(), lr=0.1)
return [optimizer, optimizer_2]

def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
if optimizer_idx == 0:
assert self.layer_1[0].weight.requires_grad is True
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is False
optimizer.step(closure=closure)

if optimizer_idx == 1:
assert self.layer_1[0].weight.requires_grad is False
assert self.layer_1[2].weight.requires_grad is False
assert self.layer_1[4].weight.requires_grad is False

assert self.layer_2[1].weight.requires_grad is False
assert self.layer_2[3].weight.requires_grad is False
assert self.layer_2[5].weight.requires_grad is True
optimizer.step(closure=closure)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
limit_train_batches=8,
accumulate_grad_batches=1,
)

trainer.fit(model)