diff --git a/CHANGELOG.md b/CHANGELOG.md index 15e8573f34baa..419d3bc07b47f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -101,6 +101,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764)) +- Fixed Tuner dump: add `current_epoch` to dumped_params ([#3261](https://github.com/PyTorchLightning/pytorch-lightning/pull/3261)) + - Fixed aggregation of metrics ([#3517](https://github.com/PyTorchLightning/pytorch-lightning/pull/3517)) - Fixed `current_epoch` and `global_step` properties mismatch between `Trainer` and `LightningModule` ([#3785](https://github.com/PyTorchLightning/pytorch-lightning/pull/3785)) diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 8b2e05c66b753..ad3a9eb2e55c9 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -123,6 +123,7 @@ def __scale_batch_dump_params(trainer): # Prevent going into infinite loop trainer.__dumped_params = { 'auto_lr_find': trainer.auto_lr_find, + 'current_epoch': trainer.current_epoch, 'max_steps': trainer.max_steps, 'weights_summary': trainer.weights_summary, 'logger': trainer.logger, @@ -138,6 +139,7 @@ def __scale_batch_dump_params(trainer): def __scale_batch_reset_params(trainer, model, steps_per_trial): trainer.auto_scale_batch_size = None # prevent recursion trainer.auto_lr_find = False # avoid lr find being called multiple times + trainer.current_epoch = 0 trainer.max_steps = steps_per_trial # take few steps trainer.weights_summary = None # not needed before full run trainer.logger = DummyLogger() @@ -151,6 +153,7 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial): def __scale_batch_restore_params(trainer): trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find'] + trainer.current_epoch = trainer.__dumped_params['current_epoch'] trainer.max_steps = trainer.__dumped_params['max_steps'] trainer.weights_summary = trainer.__dumped_params['weights_summary'] trainer.logger = trainer.__dumped_params['logger'] diff --git a/tests/trainer/test_trainer_tricks.py b/tests/trainer/test_trainer_tricks.py index a9297576c6f14..3ee47522d9258 100755 --- a/tests/trainer/test_trainer_tricks.py +++ b/tests/trainer/test_trainer_tricks.py @@ -182,7 +182,8 @@ def test_trainer_reset_correctly(tmpdir): 'callbacks', 'checkpoint_callback', 'early_stop_callback', - 'limit_train_batches'] + 'limit_train_batches', + 'current_epoch'] attributes_before = {} for ca in changed_attributes: