Skip to content

Commit

Permalink
Fix for incorrect run on the validation set with overwritten validati…
Browse files Browse the repository at this point in the history
…on_epoch_end and test_end (Lightning-AI#1353)

* reorder if clauses

* fix wrong method overload in test

* fix formatting

* update change_log

* fix line too long
  • Loading branch information
Adrian Wälchli authored and tullie committed May 6, 2020
1 parent d841f03 commit bcae6b8
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue where `val_percent_check=0` would not disable validation ([#1251](https://github.com/PyTorchLightning/pytorch-lightning/pull/1251))
- Fixed average of incomplete `TensorRunningMean` ([#1309](https://github.com/PyTorchLightning/pytorch-lightning/pull/1309))
- Fixed an issue with early stopping that would prevent it from monitoring training metrics when validation is disabled / not implemented ([#1235](https://github.com/PyTorchLightning/pytorch-lightning/pull/1235)).
- Fixed a bug that would cause `trainer.test()` to run on the validation set when overloading `validation_epoch_end ` and `test_end` ([#1353](https://github.com/PyTorchLightning/pytorch-lightning/pull/1353)).

## [0.7.1] - 2020-03-07

Expand Down
33 changes: 19 additions & 14 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,20 +295,25 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: int, test_
if isinstance(model, (LightningDistributedDataParallel, LightningDataParallel)):
model = model.module

# TODO: remove in v1.0.0
if test_mode and self.is_overriden('test_end', model=model):
eval_results = model.test_end(outputs)
warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use `test_epoch_end` instead.', DeprecationWarning)
elif self.is_overriden('validation_end', model=model):
eval_results = model.validation_end(outputs)
warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use `validation_epoch_end` instead.', DeprecationWarning)

if test_mode and self.is_overriden('test_epoch_end', model=model):
eval_results = model.test_epoch_end(outputs)
elif self.is_overriden('validation_epoch_end', model=model):
eval_results = model.validation_epoch_end(outputs)
if test_mode:
if self.is_overriden('test_end', model=model):
# TODO: remove in v1.0.0
eval_results = model.test_end(outputs)
warnings.warn('Method `test_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use `test_epoch_end` instead.', DeprecationWarning)

elif self.is_overriden('test_epoch_end', model=model):
eval_results = model.test_epoch_end(outputs)

else:
if self.is_overriden('validation_end', model=model):
# TODO: remove in v1.0.0
eval_results = model.validation_end(outputs)
warnings.warn('Method `validation_end` was deprecated in 0.7.0 and will be removed 1.0.0.'
' Use `validation_epoch_end` instead.', DeprecationWarning)

elif self.is_overriden('validation_epoch_end', model=model):
eval_results = model.validation_epoch_end(outputs)

# enable train mode again
model.train()
Expand Down
20 changes: 12 additions & 8 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,15 +528,15 @@ def test_disabled_validation():
class CurrentModel(LightTrainDataloader, LightValidationMixin, TestModelBase):

validation_step_invoked = False
validation_end_invoked = False
validation_epoch_end_invoked = False

def validation_step(self, *args, **kwargs):
self.validation_step_invoked = True
return super().validation_step(*args, **kwargs)

def validation_end(self, *args, **kwargs):
self.validation_end_invoked = True
return super().validation_end(*args, **kwargs)
def validation_epoch_end(self, *args, **kwargs):
self.validation_epoch_end_invoked = True
return super().validation_epoch_end(*args, **kwargs)

hparams = tutils.get_default_hparams()
model = CurrentModel(hparams)
Expand All @@ -555,8 +555,10 @@ def validation_end(self, *args, **kwargs):
# check that val_percent_check=0 turns off validation
assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 1
assert not model.validation_step_invoked, '`validation_step` should not run when `val_percent_check=0`'
assert not model.validation_end_invoked, '`validation_end` should not run when `val_percent_check=0`'
assert not model.validation_step_invoked, \
'`validation_step` should not run when `val_percent_check=0`'
assert not model.validation_epoch_end_invoked, \
'`validation_epoch_end` should not run when `val_percent_check=0`'

# check that val_percent_check has no influence when fast_dev_run is turned on
model = CurrentModel(hparams)
Expand All @@ -566,8 +568,10 @@ def validation_end(self, *args, **kwargs):

assert result == 1, 'training failed to complete'
assert trainer.current_epoch == 0
assert model.validation_step_invoked, 'did not run `validation_step` with `fast_dev_run=True`'
assert model.validation_end_invoked, 'did not run `validation_end` with `fast_dev_run=True`'
assert model.validation_step_invoked, \
'did not run `validation_step` with `fast_dev_run=True`'
assert model.validation_epoch_end_invoked, \
'did not run `validation_epoch_end` with `fast_dev_run=True`'


def test_nan_loss_detection(tmpdir):
Expand Down

0 comments on commit bcae6b8

Please sign in to comment.