Skip to content

Commit

Permalink
add current_epoch to dumped_params (#3261)
Browse files Browse the repository at this point in the history
* add current epoch to __dumped_params

* log

* reset

* add to test

* Update CHANGELOG.md

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
4 people authored Oct 6, 2020
1 parent 69833da commit 39b3704
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed counter-intuitive error being thrown in `Accuracy` metric for zero target tensor ([#3764](https://github.com/PyTorchLightning/pytorch-lightning/pull/3764))

- Fixed Tuner dump: add `current_epoch` to dumped_params ([#3261](https://github.com/PyTorchLightning/pytorch-lightning/pull/3261))

- Fixed aggregation of metrics ([#3517](https://github.com/PyTorchLightning/pytorch-lightning/pull/3517))

- Fixed `current_epoch` and `global_step` properties mismatch between `Trainer` and `LightningModule` ([#3785](https://github.com/PyTorchLightning/pytorch-lightning/pull/3785))
Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/tuner/batch_size_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __scale_batch_dump_params(trainer):
# Prevent going into infinite loop
trainer.__dumped_params = {
'auto_lr_find': trainer.auto_lr_find,
'current_epoch': trainer.current_epoch,
'max_steps': trainer.max_steps,
'weights_summary': trainer.weights_summary,
'logger': trainer.logger,
Expand All @@ -138,6 +139,7 @@ def __scale_batch_dump_params(trainer):
def __scale_batch_reset_params(trainer, model, steps_per_trial):
trainer.auto_scale_batch_size = None # prevent recursion
trainer.auto_lr_find = False # avoid lr find being called multiple times
trainer.current_epoch = 0
trainer.max_steps = steps_per_trial # take few steps
trainer.weights_summary = None # not needed before full run
trainer.logger = DummyLogger()
Expand All @@ -151,6 +153,7 @@ def __scale_batch_reset_params(trainer, model, steps_per_trial):

def __scale_batch_restore_params(trainer):
trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find']
trainer.current_epoch = trainer.__dumped_params['current_epoch']
trainer.max_steps = trainer.__dumped_params['max_steps']
trainer.weights_summary = trainer.__dumped_params['weights_summary']
trainer.logger = trainer.__dumped_params['logger']
Expand Down
3 changes: 2 additions & 1 deletion tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,8 @@ def test_trainer_reset_correctly(tmpdir):
'callbacks',
'checkpoint_callback',
'early_stop_callback',
'limit_train_batches']
'limit_train_batches',
'current_epoch']

attributes_before = {}
for ca in changed_attributes:
Expand Down

0 comments on commit 39b3704

Please sign in to comment.