Skip to content

Commit

Permalink
[Bugfix] Fixed epoch level schedulers not being called when val_check…
Browse files Browse the repository at this point in the history
…_interval < 1.0 (#6075)

* fix bug

* fix tests

* changelog

* fix pep8

* fix tests

* fix and add some tests

* add test for rlop

* chlog

* Update CHANGELOG.md

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
  • Loading branch information
2 people authored and lexierule committed Mar 5, 2021
1 parent c219aa4 commit 76a6321
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 30 deletions.
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,25 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


## [1.2.2] - 2021-03-02

### Added


### Changed


### Deprecated


### Removed


### Fixed

- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))


## [1.2.1] - 2021-02-23

### Fixed
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,21 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
continue
# update LR
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']

if lr_scheduler['reduce_on_plateau']:
lr_scheduler['scheduler'].step(monitor_val)
else:
lr_scheduler['scheduler'].step()

new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']

if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=monitor_key
self.trainer.batch_idx,
interval,
scheduler_idx,
old_lr,
new_lr,
monitor_key=monitor_key,
monitor_val=monitor_val
)
22 changes: 13 additions & 9 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def run_training_epoch(self):
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
should_check_val = False
val_loop_called = False

for batch_idx, (batch, is_last_batch) in train_dataloader:

Expand Down Expand Up @@ -515,6 +516,7 @@ def run_training_epoch(self):
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch)
if should_check_val:
self.trainer.run_evaluation()
val_loop_called = True

# reset stage to train
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
Expand Down Expand Up @@ -560,21 +562,23 @@ def run_training_epoch(self):
)

should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
if should_check_val:
self.trainer.run_evaluation(on_epoch=True)

# reset stage to train
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)

should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval

if should_train_only:
# update epoch level lr_schedulers
# update epoch level lr_schedulers if no val loop outside train loop is triggered
if (val_loop_called and not should_check_val) or should_train_only:
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')

if should_train_only:
self.check_checkpoint_callback(True)
self.check_early_stopping_callback(True)

if should_check_val:
self.trainer.run_evaluation(on_epoch=True)

# reset stage to train
self.trainer._running_stage = RunningStage.TRAINING

# increment the global step once
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
Expand Down Expand Up @@ -820,7 +824,7 @@ def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
can_check_val = self.trainer.enable_validation and is_val_check_epoch
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0

should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop
or is_last_batch_for_infinite_dataset
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/utilities/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,16 @@ def track_train_loss_history(self, batch_idx, loss):
self.saved_train_losses.append(loss_dict)

@enabled_only
def track_lr_schedulers_update(self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None):
def track_lr_schedulers_update(
self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None
):
loss_dict = {
'batch_idx': batch_idx,
'interval': interval,
'scheduler_idx': scheduler_idx,
'epoch': self.trainer.current_epoch,
'monitor_key': monitor_key,
'monitor_val': monitor_val,
'old_lr': old_lr,
'new_lr': new_lr
}
Expand Down
149 changes: 136 additions & 13 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch
import yaml
from omegaconf import Container, OmegaConf
from torch import optim

import pytorch_lightning as pl
import tests.helpers.utils as tutils
Expand All @@ -47,8 +48,8 @@ def training_step(self, batch, batch_idx):

def validation_epoch_end(self, outputs):
outs = torch.stack([x['x'] for x in outputs]).mean()
self.log('epoch', self.current_epoch, on_epoch=True)
self.log('val_acc', outs, on_epoch=True)
self.log('epoch', self.current_epoch)
self.log('val_acc', outs)


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
Expand All @@ -57,14 +58,16 @@ def validation_epoch_end(self, outputs):
[('base', "base", 'val_log'), ('base', "base", 'train_log_epoch'), (None, "base", 'train_log_epoch'),
("base", None, 'train_log_epoch')],
)
def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step, val_dataloaders, monitor):
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor, reduce_lr_on_plateau):
"""
Test that when a model checkpoint is saved, it saves with
the correct score appended to ckpt_path and checkpoint data
"""
max_epochs = 3
limit_train_batches = 5
limit_val_batches = 7
lr = 1e-1

class CustomBoringModel(BoringModel):

Expand All @@ -74,21 +77,28 @@ def __init__(self):
self.val_logs = torch.randn(max_epochs, limit_val_batches)

def training_step(self, batch, batch_idx):
out = super().training_step(batch, batch_idx)
log_value = self.train_log_epochs[self.current_epoch, batch_idx]
self.log('train_log', log_value, on_epoch=True)
return out
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
out = super().validation_step(batch, batch_idx)
log_value = self.val_logs[self.current_epoch, batch_idx]
self.log('val_log', log_value)
self.log('epoch', self.current_epoch, on_epoch=True)
return out
return super().validation_step(batch, batch_idx)

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.2)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
optimizer = optim.SGD(self.parameters(), lr=lr)

if reduce_lr_on_plateau:
lr_scheduler = {
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
'monitor': monitor,
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)

return [optimizer], [lr_scheduler]

filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
Expand All @@ -109,11 +119,15 @@ def configure_optimizers(self):
max_epochs=max_epochs,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
results = trainer.fit(model)
assert results
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
assert len(ckpt_files) == len(scores) == max_epochs
assert len(lr_scheduler_debug) == max_epochs

for epoch in range(max_epochs):
score = scores[epoch]
Expand All @@ -130,9 +144,118 @@ def configure_optimizers(self):
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score

lr_scheduler_specific_data = chk['lr_schedulers'][0]
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1**(epoch + 1))
if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1))

assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.parametrize(
"val_check_interval,reduce_lr_on_plateau",
[
(0.25, True),
(0.25, False),
(0.33, False),
],
)
def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, reduce_lr_on_plateau):
"""
Test that when a model checkpoint is saved, it saves with the correct
score appended to ckpt_path and checkpoint data with val_check_interval
"""
max_epochs = 3
limit_train_batches = 12
limit_val_batches = 7
lr = 1e-1
monitor = 'val_log'
per_epoch_steps = int(limit_train_batches * val_check_interval)
per_epoch_call_count = limit_train_batches // per_epoch_steps

class CustomBoringModel(BoringModel):

def __init__(self):
super().__init__()
self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches)
self.val_loop_count = 0

def validation_step(self, batch, batch_idx):
log_value = self.val_logs[self.val_loop_count, batch_idx]
self.log('val_log', log_value)
self.log('epoch', self.current_epoch, on_epoch=True)
return super().validation_step(batch, batch_idx)

def validation_epoch_end(self, outputs):
self.val_loop_count += 1
super().validation_epoch_end(outputs)

def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=lr)

if reduce_lr_on_plateau:
lr_scheduler = {
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
'monitor': monitor,
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)

return [optimizer], [lr_scheduler]

filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)

model = CustomBoringModel()

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint],
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
max_epochs=max_epochs,
val_check_interval=val_check_interval,
progress_bar_refresh_rate=0,
num_sanity_val_steps=0,
)
results = trainer.fit(model)
assert results
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs
assert len(lr_scheduler_debug) == max_epochs

for epoch in range(max_epochs):
for ix in range(per_epoch_call_count):
global_ix = ix + per_epoch_call_count * epoch
score = scores[global_ix]
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt'
assert math.isclose(score, expected_score, rel_tol=1e-4)

chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk['epoch'] == epoch + 1
assert chk['global_step'] == per_epoch_steps * (global_ix + 1)

mc_specific_data = chk['callbacks'][type(checkpoint)]
assert mc_specific_data['dirpath'] == checkpoint.dirpath
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score

if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
did_update = 1 if ix + 1 == per_epoch_call_count else 0
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))

assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)


@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
Expand Down
9 changes: 3 additions & 6 deletions tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@ def test_optimizer_with_scheduling(tmpdir):

# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=0.1,
limit_train_batches=0.2,
default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, val_check_interval=0.5
)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
Expand Down Expand Up @@ -164,15 +161,15 @@ def test_reducelronplateau_scheduling(tmpdir):
model.configure_optimizers = lambda: {
'optimizer': optimizer,
'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
'monitor': 'early_stop_on',
'monitor': 'val_acc',
}
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
lr_scheduler = trainer.lr_schedulers[0]
assert lr_scheduler == dict(
scheduler=lr_scheduler['scheduler'],
monitor='early_stop_on',
monitor='val_acc',
interval='epoch',
frequency=1,
reduce_on_plateau=True,
Expand Down

0 comments on commit 76a6321

Please sign in to comment.