Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Learning rate stepping option #941

Merged
merged 34 commits into from
Mar 5, 2020
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
20f32b4
remove deprecated args to learning rate step function
Feb 18, 2020
ef237f1
step based scheduler
Feb 25, 2020
67ae533
mixing models for testing
Feb 25, 2020
efe19e0
merge
Feb 25, 2020
e640403
fix styling
Feb 25, 2020
2e674e8
tests
Feb 25, 2020
4b96634
update documentation
Feb 25, 2020
b3a8d09
smaller fix
Feb 25, 2020
5c876c2
merge
Feb 26, 2020
8fa7a03
update to dict structure
Feb 26, 2020
12a526c
updated test
Feb 26, 2020
a3ac63f
update documentation
Feb 26, 2020
fc4847b
update CHANGELOG.md
Feb 26, 2020
4167975
fix styling
Feb 26, 2020
6e2d712
fix problems with trainer io
Feb 26, 2020
7d18fab
fix tests
Feb 26, 2020
215b85f
rebase
Feb 27, 2020
f01597d
simplification of code
Feb 27, 2020
55d9661
fix styling
Feb 27, 2020
2766910
change from batch to step
Feb 28, 2020
2c848b9
update to tests
Feb 28, 2020
1906239
fix styling
Feb 28, 2020
fc0ae09
fixed some logic
Feb 28, 2020
44207bc
Update pytorch_lightning/core/lightning.py
Borda Feb 28, 2020
1bcbf11
Merge branch 'master' into lr_stepping_option
williamFalcon Mar 3, 2020
ec15729
duplicated test
Borda Mar 3, 2020
2e5e9ba
fix test on amp
Mar 4, 2020
8dc4c31
small update to tests
Mar 4, 2020
284afe5
added monitor key for ReduceLROnPlateau
Mar 4, 2020
436ac59
Merge branch 'master' into lr_stepping_option
williamFalcon Mar 4, 2020
167886f
Merge branch 'master' into lr_stepping_option
williamFalcon Mar 5, 2020
bf4c2bb
Update trainer.py
williamFalcon Mar 5, 2020
383ed9a
Update training_loop.py
williamFalcon Mar 5, 2020
1f42822
fix test after introducing monitor keyword
Mar 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added semantic segmentation example ([#751](https://github.com/PyTorchLightning/pytorch-lightning/pull/751),[#876](https://github.com/PyTorchLightning/pytorch-lightning/pull/876))
- Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
- Added support for step-based learning rate scheduling ([#941](https://github.com/PyTorchLightning/pytorch-lightning/pull/941))

### Changed

Expand Down
11 changes: 11 additions & 0 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,15 @@ def configure_optimizers(self):
discriminator_sched = CosineAnnealing(discriminator_opt, T_max=10)
return [generator_opt, disriminator_opt], [discriminator_sched]

# example with step-based learning_rate schedulers
def configure_optimizers(self):
gen_opt = Adam(self.model_gen.parameters(), lr=0.01)
dis_opt = Adam(self.model_disc.parameters(), lr=0.02)
gen_sched = {'scheduler': ExponentialLR(gen_opt, 0.99),
'interval': 'step'} # called after each training step
dis_sched = CosineAnnealing(discriminator_opt, T_max=10) # called after each epoch
return [gen_opt, dis_opt], [gen_sched, dis_sched]

.. note:: Lightning calls .backward() and .step() on each optimizer and learning rate scheduler as needed.

.. note:: If you use 16-bit precision (use_amp=True), Lightning will automatically
Expand All @@ -766,6 +775,8 @@ def configure_optimizers(self):
.. note:: If you need to control how often those optimizers step or override the default .step() schedule,
override the `optimizer_step` hook.

.. note:: If you only want to call a learning rate schduler every `x` step or epoch,
you can input this as 'frequency' key: dict(scheduler=lr_schudler, interval='step' or 'epoch', frequency=x)

"""

Expand Down
54 changes: 41 additions & 13 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Union, Optional, List, Dict, Tuple, Iterable

import torch
from torch import optim
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -739,8 +740,6 @@ def on_train_end(self):
# creates a default one if none passed in
self.configure_early_stopping(early_stop_callback)

self.reduce_lr_on_plateau_scheduler = None

# configure checkpoint callback
self.checkpoint_callback = checkpoint_callback
self.weights_save_path = weights_save_path
Expand Down Expand Up @@ -1028,26 +1027,55 @@ def init_optimizers(
optimizers: Union[Optimizer, Tuple[List, List], List[Optimizer], Tuple[Optimizer]]
) -> Tuple[List, List]:

# single optimizer
# single output, single optimizer
if isinstance(optimizers, Optimizer):
return [optimizers], []

# two lists
if len(optimizers) == 2 and isinstance(optimizers[0], list):
# two lists, optimizer + lr schedulers
elif len(optimizers) == 2 and isinstance(optimizers[0], list):
optimizers, lr_schedulers = optimizers
lr_schedulers, self.reduce_lr_on_plateau_scheduler = self.configure_schedulers(lr_schedulers)
lr_schedulers = self.configure_schedulers(lr_schedulers)
return optimizers, lr_schedulers

# single list or tuple
if isinstance(optimizers, (list, tuple)):
# single list or tuple, multiple optimizer
elif isinstance(optimizers, (list, tuple)):
return optimizers, []

# unknown configuration
else:
raise ValueError('Unknown configuration for model optimizers. Output'
'from model.configure_optimizers() should either be:'
'* single output, single torch.optim.Optimizer'
'* single output, list of torch.optim.Optimizer'
'* two outputs, first being a list of torch.optim.Optimizer',
'second being a list of torch.optim.lr_scheduler')

def configure_schedulers(self, schedulers: list):
for i, scheduler in enumerate(schedulers):
if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
reduce_lr_on_plateau_scheduler = schedulers.pop(i)
return schedulers, reduce_lr_on_plateau_scheduler
return schedulers, None
# Convert each scheduler into dict sturcture with relevant information
lr_schedulers = []
default_config = {'interval': 'epoch', # default every epoch
'frequency': 1, # default every epoch/batch
'reduce_on_plateau': False} # most often not ReduceLROnPlateau scheduler
for scheduler in schedulers:
if isinstance(scheduler, dict):
if 'scheduler' not in scheduler:
raise ValueError(f'Lr scheduler should have key `scheduler`',
' with item being a lr scheduler')
scheduler['reduce_on_plateau'] = \
isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau)

lr_schedulers.append({**default_config, **scheduler})

elif isinstance(scheduler, optim.lr_scheduler.ReduceLROnPlateau):
lr_schedulers.append({**default_config, 'scheduler': scheduler,
'reduce_on_plateau': True})

elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
lr_schedulers.append({**default_config, 'scheduler': scheduler})
else:
raise ValueError(f'Input {scheduler} to lr schedulers '
'is a invalid input.')
return lr_schedulers

def run_pretrain_routine(self, model: LightningModule):
"""Sanity check a few things before starting actual training.
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
scheduler.load_state_dict(lrs_state)
scheduler['scheduler'].load_state_dict(lrs_state)

# uses the model you passed into trainer
model.load_state_dict(checkpoint['state_dict'])
Expand Down Expand Up @@ -343,8 +343,8 @@ def dump_checkpoint(self):

# save lr schedulers
lr_schedulers = []
for i, scheduler in enumerate(self.lr_schedulers):
lr_schedulers.append(scheduler.state_dict())
for scheduler in self.lr_schedulers:
lr_schedulers.append(scheduler['scheduler'].state_dict())

checkpoint['lr_schedulers'] = lr_schedulers

Expand Down Expand Up @@ -431,7 +431,7 @@ def restore_training_state(self, checkpoint):
# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers):
scheduler.load_state_dict(lrs_state)
scheduler['scheduler'].load_state_dict(lrs_state)

# ----------------------------------
# PRIVATE OPS
Expand Down
41 changes: 30 additions & 11 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,17 +400,7 @@ def train(self):
self.run_training_epoch()

# update LR schedulers
if self.lr_schedulers is not None:
for lr_scheduler in self.lr_schedulers:
lr_scheduler.step()
if self.reduce_lr_on_plateau_scheduler is not None:
val_loss = self.callback_metrics.get('val_loss')
if val_loss is None:
avail_metrics = ','.join(list(self.callback_metrics.keys()))
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
f'which is not available. Available metrics are: {avail_metrics}'
raise MisconfigurationException(m)
self.reduce_lr_on_plateau_scheduler.step(val_loss)
self.update_learning_rates(interval='epoch')

if self.max_steps and self.max_steps == self.global_step:
self.main_progress_bar.close()
Expand Down Expand Up @@ -487,6 +477,9 @@ def run_training_epoch(self):
# when returning -1 from train_step, we end epoch early
early_stop_epoch = batch_result == -1

# update lr
self.update_learning_rates(interval='step')

# ---------------
# RUN VAL STEP
# ---------------
Expand Down Expand Up @@ -751,3 +744,29 @@ def training_forward(self, batch, batch_idx, opt_idx, hiddens):
output = self.process_output(output, train=True)

return output

def update_learning_rates(self, interval):
''' Update learning rates
Args:
interval (str): either 'epoch' or 'step'.
'''
if not self.lr_schedulers:
return

for lr_scheduler in self.lr_schedulers:
current_idx = self.batch_idx if interval == 'step' else self.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
# the current step modulo the schedulers frequency is zero
if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency'] == 0:
# If instance of ReduceLROnPlateau, we need to pass validation loss
if lr_scheduler['reduce_on_plateau']:
val_loss = self.callback_metrics.get('val_loss')
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
if val_loss is None:
avail_metrics = ','.join(list(self.callback_metrics.keys()))
m = f'ReduceLROnPlateau conditioned on metric val_loss ' \
f'which is not available. Available metrics are: {avail_metrics}'
raise MisconfigurationException(m)
lr_scheduler['scheduler'].step(val_loss)
else:
lr_scheduler['scheduler'].step()
3 changes: 3 additions & 0 deletions tests/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
LightValStepFitMultipleDataloadersMixin,
LightTrainDataloader,
LightTestDataloader,
LightTestOptimizerWithSchedulingMixin,
LightTestMultipleOptimizersWithSchedulingMixin,
LightTestOptimizersWithMixedSchedulingMixin
)


Expand Down
2 changes: 1 addition & 1 deletion tests/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def loss(self, labels, logits):
nll = F.nll_loss(logits, labels)
return nll

def training_step(self, batch, batch_idx):
def training_step(self, batch, batch_idx, optimizer_idx=None):
"""
Lightning calls this inside the training loop
:param batch:
Expand Down
41 changes: 40 additions & 1 deletion tests/models/mixins.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import OrderedDict

import torch

from torch import optim
from pytorch_lightning.core.decorators import data_loader


Expand Down Expand Up @@ -598,6 +598,45 @@ def test_end(self, outputs):
return result


class LightTestOptimizerWithSchedulingMixin:
def configure_optimizers(self):
if self.hparams.optimizer_name == 'lbfgs':
optimizer = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
else:
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.1)
return [optimizer], [lr_scheduler]


class LightTestMultipleOptimizersWithSchedulingMixin:
def configure_optimizers(self):
if self.hparams.optimizer_name == 'lbfgs':
optimizer1 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
else:
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1)
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)

return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]


class LightTestOptimizersWithMixedSchedulingMixin:
def configure_optimizers(self):
if self.hparams.optimizer_name == 'lbfgs':
optimizer1 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.LBFGS(self.parameters(), lr=self.hparams.learning_rate)
else:
optimizer1 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
optimizer2 = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 4, gamma=0.1)
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)

return [optimizer1, optimizer2], \
[{'scheduler': lr_scheduler1, 'interval': 'step'}, lr_scheduler2]


def _get_output_metric(output, name):
if isinstance(output, dict):
val = output[name]
Expand Down
7 changes: 5 additions & 2 deletions tests/test_gpu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,13 @@ def test_optimizer_return_options():
assert len(lr_sched) == 0

# opt tuple of lists
opts = ([opt_a], ['lr_scheduler'])
scheduler = torch.optim.lr_scheduler.StepLR(opt_a, 10)
opts = ([opt_a], [scheduler])
optim, lr_sched = trainer.init_optimizers(opts)
assert len(optim) == 1 and len(lr_sched) == 1
assert optim[0] == opts[0][0] and lr_sched[0] == 'lr_scheduler'
assert optim[0] == opts[0][0] and \
lr_sched[0] == dict(scheduler=scheduler, interval='epoch',
frequency=1, reduce_on_plateau=False)


def test_cpu_slurm_save_load(tmpdir):
Expand Down
Loading