Skip to content

Commit

Permalink
passing batch outputs to on_train_batch_end (#4369)
Browse files Browse the repository at this point in the history
* passing batch outputs to on_train_batch_end

* styling

* updating epoch end logic

* also condition on on_train_epoch_end hooks

* more readable

* pep8

* pep8

* readability suggestion accepted

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* adding test_training_epoch_end_metrics_collection_on_override test

* fix formatting

* fix formatting

Co-authored-by: Swetha Mandava <smandava@nvidia.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>

(cherry picked from commit 5fcca4e)
  • Loading branch information
swethmandava authored and Borda committed Feb 4, 2021
1 parent 5ac2a54 commit 2ea6942
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 19 deletions.
43 changes: 25 additions & 18 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,13 @@ def on_train_epoch_start(self, epoch):
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")

def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx):
# hook
self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx)
self.trainer.call_hook('on_train_batch_end', batch_end_outputs, batch, batch_idx, dataloader_idx)
self.trainer.call_hook('on_batch_end')

# figure out what to track for epoch end
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)
self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs)

# reset batch logger internals
self.trainer.logger_connector.on_train_batch_end()
Expand All @@ -259,12 +259,27 @@ def reset_train_val_dataloaders(self, model):
if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
self.trainer.reset_val_dataloader(model)

def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs):
def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):

# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
sample_output = opt_outputs[-1]

# decide if we need to reduce at the end of the epoch automatically
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
hook_overridden = (
is_overridden("training_epoch_end", model=self.trainer.get_model()) or
is_overridden("on_train_epoch_end", model=self.trainer.get_model())
)

# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
if not(hook_overridden or auto_reduce_tng_result):
continue

# with 1 step (no tbptt) don't use a sequence at epoch end
if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
opt_outputs = opt_outputs[0]

epoch_output[opt_idx].append(opt_outputs)

def get_optimizers_iterable(self):
Expand Down Expand Up @@ -548,17 +563,14 @@ def run_training_epoch(self):
if batch_output.signal == -1:
break

# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
epoch_end_outputs = self.process_train_step_outputs(
batch_end_outputs = self.process_train_step_outputs(
batch_output.training_step_output_for_epoch_end,
self.early_stopping_accumulator,
self.checkpoint_accumulator,
)

# hook
# TODO: add outputs to batches
self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)
self.on_train_batch_end(epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx)

# -----------------------------------------
# SAVE METRICS TO LOGGERS
Expand Down Expand Up @@ -896,7 +908,7 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
# the training step outputs a list per optimizer. The list contains the outputs at each time step
# when no TBPTT is used, then the list has 1 item per batch
# when TBPTT IS used, then the list has n items (1 per time step)
epoch_end_outputs = []
batch_end_outputs = []
for optimizer_idx_outputs in all_train_step_outputs:
# extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
if len(optimizer_idx_outputs) == 0:
Expand All @@ -911,14 +923,9 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
if isinstance(sample_output, dict) and "checkpoint_on" in sample_output:
checkpoint_accumulator.accumulate(sample_output["checkpoint_on"])

# decide if we need to reduce at the end of the epoch automatically
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end

# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
if is_overridden("training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result:
epoch_end_outputs.append(optimizer_idx_outputs)
batch_end_outputs.append(optimizer_idx_outputs)

return epoch_end_outputs
return batch_end_outputs

def prepare_optimizers(self):
# in manual optimization we loop over all optimizers at once
Expand Down
55 changes: 54 additions & 1 deletion tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import pytest
import torch

from pytorch_lightning import Trainer

from pytorch_lightning import Trainer, Callback
from pytorch_lightning.accelerators.legacy.gpu_accelerator import GPUAccelerator
from pytorch_lightning.trainer.states import TrainerState
from tests.base import BoringModel, EvalModelTemplate, RandomDataset
Expand Down Expand Up @@ -91,6 +92,58 @@ def training_epoch_end(self, outputs):
assert metrics[f'epoch_metric_{i}'] == i


def test_training_epoch_end_metrics_collection_on_override(tmpdir):
""" Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """
num_epochs = 1

class LoggingCallback(Callback):

def on_train_epoch_end(self, trainer, pl_module):
self.len_outputs = 0

def on_train_epoch_end(self, trainer, pl_module, outputs):
self.len_outputs = len(outputs[0])

class OverriddenModel(EvalModelTemplate):

def on_train_epoch_start(self):
self.num_train_batches = 0

def training_epoch_end(self, outputs): # Overridden
pass
return

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.num_train_batches += 1

class NotOverriddenModel(EvalModelTemplate):

def on_train_epoch_start(self):
self.num_train_batches = 0

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.num_train_batches += 1

overridden_model = OverriddenModel()
not_overridden_model = NotOverriddenModel()

callback = LoggingCallback()
trainer = Trainer(
max_epochs=num_epochs,
default_root_dir=tmpdir,
overfit_batches=2,
callbacks=[callback],
)

result = trainer.fit(overridden_model)
assert callback.len_outputs == overridden_model.num_train_batches
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden

result = trainer.fit(not_overridden_model)
assert callback.len_outputs == 0
# outputs from on_train_batch_end should be empty


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_transfer_batch_hook():

Expand Down

0 comments on commit 2ea6942

Please sign in to comment.