Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Trainer.validate(…) method to run one validation epoch #4948

Merged
merged 45 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
edb3e83
Refactor Trainer in advance of implementing Trainer.validate
EliaCereda Dec 2, 2020
03d7994
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Dec 2, 2020
5a54485
Add Trainer.validate(...) method to perform one evaluation epoch over…
EliaCereda Dec 2, 2020
e06775c
Rename methods in Trainer and Accelerator to reflect that they are us…
EliaCereda Dec 2, 2020
b4e409c
Update docs to mention the new Trainer.validate method and associated…
EliaCereda Dec 2, 2020
96e42ba
Add tests for Trainer.validate(…)
EliaCereda Dec 2, 2020
85b3c9f
Update CHANGELOG.md
EliaCereda Dec 2, 2020
39113dc
Merge branch 'master' into feature/trainer-validate-2
tchaton Dec 3, 2020
a6be0d8
Replace usages of Trainer.testing with Trainer.evaluating, should be …
EliaCereda Dec 4, 2020
a922d57
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Dec 4, 2020
595f4e8
Clean up calls to LightningDataModule.setup()
EliaCereda Dec 8, 2020
0b09248
Update test_trainer_validate_loop.py to use BoringModel instead of Ev…
EliaCereda Dec 8, 2020
d691d79
Merge remote-tracking branch 'upstream/master' into feature/trainer-v…
EliaCereda Dec 8, 2020
52eaa70
Merge branch 'feature/trainer-validate-1' into feature/trainer-valida…
EliaCereda Dec 8, 2020
06b4419
Fix ShardedPlugin when evaluating
EliaCereda Dec 8, 2020
e6a8be9
Merge remote-tracking branch 'origin/feature/trainer-validate-1' into…
EliaCereda Dec 8, 2020
389940e
Add tests for Trainer.validate with ShardedPlugin
EliaCereda Dec 8, 2020
6d0a95a
Remove superfluous calls to LoggerConnector.set_stage in validate() a…
EliaCereda Dec 10, 2020
704b121
Update more docstrings to mention Trainer.validate
EliaCereda Dec 10, 2020
f6e0759
Merge branch 'release/1.2-dev' into feature/trainer-validate-1
tchaton Jan 11, 2021
90e59c7
Merge remote-tracking branch 'upstream/release/1.2-dev' into feature/…
EliaCereda Jan 26, 2021
45d7e0a
Merge branch 'feature/trainer-validate-1' into feature/trainer-valida…
EliaCereda Jan 26, 2021
12a85b3
Pass {fit,validate,test,predict} to setup()
carmocca Mar 7, 2021
d49ccd1
Fix doctest
carmocca Mar 7, 2021
23db135
stage: Optional[str] = None
carmocca Mar 7, 2021
84f5fdb
Trailing whitespace
carmocca Mar 7, 2021
188b9fe
Update docs and CHANGELOG
carmocca Mar 7, 2021
37473f0
Mention teardown
carmocca Mar 7, 2021
0a30abf
Self-review
carmocca Mar 7, 2021
0e9d69c
Address Borda's comments
carmocca Mar 7, 2021
04343ce
Merge branch 'deleteme-carmocca' into feature/trainer-validate-2
carmocca Mar 7, 2021
9758c7b
Fixing conflicts
carmocca Mar 7, 2021
18280df
Implement Trainer.validate
carmocca Mar 7, 2021
e582d58
Refactor
carmocca Mar 7, 2021
1a5b620
Merge branch 'master' into feature/trainer-validate-2
carmocca Mar 8, 2021
5b99ec0
flake8
carmocca Mar 8, 2021
9f4dce2
Refactor
carmocca Mar 8, 2021
088d4bc
Missing import
carmocca Mar 8, 2021
58fcca4
Fix test
carmocca Mar 8, 2021
babb73d
Same threshold
carmocca Mar 8, 2021
235dc27
Address tchaton's comments
carmocca Mar 8, 2021
73dd265
Merge branch 'master' into feature/trainer-validate-2
carmocca Mar 8, 2021
e423b98
Missing import
carmocca Mar 10, 2021
cdec83b
Merge branch 'master' into feature/trainer-validate-2
carmocca Mar 10, 2021
8fab50f
Apply suggestions from code review
carmocca Mar 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added compositional metrics ([#5464](https://github.com/PyTorchLightning/pytorch-lightning/pull/5464))


- Added `Trainer.validate()` method to perform one evaluation epoch over the validation set (
[#4948](https://github.com/PyTorchLightning/pytorch-lightning/pull/4948))


### Changed

- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down Expand Up @@ -264,7 +268,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Enabled `self.log` in most functions ([#4969](https://github.com/PyTorchLightning/pytorch-lightning/pull/4969))
- Added changeable extension variable for `ModelCheckpoint` ([#4977](https://github.com/PyTorchLightning/pytorch-lightning/pull/4977))


### Changed

- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))
Expand Down
15 changes: 14 additions & 1 deletion docs/source/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,27 @@ So you can run it like so:

------------

Validation
----------
You can perform an evaluation epoch over the validation set, outside of the training loop,
using :meth:`pytorch_lightning.trainer.trainer.Trainer.validate`. This might be
useful if you want to collect new metrics from a model right at its initialization
or that has already been trained.

.. code-block:: python

trainer.validate(val_dataloaders=val_dataloaders)

------------

Testing
-------
Once you're done training, feel free to run the test set!
(Only right before publishing your paper or pushing to production)

.. code-block:: python

trainer.test(test_dataloaders=test_dataloader)
trainer.test(test_dataloaders=test_dataloaders)

------------

Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ def barrier(self, name: Optional[str] = None):
def broadcast(self, obj, src=0):
return obj

def train_or_test(self):
if self.trainer.testing:
results = self.trainer.run_test()
def train_or_evaluate(self):
if self.trainer.evaluating:
results = self.trainer.run_test_or_validate()
else:
results = self.trainer.train()
return results
Expand Down Expand Up @@ -134,7 +134,7 @@ def early_stopping_should_stop(self, pl_module):
return self.trainer.should_stop

def setup_optimizers(self, model):
if self.trainer.testing:
if self.trainer.evaluating:
return

optimizers, lr_schedulers, optimizer_frequencies = self.trainer.init_optimizers(model)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/cpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def train(self):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()
return results

def _step(self, model_step: Callable, args):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# clean up memory
torch.cuda.empty_cache()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,8 @@ def ddp_train(self, process_idx, model):
self.barrier('ddp_setup')
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# clean up memory
torch.cuda.empty_cache()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def ddp_train(self, process_idx, mp_queue, model):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# get original model
model = self.trainer.get_model()
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def ddp_train(self, process_idx, model):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# clean up memory
torch.cuda.empty_cache()
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def ddp_train(self, process_idx, mp_queue, model, is_master: bool = False, proc_
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# get original model
model = self.trainer.get_model()
Expand Down Expand Up @@ -242,7 +242,7 @@ def __recover_child_process_weights(self, model, best_path, last_path):
# todo, pass also best score

# load last weights
if last_path is not None and not self.trainer.testing:
if last_path is not None and not self.trainer.evaluating:
ckpt = pl_load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

Expand All @@ -261,7 +261,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
if not self.trainer.evaluating and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
atomic_save(model.state_dict(), last_path)
mp_queue.put(last_path)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def train(self):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

return results

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/accelerators/gpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ def train(self):
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

return results

def _step(self, model_step: Callable, args):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/horovod_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def train(self):
# set up training routine
self.trainer.train_loop.setup_training(self.trainer.model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# Make sure all workers have finished training before returning to the user
hvd.join()
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/accelerators/tpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def teardown(self):
# todo, pass also bets score

# load last weights
if last_path and not self.trainer.testing:
if last_path and not self.trainer.evaluating:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)

Expand Down Expand Up @@ -116,7 +116,7 @@ def __load_weights_on_main_process(self):
model = self.trainer.model

# load weights if not interrupted
if self.trainer.on_colab_kaggle and not self.trainer.testing:
if self.trainer.on_colab_kaggle and not self.trainer.evaluating:
self.load_spawn_weights(model)

self.trainer.model = model
Expand All @@ -137,8 +137,8 @@ def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, traine
# set up training routine
self.trainer.train_loop.setup_training(model)

# train or test
results = self.train_or_test()
# train or evaluate
results = self.train_or_evaluate()

# save weights at the end of training
self.__save_end_of_training_weights(model, trainer)
Expand Down Expand Up @@ -322,7 +322,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):

# save the last weights
last_path = None
if not self.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
if not self.trainer.evaluating and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub('.ckpt', '.tmp_end.ckpt', best_model_path)
state_dict = move_data_to_device(model.state_dict(), torch.device("cpu"))
atomic_save(state_dict, last_path)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module):
pass

def setup(self, trainer, pl_module, stage: str):
"""Called when fit or test begins"""
"""Called when fit, validate, or test begins"""
pass

def teardown(self, trainer, pl_module, stage: str):
"""Called when fit or test ends"""
"""Called when fit, validate, or test ends"""
pass

def on_init_start(self, trainer):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,13 +154,13 @@ def on_load_checkpoint(self, checkpointed_state):
self.patience = checkpointed_state['patience']

def on_validation_end(self, trainer, pl_module):
if trainer.running_sanity_check:
if trainer.running_sanity_check or trainer.evaluating:
return

self._run_early_stopping_check(trainer, pl_module)

def on_validation_epoch_end(self, trainer, pl_module):
if trainer.fast_dev_run or trainer.running_sanity_check:
if trainer.fast_dev_run or trainer.running_sanity_check or trainer.evaluating:
return

if self._validate_condition_metric(trainer.callback_metrics):
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def save_checkpoint(self, trainer, pl_module):
or self.period < 1 # no models are saved
or (epoch + 1) % self.period # skip epoch
or trainer.running_sanity_check # don't save anything during sanity check
or trainer.evaluating # don't save anything during evaluation: might delete the checkpoint being evaluated
or self.last_global_step_saved == global_step # already saved at the last step
):
return
Expand Down
22 changes: 18 additions & 4 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,13 @@ def init_train_tqdm(self) -> tqdm:

def init_validation_tqdm(self) -> tqdm:
""" Override this to customize the tqdm bar for validation. """

# The main progress bar doesn't exist in trainer.validate(...)
has_main_bar = int(self.main_progress_bar is not None)

bar = tqdm(
desc='Validating',
position=(2 * self.process_position + 1),
position=(2 * self.process_position + has_main_bar),
disable=self.is_disabled,
leave=False,
dynamic_ncols=True,
Expand Down Expand Up @@ -340,19 +344,29 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
def on_validation_start(self, trainer, pl_module):
super().on_validation_start(trainer, pl_module)
if not trainer.running_sanity_check:
self._update_bar(self.main_progress_bar) # fill up remaining
# The main progress bar doesn't exist in trainer.validate(...)
if self.main_progress_bar is not None:
self._update_bar(self.main_progress_bar) # fill up remaining

self.val_progress_bar = self.init_validation_tqdm()
self.val_progress_bar.total = convert_inf(self.total_val_batches)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
if self._should_update(self.val_batch_idx, self.total_val_batches):
self._update_bar(self.val_progress_bar)
self._update_bar(self.main_progress_bar)

# The main progress bar doesn't exist in trainer.validate(...)
if self.main_progress_bar is not None:
self._update_bar(self.main_progress_bar)

def on_validation_end(self, trainer, pl_module):
super().on_validation_end(trainer, pl_module)
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

# The main progress bar doesn't exist in trainer.validate(...)
if self.main_progress_bar is not None:
self.main_progress_bar.set_postfix(trainer.progress_bar_dict)

self.val_progress_bar.close()

def on_train_end(self, trainer, pl_module):
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,16 @@ def wrapped_fn(*args, **kwargs):
if fn.__name__ == "setup":

# Get stage either by grabbing from args or checking kwargs.
# If not provided, set call status of 'fit' and 'test' to True.
# If not provided, set call status of 'fit', 'validation', and 'test' to True.
# We do this so __attach_datamodule in trainer.py doesn't mistakenly call setup('test') on trainer.test()
stage = args[1] if len(args) > 1 else kwargs.get("stage", None)

if stage == "fit" or stage is None:
obj._has_setup_fit = True

if stage == "validation" or stage is None:
obj._has_setup_validation = True

if stage == "test" or stage is None:
obj._has_setup_test = True

Expand Down Expand Up @@ -155,6 +158,7 @@ def __init__(
# Private attrs to keep track of whether or not data hooks have been called yet
self._has_prepared_data = False
self._has_setup_fit = False
self._has_setup_validation = False
self._has_setup_test = False

@property
Expand Down Expand Up @@ -230,6 +234,15 @@ def has_setup_fit(self):
"""
return self._has_setup_fit

@property
def has_setup_validation(self):
"""Return bool letting you know if datamodule.setup('validation') has been called or not.

Returns:
bool: True if datamodule.setup('validation') has been called. False by default.
"""
return self._has_setup_validation

@property
def has_setup_test(self):
"""Return bool letting you know if datamodule.setup('test') has been called or not.
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ class ModelHooks:
"""Hooks to be used in LightningModule."""
def setup(self, stage: str) -> None:
"""
Called at the beginning of fit and test.
Called at the beginning of fit (training + validation), validation, and test.
This is a good hook when you need to build models dynamically or adjust something about them.
This hook is called on every process when using DDP.

Args:
stage: either 'fit' or 'test'
stage: either 'fit', 'validation', or 'test'

Example::

Expand All @@ -55,10 +55,10 @@ def setup(stage):

def teardown(self, stage: str) -> None:
"""
Called at the end of fit and test.
Called at the end of fit (training + validation), validation, and test.

Args:
stage: either 'fit' or 'test'
stage: either 'fit', 'validation', or 'test'
"""

def on_fit_start(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def on_before_accelerator_backend_setup(self, model):
callback.on_before_accelerator_backend_setup(self, model)

def setup(self, model, stage: str):
"""Called in the beginning of fit and test"""
"""Called in the beginning of fit, validate and test"""
for callback in self.callbacks:
callback.setup(self, model, stage)

def teardown(self, stage: str):
"""Called at the end of fit and test"""
"""Called at the end of fit, validate and test"""
for callback in self.callbacks:
callback.teardown(self, self.get_model(), stage)

Expand Down
Loading