Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

passing batch outputs to on_train_batch_end #4369

Merged
merged 19 commits into from
Jan 26, 2021
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 25 additions & 18 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,13 @@ def on_train_epoch_start(self, epoch):
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")

def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx):
# hook
self.trainer.call_hook('on_batch_end')
self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx)
self.trainer.call_hook('on_train_batch_end', batch_end_outputs, batch, batch_idx, dataloader_idx)

# figure out what to track for epoch end
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)
self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs)

# reset batch logger internals
self.trainer.logger_connector.on_train_batch_end()
Expand All @@ -244,12 +244,27 @@ def reset_train_val_dataloaders(self, model):
if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch:
self.trainer.reset_val_dataloader(model)

def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs):
def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs):

# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
for opt_idx, opt_outputs in enumerate(batch_end_outputs):
sample_output = opt_outputs[-1]

# decide if we need to reduce at the end of the epoch automatically
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end
hook_overridden = (
is_overridden("training_epoch_end", model=self.trainer.get_model()) or
is_overridden("on_train_epoch_end", model=self.trainer.get_model())
)

# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
if not(hook_overridden or auto_reduce_tng_result):
continue

# with 1 step (no tbptt) don't use a sequence at epoch end
if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
opt_outputs = opt_outputs[0]

epoch_output[opt_idx].append(opt_outputs)

def get_optimizers_iterable(self):
Expand Down Expand Up @@ -537,17 +552,14 @@ def run_training_epoch(self):
if batch_output.signal == -1:
break

# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
epoch_end_outputs = self.process_train_step_outputs(
batch_end_outputs = self.process_train_step_outputs(
batch_output.training_step_output_for_epoch_end,
self.early_stopping_accumulator,
self.checkpoint_accumulator,
)

# hook
# TODO: add outputs to batches
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

anyone knows what this todo means?

self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)
self.on_train_batch_end(epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx)

# -----------------------------------------
# SAVE METRICS TO LOGGERS
Expand Down Expand Up @@ -897,7 +909,7 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
# the training step outputs a list per optimizer. The list contains the outputs at each time step
# when no TBPTT is used, then the list has 1 item per batch
# when TBPTT IS used, then the list has n items (1 per time step)
epoch_end_outputs = []
batch_end_outputs = []
for optimizer_idx_outputs in all_train_step_outputs:
# extract one representative sample from each time step (1 if no tbptt) and 0th optimizer
if len(optimizer_idx_outputs) == 0:
Expand All @@ -912,14 +924,9 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu
if isinstance(sample_output, dict) and "checkpoint_on" in sample_output:
checkpoint_accumulator.accumulate(sample_output["checkpoint_on"])

# decide if we need to reduce at the end of the epoch automatically
auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end

# only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end
if is_overridden("training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result:
epoch_end_outputs.append(optimizer_idx_outputs)
batch_end_outputs.append(optimizer_idx_outputs)

return epoch_end_outputs
return batch_end_outputs

def prepare_optimizers(self):
# in manual optimization we loop over all optimizers at once
Expand Down
54 changes: 54 additions & 0 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import pytest
import torch


from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
import pytorch_lightning as pl
from tests.base import BoringModel, EvalModelTemplate, RandomDataset


Expand Down Expand Up @@ -90,6 +92,58 @@ def training_epoch_end(self, outputs):
assert metrics[f'epoch_metric_{i}'] == i


def test_training_epoch_end_metrics_collection_on_override(tmpdir):
""" Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """
num_epochs = 1

class LoggingCallback(pl.Callback):

def on_train_epoch_end(self, trainer, pl_module):
Copy link
Member

@Borda Borda Jan 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SeanNaren flake8 found this bug, but it was still merged, could pls fix it in #5666
most likely on_train_epoch_end >> on_train_epoch_start as *_end is here twice

self.len_outputs = 0

def on_train_epoch_end(self, trainer, pl_module, outputs):
self.len_outputs = len(outputs[0])

class OverriddenModel(EvalModelTemplate):

def on_train_epoch_start(self):
self.num_train_batches = 0

def training_epoch_end(self, outputs): # Overridden
pass
return

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.num_train_batches += 1

class NotOverriddenModel(EvalModelTemplate):

def on_train_epoch_start(self):
self.num_train_batches = 0

def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.num_train_batches += 1

overridden_model = OverriddenModel()
not_overridden_model = NotOverriddenModel()

callback = LoggingCallback()
trainer = Trainer(
max_epochs=num_epochs,
default_root_dir=tmpdir,
overfit_batches=2,
callbacks=[callback],
)

result = trainer.fit(overridden_model)
assert callback.len_outputs == overridden_model.num_train_batches
# outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden

result = trainer.fit(not_overridden_model)
assert callback.len_outputs == 0
# outputs from on_train_batch_end should be empty


@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_transfer_batch_hook():

Expand Down