Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fixed epoch level schedulers not being called when val_check_interval < 1.0 #6075

Merged
merged 10 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed wrapping optimizers upon assignment ([#6006](https://github.com/PyTorchLightning/pytorch-lightning/pull/6006))
- Fixed allowing hashing of metrics with lists in their state ([#5939](https://github.com/PyTorchLightning/pytorch-lightning/pull/5939))

- Fixed epoch level schedulers not being called when `val_check_interval!=1` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))


## [1.1.8] - 2021-02-08

Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,21 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
continue
# update LR
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']

if lr_scheduler['reduce_on_plateau']:
lr_scheduler['scheduler'].step(monitor_val)
else:
lr_scheduler['scheduler'].step()

new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']

if self.trainer.dev_debugger.enabled:
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=monitor_key
self.trainer.batch_idx,
interval,
scheduler_idx,
old_lr,
new_lr,
monitor_key=monitor_key,
monitor_val=monitor_val
)
22 changes: 13 additions & 9 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def run_training_epoch(self):
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
should_check_val = False
val_loop_called = False

for batch_idx, (batch, is_last_batch) in train_dataloader:

Expand Down Expand Up @@ -515,6 +516,7 @@ def run_training_epoch(self):
should_check_val = self.should_check_val_fx(batch_idx, is_last_batch)
if should_check_val:
self.trainer.run_evaluation()
val_loop_called = True

# reset stage to train
self.trainer._running_stage = RunningStage.TRAINING
Expand Down Expand Up @@ -560,21 +562,23 @@ def run_training_epoch(self):
)

should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True)
if should_check_val:
self.trainer.run_evaluation(on_epoch=True)

# reset stage to train
self.trainer._running_stage = RunningStage.TRAINING

should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval

if should_train_only:
# update epoch level lr_schedulers
# update epoch level lr_schedulers if no val loop outside train loop is triggered
if (val_loop_called and not should_check_val) or should_train_only:
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')

if should_train_only:
self.check_checkpoint_callback(True)
self.check_early_stopping_callback(True)

if should_check_val:
self.trainer.run_evaluation(on_epoch=True)
Comment on lines +570 to +575
Copy link
Contributor

@ananthsub ananthsub Mar 22, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these checks guaranteed to be mutually exclusive? i'm updating to 1.2.4 and see a test failure with one of my modules where it looks like there's a gap:
before: run_evaluation ran first, allowing the module to run the validation loop. inside validation step/validation epoch end, the module could log metrics. afterward, the checkpoint is force-run. so if the checkpoint configured a metric that was available only during validation, then somehow this still worked.
by moving the run_evaluation to happen after the force checkpoint saving, we fail when looking up the monitor value

i think the correct fix is dropping the training loop force running checkpoint/early stopping callbacks here. they should be part of the callback, but that's a longer term thing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe if validation is disabled or val batches=0 then only we run the force checkpoint else it will call run_evaluation and the the callbacks will be called inside and monitor will be taken care of as expected. Can you paste a small example with your case where it doesn't work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not, see #7207

this PR introduced the change where we call the evaluation loop after we call the checkpoint/early stopping callbacks from training.

As a result, this check for should_train_only is incomplete - it inherently depends on the evaluation loop to populate num_val_batches correctly. in run_evaluation we set the validation dataloader, but this is too late as the validation dataloader is what's used to determine should_skip_eval above


# reset stage to train
self.trainer._running_stage = RunningStage.TRAINING

# increment the global step once
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
Expand Down Expand Up @@ -820,7 +824,7 @@ def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False):
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
can_check_val = self.trainer.enable_validation and is_val_check_epoch
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0

should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop
or is_last_batch_for_infinite_dataset
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/utilities/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,16 @@ def track_train_loss_history(self, batch_idx, loss):
self.saved_train_losses.append(loss_dict)

@enabled_only
def track_lr_schedulers_update(self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None):
def track_lr_schedulers_update(
self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None
):
loss_dict = {
'batch_idx': batch_idx,
'interval': interval,
'scheduler_idx': scheduler_idx,
'epoch': self.trainer.current_epoch,
'monitor_key': monitor_key,
'monitor_val': monitor_val,
'old_lr': old_lr,
'new_lr': new_lr
}
Expand Down
149 changes: 136 additions & 13 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch
import yaml
from omegaconf import Container, OmegaConf
from torch import optim

import pytorch_lightning as pl
import tests.helpers.utils as tutils
Expand All @@ -47,8 +48,8 @@ def training_step(self, batch, batch_idx):

def validation_epoch_end(self, outputs):
outs = torch.stack([x['x'] for x in outputs]).mean()
self.log('epoch', self.current_epoch, on_epoch=True)
self.log('val_acc', outs, on_epoch=True)
self.log('epoch', self.current_epoch)
self.log('val_acc', outs)


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
Expand All @@ -57,14 +58,16 @@ def validation_epoch_end(self, outputs):
[('base', "base", 'val_log'), ('base', "base", 'train_log_epoch'), (None, "base", 'train_log_epoch'),
("base", None, 'train_log_epoch')],
)
def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step, val_dataloaders, monitor):
@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True])
def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor, reduce_lr_on_plateau):
"""
Test that when a model checkpoint is saved, it saves with
the correct score appended to ckpt_path and checkpoint data
"""
max_epochs = 3
limit_train_batches = 5
limit_val_batches = 7
lr = 1e-1

class CustomBoringModel(BoringModel):

Expand All @@ -74,21 +77,28 @@ def __init__(self):
self.val_logs = torch.randn(max_epochs, limit_val_batches)

def training_step(self, batch, batch_idx):
out = super().training_step(batch, batch_idx)
log_value = self.train_log_epochs[self.current_epoch, batch_idx]
self.log('train_log', log_value, on_epoch=True)
return out
return super().training_step(batch, batch_idx)

def validation_step(self, batch, batch_idx):
out = super().validation_step(batch, batch_idx)
log_value = self.val_logs[self.current_epoch, batch_idx]
self.log('val_log', log_value)
self.log('epoch', self.current_epoch, on_epoch=True)
return out
return super().validation_step(batch, batch_idx)

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.2)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
optimizer = optim.SGD(self.parameters(), lr=lr)

if reduce_lr_on_plateau:
lr_scheduler = {
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
'monitor': monitor,
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)

return [optimizer], [lr_scheduler]

filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
Expand All @@ -109,11 +119,15 @@ def configure_optimizers(self):
max_epochs=max_epochs,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
results = trainer.fit(model)
assert results
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
assert len(ckpt_files) == len(scores) == max_epochs
assert len(lr_scheduler_debug) == max_epochs

for epoch in range(max_epochs):
score = scores[epoch]
Expand All @@ -130,9 +144,118 @@ def configure_optimizers(self):
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score

lr_scheduler_specific_data = chk['lr_schedulers'][0]
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1**(epoch + 1))
if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
assert lr_scheduler_specific_data['_step_count'] == epoch + 2
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1))

assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.parametrize(
"val_check_interval,reduce_lr_on_plateau",
[
(0.25, True),
(0.25, False),
(0.33, False),
],
)
def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, reduce_lr_on_plateau):
"""
Test that when a model checkpoint is saved, it saves with the correct
score appended to ckpt_path and checkpoint data with val_check_interval
"""
max_epochs = 3
limit_train_batches = 12
limit_val_batches = 7
lr = 1e-1
monitor = 'val_log'
per_epoch_steps = int(limit_train_batches * val_check_interval)
per_epoch_call_count = limit_train_batches // per_epoch_steps

class CustomBoringModel(BoringModel):

def __init__(self):
super().__init__()
self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches)
self.val_loop_count = 0

def validation_step(self, batch, batch_idx):
log_value = self.val_logs[self.val_loop_count, batch_idx]
self.log('val_log', log_value)
self.log('epoch', self.current_epoch, on_epoch=True)
return super().validation_step(batch, batch_idx)

def validation_epoch_end(self, outputs):
self.val_loop_count += 1
super().validation_epoch_end(outputs)

def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(), lr=lr)

if reduce_lr_on_plateau:
lr_scheduler = {
'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer),
'monitor': monitor,
'strict': True,
}
else:
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)

return [optimizer], [lr_scheduler]

filename = '{' + f'{monitor}' + ':.4f}-{epoch}'
checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1)

model = CustomBoringModel()

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[checkpoint],
limit_train_batches=limit_train_batches,
limit_val_batches=limit_val_batches,
max_epochs=max_epochs,
val_check_interval=val_check_interval,
progress_bar_refresh_rate=0,
num_sanity_val_steps=0,
)
results = trainer.fit(model)
assert results
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"

ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs
assert len(lr_scheduler_debug) == max_epochs

for epoch in range(max_epochs):
for ix in range(per_epoch_call_count):
global_ix = ix + per_epoch_call_count * epoch
score = scores[global_ix]
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt'
assert math.isclose(score, expected_score, rel_tol=1e-4)

chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk['epoch'] == epoch + 1
assert chk['global_step'] == per_epoch_steps * (global_ix + 1)

mc_specific_data = chk['callbacks'][type(checkpoint)]
assert mc_specific_data['dirpath'] == checkpoint.dirpath
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score

if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
did_update = 1 if ix + 1 == per_epoch_call_count else 0
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))

assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)


@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
Expand Down
5 changes: 3 additions & 2 deletions tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_optimizer_with_scheduling(tmpdir):
max_epochs=1,
limit_val_batches=0.1,
limit_train_batches=0.2,
val_check_interval=0.5
)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
Expand Down Expand Up @@ -164,15 +165,15 @@ def test_reducelronplateau_scheduling(tmpdir):
model.configure_optimizers = lambda: {
'optimizer': optimizer,
'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer),
'monitor': 'early_stop_on',
'monitor': 'val_acc',
}
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
lr_scheduler = trainer.lr_schedulers[0]
assert lr_scheduler == dict(
scheduler=lr_scheduler['scheduler'],
monitor='early_stop_on',
monitor='val_acc',
interval='epoch',
frequency=1,
reduce_on_plateau=True,
Expand Down