Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing val/test hooks in LightningModule #5467

Merged
merged 8 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))


- Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467))


### Changed

- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
Expand Down
29 changes: 27 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from typing import Any, Dict, List, Optional, Union

import torch
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn


class ModelHooks:
"""Hooks to be used in LightningModule."""
Expand Down Expand Up @@ -74,7 +75,7 @@ def on_fit_end(self):

def on_train_start(self) -> None:
"""
Called at the beginning of training before sanity check.
Called at the beginning of training after sanity check.
"""
# do something at the start of training

Expand All @@ -84,6 +85,18 @@ def on_train_end(self) -> None:
"""
# do something at the end of training

def on_validation_start(self):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""
Called at the beginning of validation.
"""
# do something at the start of validation

def on_validation_end(self):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""
Called at the end of validation.
"""
# do something at the end of validation

def on_pretrain_routine_start(self) -> None:
"""
Called at the beginning of the pretrain routine (between fit and train start).
Expand Down Expand Up @@ -253,6 +266,18 @@ def on_test_epoch_end(self) -> None:
"""
# do something when the epoch ends

def on_test_start(self):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""
Called at the beginning of testing.
"""
# do something at the start of testing

def on_test_end(self):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
"""
Called at the end of testing.
"""
# do something at the end of testing

def on_before_zero_grad(self, optimizer: Optimizer) -> None:
"""
Called after optimizer.step() and before optimizer.zero_grad().
Expand Down
25 changes: 19 additions & 6 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from unittest.mock import MagicMock

import pytest
import torch
from unittest.mock import MagicMock

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from tests.base import EvalModelTemplate, BoringModel
from tests.base import BoringModel, EvalModelTemplate


@pytest.mark.parametrize('max_steps', [1, 2, 3])
Expand Down Expand Up @@ -253,10 +253,6 @@ def on_test_start(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_start()

def on_test_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_end()

def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_batch_start(batch, batch_idx, dataloader_idx)
Expand Down Expand Up @@ -289,6 +285,14 @@ def on_test_model_train(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_model_train()

def on_test_end(self):
self.called.append(inspect.currentframe().f_code.co_name)
super().on_test_end()

def teardown(self, stage: str):
self.called.append(inspect.currentframe().f_code.co_name)
super().teardown(stage)

model = HookedModel()

assert model.called == []
Expand All @@ -312,10 +316,12 @@ def on_test_model_train(self):
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'on_validation_model_eval',
'on_validation_start',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_validation_batch_end',
'on_validation_epoch_end',
'on_validation_end',
'on_validation_model_train',
'on_train_start',
'on_epoch_start',
Expand All @@ -329,16 +335,19 @@ def on_test_model_train(self):
'on_before_zero_grad',
'on_train_batch_end',
'on_validation_model_eval',
'on_validation_start',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_validation_batch_end',
'on_validation_epoch_end',
'on_save_checkpoint',
'on_validation_end',
'on_validation_model_train',
'on_epoch_end',
'on_train_epoch_end',
'on_train_end',
'on_fit_end',
'teardown',
]

assert model.called == expected
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall this be updated the way it's done with callback using mock?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that would be great (in a follow up pr). let me know if you need help

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure 👍

Expand All @@ -351,12 +360,16 @@ def on_test_model_train(self):
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'on_test_model_eval',
'on_test_start',
'on_test_epoch_start',
'on_test_batch_start',
'on_test_batch_end',
'on_test_epoch_end',
'on_test_end',
'on_test_model_train',
'on_fit_end',
'teardown', # for 'fit'
'teardown', # for 'test'
]

assert model2.called == expected