From eeb9787aab186633caec81d8e5f4ccb1bf9d19d8 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 19 Feb 2021 09:17:22 +0100 Subject: [PATCH 1/9] fix bug --- pytorch_lightning/trainer/trainer.py | 6 +----- pytorch_lightning/trainer/training_loop.py | 7 ++++--- tests/trainer/optimization/test_optimizers.py | 1 + 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cf3bfd7a3e5a3..e3837a3b7bfe9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -666,7 +666,7 @@ def run_train(self): # hook self.train_loop.on_train_end() - def run_evaluation(self, max_batches=None, on_epoch=False): + def run_evaluation(self, max_batches=None): # used to know if we are logging for val, test + reset cached results self._running_stage = RunningStage.TESTING if self.testing else RunningStage.EVALUATING @@ -739,10 +739,6 @@ def run_evaluation(self, max_batches=None, on_epoch=False): # hook self.evaluation_loop.on_evaluation_epoch_end() - # update epoch-level lr_schedulers - if on_epoch: - self.optimizer_connector.update_learning_rates(interval='epoch') - # hook self.evaluation_loop.on_evaluation_end() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index d2298c8c4e860..3711186642121 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -561,7 +561,7 @@ def run_training_epoch(self): should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) if should_check_val: - self.trainer.run_evaluation(on_epoch=True) + self.trainer.run_evaluation() # reset stage to train self.trainer._running_stage = RunningStage.TRAINING @@ -569,9 +569,10 @@ def run_training_epoch(self): should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval + # update epoch level lr_schedulers + self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_train_only: - # update epoch level lr_schedulers - self.trainer.optimizer_connector.update_learning_rates(interval='epoch') self.check_checkpoint_callback(True) self.check_early_stopping_callback(True) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 7172b2dca76da..c32ed35a99c20 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -34,6 +34,7 @@ def test_optimizer_with_scheduling(tmpdir): max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, + val_check_interval=0.5 ) trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" From c99d769b792adb1f8f7860e13cba0b5c5d0cc720 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 19 Feb 2021 11:55:33 +0100 Subject: [PATCH 2/9] fix tests --- pytorch_lightning/trainer/training_loop.py | 13 +++++++++---- tests/checkpointing/test_model_checkpoint.py | 2 +- tests/trainer/optimization/test_optimizers.py | 4 ++-- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 3711186642121..c1f162f21fe29 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -480,6 +480,7 @@ def run_training_epoch(self): train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 should_check_val = False + val_loop_called = False for batch_idx, (batch, is_last_batch) in train_dataloader: @@ -518,6 +519,7 @@ def run_training_epoch(self): # reset stage to train self.trainer._running_stage = RunningStage.TRAINING + val_loop_called = True # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -558,7 +560,7 @@ def run_training_epoch(self): self.trainer.logger_connector.log_train_epoch_end_metrics( epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers ) - + should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) if should_check_val: self.trainer.run_evaluation() @@ -566,13 +568,16 @@ def run_training_epoch(self): # reset stage to train self.trainer._running_stage = RunningStage.TRAINING + if should_check_val or val_loop_called: + # update epoch level lr_schedulers + self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval - # update epoch level lr_schedulers - self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - if should_train_only: + # update epoch level lr_schedulers + self.trainer.optimizer_connector.update_learning_rates(interval='epoch') self.check_checkpoint_callback(True) self.check_early_stopping_callback(True) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index bd4a02536c5c3..a922c152caf7c 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -131,7 +131,7 @@ def configure_optimizers(self): assert mc_specific_data['current_score'] == score lr_scheduler_specific_data = chk['lr_schedulers'][0] - assert lr_scheduler_specific_data['_step_count'] == epoch + 2 + assert lr_scheduler_specific_data['_step_count'] == epoch + 1 assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1**(epoch + 1)) diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index c32ed35a99c20..36713de792f11 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -165,7 +165,7 @@ def test_reducelronplateau_scheduling(tmpdir): model.configure_optimizers = lambda: { 'optimizer': optimizer, 'lr_scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer), - 'monitor': 'early_stop_on', + 'monitor': 'val_acc', } trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) @@ -173,7 +173,7 @@ def test_reducelronplateau_scheduling(tmpdir): lr_scheduler = trainer.lr_schedulers[0] assert lr_scheduler == dict( scheduler=lr_scheduler['scheduler'], - monitor='early_stop_on', + monitor='val_acc', interval='epoch', frequency=1, reduce_on_plateau=True, From 8781a23e2d7ce22ab89f1b4239bf3307f7765824 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 19 Feb 2021 11:56:54 +0100 Subject: [PATCH 3/9] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dfc9a1c021ac..e14733c0c632c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -193,6 +193,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed wrapping optimizers upon assignment ([#6006](https://github.com/PyTorchLightning/pytorch-lightning/pull/6006)) - Fixed allowing hashing of metrics with lists in their state ([#5939](https://github.com/PyTorchLightning/pytorch-lightning/pull/5939)) +- Fixed epoch level schedulers not being called when `val_check_interval!=1` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075)) + ## [1.1.8] - 2021-02-08 From 961f7f89222f0ae93286ed6207b9b11ceb4e9cff Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 19 Feb 2021 11:57:43 +0100 Subject: [PATCH 4/9] fix pep8 --- pytorch_lightning/trainer/training_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index c1f162f21fe29..72149f2bf4696 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -560,7 +560,7 @@ def run_training_epoch(self): self.trainer.logger_connector.log_train_epoch_end_metrics( epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers ) - + should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) if should_check_val: self.trainer.run_evaluation() From 6830344400e67bc3239052c57d131bcbd9a4548b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 19 Feb 2021 12:37:35 +0100 Subject: [PATCH 5/9] fix tests --- pytorch_lightning/trainer/trainer.py | 6 +++++- pytorch_lightning/trainer/training_loop.py | 12 ++++++------ tests/checkpointing/test_model_checkpoint.py | 2 +- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e3837a3b7bfe9..cf3bfd7a3e5a3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -666,7 +666,7 @@ def run_train(self): # hook self.train_loop.on_train_end() - def run_evaluation(self, max_batches=None): + def run_evaluation(self, max_batches=None, on_epoch=False): # used to know if we are logging for val, test + reset cached results self._running_stage = RunningStage.TESTING if self.testing else RunningStage.EVALUATING @@ -739,6 +739,10 @@ def run_evaluation(self, max_batches=None): # hook self.evaluation_loop.on_evaluation_epoch_end() + # update epoch-level lr_schedulers + if on_epoch: + self.optimizer_connector.update_learning_rates(interval='epoch') + # hook self.evaluation_loop.on_evaluation_end() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 72149f2bf4696..7dcedc88f43d2 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -516,10 +516,10 @@ def run_training_epoch(self): should_check_val = self.should_check_val_fx(batch_idx, is_last_batch) if should_check_val: self.trainer.run_evaluation() + val_loop_called = True # reset stage to train self.trainer._running_stage = RunningStage.TRAINING - val_loop_called = True # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -562,16 +562,16 @@ def run_training_epoch(self): ) should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) + + if val_loop_called: + self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_check_val: - self.trainer.run_evaluation() + self.trainer.run_evaluation(on_epoch=True) # reset stage to train self.trainer._running_stage = RunningStage.TRAINING - if should_check_val or val_loop_called: - # update epoch level lr_schedulers - self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index a922c152caf7c..bd4a02536c5c3 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -131,7 +131,7 @@ def configure_optimizers(self): assert mc_specific_data['current_score'] == score lr_scheduler_specific_data = chk['lr_schedulers'][0] - assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + assert lr_scheduler_specific_data['_step_count'] == epoch + 2 assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1**(epoch + 1)) From fc93a3e367151a5922c557892988953ed028fb70 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sat, 20 Feb 2021 02:34:04 +0530 Subject: [PATCH 6/9] fix and add some tests --- .../trainer/connectors/optimizer_connector.py | 2 + pytorch_lightning/trainer/training_loop.py | 20 ++-- tests/checkpointing/test_model_checkpoint.py | 99 ++++++++++++++++--- 3 files changed, 99 insertions(+), 22 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 5fb7b698b1669..18aae62681a5f 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -71,10 +71,12 @@ def update_learning_rates(self, interval: str, monitor_metrics=None): continue # update LR old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] + if lr_scheduler['reduce_on_plateau']: lr_scheduler['scheduler'].step(monitor_val) else: lr_scheduler['scheduler'].step() + new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] if self.trainer.dev_debugger.enabled: diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 7dcedc88f43d2..2f060be70c326 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -562,25 +562,23 @@ def run_training_epoch(self): ) should_check_val = self.should_check_val_fx(batch_idx, is_last_batch, on_epoch=True) + should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) + should_train_only = self.trainer.disable_validation or should_skip_eval - if val_loop_called: + # update epoch level lr_schedulers if no val loop outside train loop is triggered + if (val_loop_called and not should_check_val) or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + if should_train_only: + self.check_checkpoint_callback(True) + self.check_early_stopping_callback(True) + if should_check_val: self.trainer.run_evaluation(on_epoch=True) # reset stage to train self.trainer._running_stage = RunningStage.TRAINING - should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) - should_train_only = self.trainer.disable_validation or should_skip_eval - - if should_train_only: - # update epoch level lr_schedulers - self.trainer.optimizer_connector.update_learning_rates(interval='epoch') - self.check_checkpoint_callback(True) - self.check_early_stopping_callback(True) - # increment the global step once # progress global step according to grads progress self.increment_accumulated_grad_global_step() @@ -826,7 +824,7 @@ def should_check_val_fx(self, batch_idx, is_last_batch, on_epoch=False): is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 can_check_val = self.trainer.enable_validation and is_val_check_epoch is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") - epoch_end_val_check = self.trainer.val_check_batch == self.trainer.num_training_batches + epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 should_check_val = ((is_val_check_batch and epoch_end_val_check) or self.trainer.should_stop or is_last_batch_for_infinite_dataset diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index bd4a02536c5c3..0913beb9d0f4d 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -57,7 +57,7 @@ def validation_epoch_end(self, outputs): [('base', "base", 'val_log'), ('base', "base", 'train_log_epoch'), (None, "base", 'train_log_epoch'), ("base", None, 'train_log_epoch')], ) -def test_model_checkpoint_correct_score_and_checkpoint(tmpdir, validation_step, val_dataloaders, monitor): +def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor): """ Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and checkpoint data @@ -74,22 +74,15 @@ def __init__(self): self.val_logs = torch.randn(max_epochs, limit_val_batches) def training_step(self, batch, batch_idx): - out = super().training_step(batch, batch_idx) log_value = self.train_log_epochs[self.current_epoch, batch_idx] self.log('train_log', log_value, on_epoch=True) - return out + return super().training_step(batch, batch_idx) def validation_step(self, batch, batch_idx): - out = super().validation_step(batch, batch_idx) log_value = self.val_logs[self.current_epoch, batch_idx] self.log('val_log', log_value) self.log('epoch', self.current_epoch, on_epoch=True) - return out - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.2) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) - return [optimizer], [lr_scheduler] + return super().validation_step(batch, batch_idx) filename = '{' + f'{monitor}' + ':.4f}-{epoch}' checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) @@ -114,6 +107,7 @@ def configure_optimizers(self): ckpt_files = list(Path(tmpdir).glob('*.ckpt')) scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric] assert len(ckpt_files) == len(scores) == max_epochs + assert len(trainer.dev_debugger.saved_lr_scheduler_updates) == max_epochs for epoch in range(max_epochs): score = scores[epoch] @@ -132,7 +126,90 @@ def configure_optimizers(self): lr_scheduler_specific_data = chk['lr_schedulers'][0] assert lr_scheduler_specific_data['_step_count'] == epoch + 2 - assert lr_scheduler_specific_data['_last_lr'][0], 4 == 0.2 * (0.1**(epoch + 1)) + assert lr_scheduler_specific_data['_last_lr'][0] == 0.1 * (0.1**(epoch + 1)) + + +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +@pytest.mark.parametrize( + "val_check_interval,lr_sched_step_count_inc", + [ + (0.25, 1), + (0.33, 0), + ], +) +def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, lr_sched_step_count_inc): + """ + Test that when a model checkpoint is saved, it saves with the correct + score appended to ckpt_path and checkpoint data with val_check_interval + """ + max_epochs = 3 + limit_train_batches = 12 + limit_val_batches = 7 + monitor = 'val_log' + per_epoch_steps = int(limit_train_batches * val_check_interval) + per_epoch_call_count = limit_train_batches // per_epoch_steps + + class CustomBoringModel(BoringModel): + + def __init__(self): + super().__init__() + self.val_logs = torch.randn(per_epoch_call_count * max_epochs, limit_val_batches) + self.val_loop_count = 0 + + def validation_step(self, batch, batch_idx): + log_value = self.val_logs[self.val_loop_count, batch_idx] + self.log('val_log', log_value) + self.log('epoch', self.current_epoch, on_epoch=True) + return super().validation_step(batch, batch_idx) + + def validation_epoch_end(self, outputs): + self.val_loop_count += 1 + super().validation_epoch_end(outputs) + + filename = '{' + f'{monitor}' + ':.4f}-{epoch}' + checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) + + model = CustomBoringModel() + + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint], + limit_train_batches=limit_train_batches, + limit_val_batches=limit_val_batches, + max_epochs=max_epochs, + val_check_interval=val_check_interval, + progress_bar_refresh_rate=0, + num_sanity_val_steps=0, + ) + trainer.fit(model) + + ckpt_files = list(Path(tmpdir).glob('*.ckpt')) + scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric] + assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs + assert len(trainer.dev_debugger.saved_lr_scheduler_updates) == max_epochs + + for epoch in range(max_epochs): + for ix in range(per_epoch_call_count): + global_ix = ix + per_epoch_call_count * epoch + score = scores[global_ix] + expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item() + expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt' + assert math.isclose(score, expected_score, rel_tol=1e-4) + + chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) + assert chk['epoch'] == epoch + 1 + assert chk['global_step'] == per_epoch_steps * (global_ix + 1) + + mc_specific_data = chk['callbacks'][type(checkpoint)] + assert mc_specific_data['dirpath'] == checkpoint.dirpath + assert mc_specific_data['monitor'] == monitor + assert mc_specific_data['current_score'] == score + + lr_scheduler_specific_data = chk['lr_schedulers'][0] + + did_update = 1 if ix + 1 == per_epoch_call_count else 0 + assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update + assert lr_scheduler_specific_data['_last_lr'][0] == 0.1 * (0.1**(epoch + did_update)) @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) From ba124a2a8e869c59d6e70b38bba99fd20fbdc2e9 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Mon, 22 Feb 2021 23:39:25 +0530 Subject: [PATCH 7/9] add test for rlop --- .../trainer/connectors/optimizer_connector.py | 8 +- pytorch_lightning/utilities/debugging.py | 5 +- tests/checkpointing/test_model_checkpoint.py | 82 +++++++++++++++---- 3 files changed, 75 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 18aae62681a5f..48958dcf45ddd 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -81,5 +81,11 @@ def update_learning_rates(self, interval: str, monitor_metrics=None): if self.trainer.dev_debugger.enabled: self.trainer.dev_debugger.track_lr_schedulers_update( - self.trainer.batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=monitor_key + self.trainer.batch_idx, + interval, + scheduler_idx, + old_lr, + new_lr, + monitor_key=monitor_key, + monitor_val=monitor_val ) diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py index 5a5157d9e23f7..65cf4472d156c 100644 --- a/pytorch_lightning/utilities/debugging.py +++ b/pytorch_lightning/utilities/debugging.py @@ -121,13 +121,16 @@ def track_train_loss_history(self, batch_idx, loss): self.saved_train_losses.append(loss_dict) @enabled_only - def track_lr_schedulers_update(self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None): + def track_lr_schedulers_update( + self, batch_idx, interval, scheduler_idx, old_lr, new_lr, monitor_key=None, monitor_val=None + ): loss_dict = { 'batch_idx': batch_idx, 'interval': interval, 'scheduler_idx': scheduler_idx, 'epoch': self.trainer.current_epoch, 'monitor_key': monitor_key, + 'monitor_val': monitor_val, 'old_lr': old_lr, 'new_lr': new_lr } diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 0913beb9d0f4d..bde4c8939195b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -26,6 +26,7 @@ import torch import yaml from omegaconf import Container, OmegaConf +from torch import optim import pytorch_lightning as pl import tests.helpers.utils as tutils @@ -47,8 +48,8 @@ def training_step(self, batch, batch_idx): def validation_epoch_end(self, outputs): outs = torch.stack([x['x'] for x in outputs]).mean() - self.log('epoch', self.current_epoch, on_epoch=True) - self.log('val_acc', outs, on_epoch=True) + self.log('epoch', self.current_epoch) + self.log('val_acc', outs) @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @@ -57,7 +58,8 @@ def validation_epoch_end(self, outputs): [('base', "base", 'val_log'), ('base', "base", 'train_log_epoch'), (None, "base", 'train_log_epoch'), ("base", None, 'train_log_epoch')], ) -def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor): +@pytest.mark.parametrize('reduce_lr_on_plateau', [False, True]) +def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloaders, monitor, reduce_lr_on_plateau): """ Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and checkpoint data @@ -65,6 +67,7 @@ def test_model_checkpoint_score_and_ckpt(tmpdir, validation_step, val_dataloader max_epochs = 3 limit_train_batches = 5 limit_val_batches = 7 + lr = 1e-1 class CustomBoringModel(BoringModel): @@ -84,6 +87,20 @@ def validation_step(self, batch, batch_idx): self.log('epoch', self.current_epoch, on_epoch=True) return super().validation_step(batch, batch_idx) + def configure_optimizers(self): + optimizer = optim.SGD(self.parameters(), lr=lr) + + if reduce_lr_on_plateau: + lr_scheduler = { + 'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), + 'monitor': monitor, + 'strict': True, + } + else: + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1) + + return [optimizer], [lr_scheduler] + filename = '{' + f'{monitor}' + ':.4f}-{epoch}' checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) @@ -102,12 +119,15 @@ def validation_step(self, batch, batch_idx): max_epochs=max_epochs, progress_bar_refresh_rate=0, ) - trainer.fit(model) + results = trainer.fit(model) + assert results + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" ckpt_files = list(Path(tmpdir).glob('*.ckpt')) scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric] + lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates assert len(ckpt_files) == len(scores) == max_epochs - assert len(trainer.dev_debugger.saved_lr_scheduler_updates) == max_epochs + assert len(lr_scheduler_debug) == max_epochs for epoch in range(max_epochs): score = scores[epoch] @@ -124,20 +144,25 @@ def validation_step(self, batch, batch_idx): assert mc_specific_data['monitor'] == monitor assert mc_specific_data['current_score'] == score - lr_scheduler_specific_data = chk['lr_schedulers'][0] - assert lr_scheduler_specific_data['_step_count'] == epoch + 2 - assert lr_scheduler_specific_data['_last_lr'][0] == 0.1 * (0.1**(epoch + 1)) + if not reduce_lr_on_plateau: + lr_scheduler_specific_data = chk['lr_schedulers'][0] + assert lr_scheduler_specific_data['_step_count'] == epoch + 2 + assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + 1)) + + assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None) + assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None) @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize( - "val_check_interval,lr_sched_step_count_inc", + "val_check_interval,reduce_lr_on_plateau", [ - (0.25, 1), - (0.33, 0), + (0.25, True), + (0.25, False), + (0.33, False), ], ) -def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, lr_sched_step_count_inc): +def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, reduce_lr_on_plateau): """ Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and checkpoint data with val_check_interval @@ -145,6 +170,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_in max_epochs = 3 limit_train_batches = 12 limit_val_batches = 7 + lr = 1e-1 monitor = 'val_log' per_epoch_steps = int(limit_train_batches * val_check_interval) per_epoch_call_count = limit_train_batches // per_epoch_steps @@ -166,6 +192,20 @@ def validation_epoch_end(self, outputs): self.val_loop_count += 1 super().validation_epoch_end(outputs) + def configure_optimizers(self): + optimizer = optim.SGD(self.parameters(), lr=lr) + + if reduce_lr_on_plateau: + lr_scheduler = { + 'scheduler': optim.lr_scheduler.ReduceLROnPlateau(optimizer), + 'monitor': monitor, + 'strict': True, + } + else: + lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1) + + return [optimizer], [lr_scheduler] + filename = '{' + f'{monitor}' + ':.4f}-{epoch}' checkpoint = ModelCheckpoint(dirpath=tmpdir, filename=filename, monitor=monitor, save_top_k=-1) @@ -181,12 +221,15 @@ def validation_epoch_end(self, outputs): progress_bar_refresh_rate=0, num_sanity_val_steps=0, ) - trainer.fit(model) + results = trainer.fit(model) + assert results + assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" ckpt_files = list(Path(tmpdir).glob('*.ckpt')) scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric] + lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs - assert len(trainer.dev_debugger.saved_lr_scheduler_updates) == max_epochs + assert len(lr_scheduler_debug) == max_epochs for epoch in range(max_epochs): for ix in range(per_epoch_call_count): @@ -205,11 +248,14 @@ def validation_epoch_end(self, outputs): assert mc_specific_data['monitor'] == monitor assert mc_specific_data['current_score'] == score - lr_scheduler_specific_data = chk['lr_schedulers'][0] + if not reduce_lr_on_plateau: + lr_scheduler_specific_data = chk['lr_schedulers'][0] + did_update = 1 if ix + 1 == per_epoch_call_count else 0 + assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update + assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update)) - did_update = 1 if ix + 1 == per_epoch_call_count else 0 - assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update - assert lr_scheduler_specific_data['_last_lr'][0] == 0.1 * (0.1**(epoch + did_update)) + assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None) + assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None) @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) From f596ab5e57bc4bb629b1176a39d174fcb055c946 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 23 Feb 2021 19:25:09 +0530 Subject: [PATCH 8/9] chlog --- CHANGELOG.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e14733c0c632c..5a603e90315da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,7 +39,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107)) -- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115) +- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)) + + +- Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075)) ## [1.2.0] - 2021-02-18 @@ -193,8 +196,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed wrapping optimizers upon assignment ([#6006](https://github.com/PyTorchLightning/pytorch-lightning/pull/6006)) - Fixed allowing hashing of metrics with lists in their state ([#5939](https://github.com/PyTorchLightning/pytorch-lightning/pull/5939)) -- Fixed epoch level schedulers not being called when `val_check_interval!=1` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075)) - ## [1.1.8] - 2021-02-08 From cede66a9e91e4ebf1e28cc14dc83aed1a9f96433 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 24 Feb 2021 11:20:20 +0100 Subject: [PATCH 9/9] Update CHANGELOG.md --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1231a7f7a4161..7857ac9543c5a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,9 +38,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)) -- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)) - - - Fixed epoch level schedulers not being called when `val_check_interval < 1.0` ([#6075](https://github.com/PyTorchLightning/pytorch-lightning/pull/6075))