Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[1/2] Deprecate outputs in on_train_epoch_end hooks #7339

Merged
merged 15 commits into from
May 5, 2021
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Deprecated

- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339))


- Deprecated `Trainer.truncated_bptt_steps` in favor of `LightningModule.truncated_bptt_steps` ([#7323](https://github.com/PyTorchLightning/pytorch-lightning/pull/7323))

Expand All @@ -217,7 +219,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the `save_function` property from the `ModelCheckpoint` callback ([#7201](https://github.com/PyTorchLightning/pytorch-lightning/pull/7201))


- Deprecated `LightningModule.write_predictions` and `LigtningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066))
- Deprecated `LightningModule.write_predictions` and `LightningModule.write_predictions_dict` ([#7066](https://github.com/PyTorchLightning/pytorch-lightning/pull/7066))


- Deprecated `TrainerLoggingMixin` in favor of a separate utilities module for metric handling ([#7180](https://github.com/PyTorchLightning/pytorch-lightning/pull/7180))
Expand Down
10 changes: 3 additions & 7 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.enums import AMPType, GradClipAlgorithmType, LightningEnum
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
from pytorch_lightning.utilities.types import STEP_OUTPUT

if _NATIVE_AMP_AVAILABLE:
from torch.cuda.amp import GradScaler
Expand Down Expand Up @@ -354,12 +354,8 @@ def clip_gradients(
model=self.model,
)

def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
"""Hook to do something on the end of an training epoch

Args:
outputs: the outputs of the training steps
"""
def on_train_epoch_end(self) -> None:
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
"""Hook to do something on the end of an training epoch."""
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
pass

def on_train_end(self) -> None:
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def on_train_epoch_start(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningMo
"""Called when the train epoch begins."""
pass

def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', outputs: EPOCH_OUTPUT) -> None:
def on_train_epoch_end(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', unused: Optional = None
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Called when the train epoch ends."""
pass

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _should_skip_check(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerFn
return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking

def on_train_epoch_end(self, trainer, pl_module, outputs) -> None:
def on_train_epoch_end(self, trainer, pl_module) -> None:
if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
return
self._run_early_stopping_check(trainer)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModul
self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []})
self._original_layers[id_]["names"].append((i, name))

def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs):
def on_train_epoch_end(self, trainer, pl_module: LightningModule):
current_epoch = trainer.current_epoch
prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning
amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def on_train_epoch_start(self) -> None:
Called in the training loop at the very beginning of the epoch.
"""

def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
def on_train_epoch_end(self, unused: Optional = None) -> None:
ananthsub marked this conversation as resolved.
Show resolved Hide resolved
"""
Called in the training loop at the very end of the epoch.
"""
Expand Down
10 changes: 9 additions & 1 deletion pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,15 @@ def on_train_epoch_end(self, outputs: EPOCH_OUTPUT):
outputs: List of outputs on each ``train`` epoch
"""
for callback in self.callbacks:
callback.on_train_epoch_end(self, self.lightning_module, outputs)
if is_param_in_hook_signature(callback.on_train_epoch_end, "outputs"):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
warning_cache.warn(
"The signature of `Callback.on_train_epoch_end` has changed in v1.3."
" `outputs` parameter has been removed."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
callback.on_train_epoch_end(self, self.lightning_module, outputs)
else:
callback.on_train_epoch_end(self, self.lightning_module)

def on_validation_epoch_start(self):
"""Called when the epoch begins."""
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1211,6 +1211,11 @@ def _cache_logged_metrics(self):
self.logger_connector.cache_logged_metrics()

def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
# Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook
# This was done to manage the deprecation of an argument to on_train_epoch_end
# If making chnages to this function, ensure that those changes are also made to
# TrainLoop._on_train_epoch_end_hook

# set hook_name to model + reset Result obj
skip = self._reset_result_and_set_hook_fx_name(hook_name)

Expand Down
67 changes: 62 additions & 5 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pytorch_lightning.utilities.grads import grad_norm
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache


Expand Down Expand Up @@ -197,16 +198,14 @@ def reset_train_val_dataloaders(self, model) -> None:

def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):

hook_overridden = self._should_add_batch_output_to_epoch_output()

# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
sample_output = opt_outputs[-1]

# decide if we need to reduce at the end of the epoch automatically
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
hook_overridden = (
is_overridden("training_epoch_end", model=self.trainer.lightning_module)
or is_overridden("on_train_epoch_end", model=self.trainer.lightning_module)
)

# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
if not (hook_overridden or auto_reduce_tng_result):
Expand All @@ -218,6 +217,22 @@ def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):

epoch_output[opt_idx].append(opt_outputs)

def _should_add_batch_output_to_epoch_output(self) -> bool:
# We add to the epoch outputs if
# 1. The model defines training_epoch_end OR
# 2. The model overrides on_train_epoch_end which has `outputs` in the signature
# TODO: in v1.5 this only needs to check if training_epoch_end is overridden
lightning_module = self.trainer.lightning_module
if is_overridden("training_epoch_end", model=lightning_module):
return True

if is_overridden("on_train_epoch_end", model=lightning_module):
model_hook_fx = getattr(lightning_module, "on_train_epoch_end")
if is_param_in_hook_signature(model_hook_fx, "outputs"):
return True

return False

def get_optimizers_iterable(self, batch_idx=None):
"""
Generates an iterable with (idx, optimizer) for each optimizer.
Expand Down Expand Up @@ -593,9 +608,51 @@ def on_train_epoch_end(self, epoch_output: List[List[List[Result]]]) -> None:
self.trainer.logger_connector.cache_logged_metrics()

# call train epoch end hooks
self.trainer.call_hook('on_train_epoch_end', processed_epoch_output)
self._on_train_epoch_end_hook(processed_epoch_output)
self.trainer.call_hook('on_epoch_end')

def _on_train_epoch_end_hook(self, processed_epoch_output) -> None:
# We cannot rely on Trainer.call_hook because the signatures might be different across
# lightning module and callback
# As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end`

# This implementation is copied from Trainer.call_hook
hook_name = "on_train_epoch_end"

# set hook_name to model + reset Result obj
skip = self.trainer._reset_result_and_set_hook_fx_name(hook_name)

# always profile hooks
with self.trainer.profiler.profile(hook_name):

# first call trainer hook
if hasattr(self.trainer, hook_name):
trainer_hook = getattr(self.trainer, hook_name)
trainer_hook(processed_epoch_output)

# next call hook in lightningModule
model_ref = self.trainer.lightning_module
if is_overridden(hook_name, model_ref):
hook_fx = getattr(model_ref, hook_name)
if is_param_in_hook_signature(hook_fx, "outputs"):
self.warning_cache.warn(
"The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3."
" `outputs` parameter has been deprecated."
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
model_ref.on_train_epoch_end(processed_epoch_output)
else:
model_ref.on_train_epoch_end()

# if the PL module doesn't have the hook then call the accelerator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.trainer.accelerator, hook_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a huge fan of this. Better to use call_hook and maybe perform the signature analysis somewhere else.

Copy link
Contributor Author

@ananthsub ananthsub May 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the comment, call_hook enforces that all of accelerator/trainer/module all take the exact same arguments for the hook, which might not be the case here. this was the same pattern @kaushikb11 followed in #6120

I'm not really a fan either, but call_hook is calling over 3 distinct interfaces which aren't enforced to be compatible.

maybe this is something we can look at for v1.4 is how to make to simplify/strengthen this? maybe the techniques @SkafteNicki used for metrics collections could apply here, but that seems beyond the scope of this PR

one thing I can do is add comments to Trainer.call_hook to indicate that there's this override being applied in training loop and any changes to call_hook must also be applied here.

accelerator_hook = getattr(self.trainer.accelerator, hook_name)
accelerator_hook()

if not skip:
self.trainer._cache_logged_metrics()

def run_training_batch(self, batch, batch_idx, dataloader_idx):
# track grad norms
grad_norm_dic = {}
Expand Down
5 changes: 1 addition & 4 deletions tests/callbacks/test_callback_hook_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
assert 'x' in outputs

def on_train_epoch_end(self, trainer, pl_module, outputs):
assert len(outputs) == trainer.num_training_batches

class TestModel(BoringModel):

def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None:
Expand All @@ -48,7 +45,7 @@ def on_validation_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx
def on_test_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None:
assert 'x' in outputs

def on_train_epoch_end(self, outputs) -> None:
def training_epoch_end(self, outputs) -> None:
assert len(outputs) == self.trainer.num_training_batches

model = TestModel()
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

class TestBackboneFinetuningCallback(BackboneFinetuning):

def on_train_epoch_end(self, trainer, pl_module, outputs):
def on_train_epoch_end(self, trainer, pl_module):
epoch = trainer.current_epoch
if self.unfreeze_backbone_at_epoch <= epoch:
optimizer = trainer.optimizers[0]
Expand Down
47 changes: 47 additions & 0 deletions tests/deprecated_api/test_remove_1-5.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,53 @@ def test_v1_5_0_model_checkpoint_period(tmpdir):
ModelCheckpoint(dirpath=tmpdir, period=1)


def test_v1_5_0_old_on_train_epoch_end(tmpdir):
callback_warning_cache.clear()

class OldSignature(Callback):

def on_train_epoch_end(self, trainer, pl_module, outputs): # noqa
...

class OldSignatureModel(BoringModel):

def on_train_epoch_end(self, outputs): # noqa
...

model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())

with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)

callback_warning_cache.clear()

model = OldSignatureModel()

with pytest.deprecated_call(match="old signature will be removed in v1.5"):
trainer.fit(model)

trainer.train_loop.warning_cache.clear()

class NewSignature(Callback):

def on_train_epoch_end(self, trainer, pl_module):
...

trainer.callbacks = [NewSignature()]
with no_deprecated_call(match="`Callback.on_train_epoch_end` signature has changed in v1.3."):
trainer.fit(model)

class NewSignatureModel(BoringModel):

def on_train_epoch_end(self):
...

model = NewSignatureModel()
with no_deprecated_call(match="`ModelHooks.on_train_epoch_end` signature has changed in v1.3."):
trainer.fit(model)


def test_v1_5_0_old_on_validation_epoch_end(tmpdir):
callback_warning_cache.clear()

Expand Down
30 changes: 9 additions & 21 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import torch

from pytorch_lightning import Callback, Trainer
from pytorch_lightning import Trainer
from tests.helpers import BoringDataModule, BoringModel, RandomDataset
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -92,21 +92,17 @@ def training_epoch_end(self, outputs):
def test_training_epoch_end_metrics_collection_on_override(tmpdir):
""" Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """

class LoggingCallback(Callback):
class OverriddenModel(BoringModel):

def on_train_epoch_start(self, trainer, pl_module):
def __init__(self):
super().__init__()
self.len_outputs = 0

def on_train_epoch_end(self, trainer, pl_module, outputs):
self.len_outputs = len(outputs)

class OverriddenModel(BoringModel):

def on_train_epoch_start(self):
self.num_train_batches = 0

def training_epoch_end(self, outputs): # Overridden
return
def training_epoch_end(self, outputs):
self.len_outputs = len(outputs)

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.num_train_batches += 1
Expand All @@ -123,22 +119,14 @@ def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
not_overridden_model = NotOverriddenModel()
not_overridden_model.training_epoch_end = None

callback = LoggingCallback()
trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
overfit_batches=2,
callbacks=[callback],
)

trainer.fit(overridden_model)
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook
# if training_epoch_end is overridden
assert callback.len_outputs == overridden_model.num_train_batches

trainer.fit(not_overridden_model)
# outputs from on_train_batch_end should be empty
assert callback.len_outputs == 0
assert overridden_model.len_outputs == overridden_model.num_train_batches


@RunIf(min_gpus=1)
Expand Down Expand Up @@ -334,9 +322,9 @@ def on_train_epoch_start(self):
self.called.append("on_train_epoch_start")
super().on_train_epoch_start()

def on_train_epoch_end(self, outputs):
def on_train_epoch_end(self):
self.called.append("on_train_epoch_end")
super().on_train_epoch_end(outputs)
super().on_train_epoch_end()

def on_validation_start(self):
self.called.append("on_validation_start")
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def _assert_epoch_end(self, stage):
acc.reset.asset_not_called()
ap.reset.assert_not_called()

def on_train_epoch_end(self, outputs):
def on_train_epoch_end(self):
self._assert_epoch_end('train')

def on_validation_epoch_end(self, outputs):
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/logging_/test_train_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data
# with func = np.mean if on_epoch else func = np.max
self.count += 1

def on_train_epoch_end(self, trainer, pl_module, outputs):
def on_train_epoch_end(self, trainer, pl_module):
self.make_logging(
pl_module, 'on_train_epoch_end', 8, on_steps=[False], on_epochs=self.choices, prob_bars=self.choices
)
Expand Down
5 changes: 0 additions & 5 deletions tests/trainer/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,6 @@ def training_epoch_end(self, outputs):
[HookedModel._check_output(output) for output in outputs]
super().training_epoch_end(outputs)

def on_train_epoch_end(self, outputs):
assert len(outputs) == 2
[HookedModel._check_output(output) for output in outputs]
super().on_train_epoch_end(outputs)

model = HookedModel()

# fit model
Expand Down