Skip to content

Commit

Permalink
Add on_epoch_start to run at the beginning of every loop irrespective…
Browse files Browse the repository at this point in the history
… of train/val/test (#6498)

* update docs

* add hook and update docs

* update tests

* chlog

* Update CHANGELOG.md

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* chlog

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
rohitgr7 and awaelchli authored Mar 25, 2021
1 parent 40976e4 commit 9be092d
Show file tree
Hide file tree
Showing 15 changed files with 139 additions and 31 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `PyTorchProfiler` to use `torch.autograd.profiler.record_function` to record functions ([#6349](https://github.com/PyTorchLightning/pytorch-lightning/pull/6349))


- Changed the behavior of `on_epoch_start` to run at the beginning of validation & test epoch ([#6498](https://github.com/PyTorchLightning/pytorch-lightning/pull/6498))


### Deprecated

- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
Expand Down
91 changes: 83 additions & 8 deletions docs/source/common/lightning_module.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,7 @@ This is the pseudocode to describe how all the hooks are called during a call to
teardown()
def train_loop():
on_epoch_start()
on_train_epoch_start()
train_outs = []
for train_batch in train_dataloader():
Expand All @@ -1070,12 +1071,15 @@ This is the pseudocode to describe how all the hooks are called during a call to
val_loop()
# end training epoch
logs = training_epoch_end(outs)
outs = training_epoch_end(outs)
on_train_epoch_end(outs)
on_epoch_end()
def val_loop():
model.eval()
torch.set_grad_enabled(False)
on_epoch_start()
on_validation_epoch_start()
val_outs = []
for val_batch in val_dataloader():
Expand All @@ -1089,6 +1093,7 @@ This is the pseudocode to describe how all the hooks are called during a call to
validation_epoch_end(val_outs)
on_validation_epoch_end()
on_epoch_end()
# set up for train
model.train()
Expand Down Expand Up @@ -1116,12 +1121,12 @@ manual_backward
on_after_backward
~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_after_backward
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_after_backward
:noindex:

on_before_zero_grad
~~~~~~~~~~~~~~~~~~~
.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_before_zero_grad
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_before_zero_grad
:noindex:

on_fit_start
Expand All @@ -1140,15 +1145,38 @@ on_fit_end
on_load_checkpoint
~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_load_checkpoint
.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint
:noindex:

on_save_checkpoint
~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.on_save_checkpoint
.. automethod:: pytorch_lightning.core.hooks.CheckpointHooks.on_save_checkpoint
:noindex:

on_train_start
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_start
:noindex:

on_train_end
~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_end
:noindex:

on_validation_start
~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_start
:noindex:

on_validation_end
~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_end
:noindex:

on_pretrain_routine_start
~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down Expand Up @@ -1186,6 +1214,11 @@ on_test_epoch_end
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_epoch_end
:noindex:

on_test_end
~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_end
:noindex:

on_train_batch_start
~~~~~~~~~~~~~~~~~~~~
Expand All @@ -1199,6 +1232,18 @@ on_train_batch_end
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_train_batch_end
:noindex:

on_epoch_start
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_start
:noindex:

on_epoch_end
~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_epoch_end
:noindex:

on_train_epoch_start
~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -1235,6 +1280,36 @@ on_validation_epoch_end
.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_epoch_end
:noindex:

on_post_move_to_device
~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_post_move_to_device
:noindex:

on_validation_model_eval
~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_eval
:noindex:

on_validation_model_train
~~~~~~~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_validation_model_train
:noindex:

on_test_model_eval
~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_eval
:noindex:

on_test_model_train
~~~~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.hooks.ModelHooks.on_test_model_train
:noindex:

optimizer_step
~~~~~~~~~~~~~~

Expand Down Expand Up @@ -1274,19 +1349,19 @@ teardown
train_dataloader
~~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.train_dataloader
.. automethod:: pytorch_lightning.core.hooks.DataHooks.train_dataloader
:noindex:

val_dataloader
~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.val_dataloader
.. automethod:: pytorch_lightning.core.hooks.DataHooks.val_dataloader
:noindex:

test_dataloader
~~~~~~~~~~~~~~~

.. automethod:: pytorch_lightning.core.lightning.LightningModule.test_dataloader
.. automethod:: pytorch_lightning.core.hooks.DataHooks.test_dataloader
:noindex:

transfer_batch_to_device
Expand Down
12 changes: 12 additions & 0 deletions docs/source/extensions/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,15 @@ on_load_checkpoint

.. automethod:: pytorch_lightning.callbacks.Callback.on_load_checkpoint
:noindex:

on_after_backward
^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_after_backward
:noindex:

on_before_zero_grad
^^^^^^^^^^^^^^^^^^^

.. automethod:: pytorch_lightning.callbacks.Callback.on_before_zero_grad
:noindex:
2 changes: 1 addition & 1 deletion docs/source/extensions/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a
.. note::

- Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a
reduction `on_epoch_end`. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction.
reduction in ``on_train_epoch_end``. We recommend using the :doc:`metrics <../extensions/metrics>` API when working with custom reduction.

- Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with
suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor`
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,11 @@ def on_test_epoch_end(self, trainer, pl_module: LightningModule, outputs: List[A
pass

def on_epoch_start(self, trainer, pl_module: LightningModule) -> None:
"""Called when the epoch begins."""
"""Called when either of train/val/test epoch begins."""
pass

def on_epoch_end(self, trainer, pl_module: LightningModule) -> None:
"""Called when the epoch ends."""
"""Called when either of train/val/test epoch ends."""
pass

def on_batch_start(self, trainer, pl_module: LightningModule) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, scheduling: Dict[int, int]):
def going_to_accumulate_grad_batches(self):
return any([v > 1 for v in self.scheduling.values()])

def on_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
epoch = trainer.current_epoch
for i in reversed(range(len(self.epochs))):
if epoch >= self.epochs[i]:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/callbacks/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def on_init_end(self, trainer):
def on_train_start(self, trainer, pl_module):
self._train_batch_idx = trainer.batch_idx

def on_epoch_start(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = 0

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
Expand Down Expand Up @@ -392,8 +392,8 @@ def on_train_start(self, trainer, pl_module):
super().on_train_start(trainer, pl_module)
self.main_progress_bar = self.init_train_tqdm()

def on_epoch_start(self, trainer, pl_module):
super().on_epoch_start(trainer, pl_module)
def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
if total_train_batches != float('inf'):
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,13 @@ def on_predict_model_eval(self) -> None:

def on_epoch_start(self) -> None:
"""
Called in the training loop at the very beginning of the epoch.
Called when either of train/val/test epoch begins.
"""
# do something when the epoch starts

def on_epoch_end(self) -> None:
"""
Called in the training loop at the very end of the epoch.
Called when either of train/val/test epoch ends.
"""
# do something when the epoch ends

Expand Down
11 changes: 7 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,10 +720,13 @@ def validation_step(self, *args, **kwargs):
.. code-block:: python
# pseudocode of order
out = validation_step()
if defined('validation_step_end'):
out = validation_step_end(out)
out = validation_epoch_end(out)
val_outs = []
for val_batch in val_data:
out = validation_step(val_batch)
if defined('validation_step_end'):
out = validation_step_end(out)
val_outs.append(out)
val_outs = validation_epoch_end(val_outs)
.. code-block:: python
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ def on_test_epoch_end(self, outputs: List[Any]):
callback.on_test_epoch_end(self, self.lightning_module)

def on_epoch_start(self):
"""Called when the epoch begins."""
"""Called when either of train/val/test epoch begins."""
for callback in self.callbacks:
callback.on_epoch_start(self, self.lightning_module)

def on_epoch_end(self):
"""Called when the epoch ends."""
"""Called when either of train/val/test epoch ends."""
for callback in self.callbacks:
callback.on_epoch_end(self, self.lightning_module)

Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def setup(self, model, max_batches, dataloaders):
self._predictions = [[] for _ in range(self.num_dataloaders)]

def on_evaluation_epoch_start(self, *args, **kwargs):
self.trainer.call_hook('on_epoch_start', *args, **kwargs)

if self.trainer.testing:
self.trainer.call_hook('on_test_epoch_start', *args, **kwargs)
else:
Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def on_train_epoch_start(self, epoch):
self.trainer.train_dataloader.sampler.set_epoch(epoch)

# changing gradient according accumulation_scheduler
self.trainer.accumulation_scheduler.on_epoch_start(self.trainer, self.trainer.lightning_module)
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)

# stores accumulated grad fractions per batch
self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches)
Expand Down Expand Up @@ -540,7 +540,7 @@ def run_training_epoch(self):
self.increment_accumulated_grad_global_step()

# epoch end hook
self.run_on_epoch_end_hook(epoch_output)
self.on_train_epoch_end(epoch_output)

# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(
Expand Down Expand Up @@ -782,7 +782,7 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None):
# update lr
self.trainer.optimizer_connector.update_learning_rates(interval="step", monitor_metrics=monitor_metrics)

def run_on_epoch_end_hook(self, epoch_output):
def on_train_epoch_end(self, epoch_output):
# inform logger the batch loop has finished
self.trainer.logger_connector.on_train_epoch_end()

Expand Down
4 changes: 4 additions & 0 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_pretrain_routine_end(trainer, model),
call.on_sanity_check_start(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
Expand Down Expand Up @@ -84,6 +85,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir):
call.on_train_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
Expand Down Expand Up @@ -118,6 +120,7 @@ def test_trainer_callback_hook_system_test(tmpdir):
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'test'),
call.on_test_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_test_epoch_start(trainer, model),
call.on_test_batch_start(trainer, model, ANY, 0, 0),
call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
Expand Down Expand Up @@ -151,6 +154,7 @@ def test_trainer_callback_hook_system_validate(tmpdir):
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'validate'),
call.on_validation_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
Expand Down
Loading

0 comments on commit 9be092d

Please sign in to comment.