From 5fcca4e43b243cd9fdb08050b285fb052856f13b Mon Sep 17 00:00:00 2001 From: Swetha Mandava Date: Tue, 26 Jan 2021 05:01:46 -0800 Subject: [PATCH] passing batch outputs to on_train_batch_end (#4369) * passing batch outputs to on_train_batch_end * styling * updating epoch end logic * also condition on on_train_epoch_end hooks * more readable * pep8 * pep8 * readability suggestion accepted Co-authored-by: Jirka Borovec * adding test_training_epoch_end_metrics_collection_on_override test * fix formatting * fix formatting Co-authored-by: Swetha Mandava Co-authored-by: Jirka Borovec Co-authored-by: Sean Naren Co-authored-by: Roger Shieh --- pytorch_lightning/trainer/training_loop.py | 43 +++++++++-------- tests/models/test_hooks.py | 54 ++++++++++++++++++++++ 2 files changed, 79 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 0925bc78a9533..ab8e7f56b1afa 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -226,13 +226,13 @@ def on_train_epoch_start(self, epoch): self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") - def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): # hook self.trainer.call_hook('on_batch_end') - self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx) + self.trainer.call_hook('on_train_batch_end', batch_end_outputs, batch, batch_idx, dataloader_idx) # figure out what to track for epoch end - self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs) + self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) # reset batch logger internals self.trainer.logger_connector.on_train_batch_end() @@ -244,12 +244,27 @@ def reset_train_val_dataloaders(self, model): if self.trainer.val_dataloaders is None and not self.trainer.reload_dataloaders_every_epoch: self.trainer.reset_val_dataloader(model) - def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs): + def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): + # track the outputs to reduce at the end of the epoch - for opt_idx, opt_outputs in enumerate(epoch_end_outputs): + for opt_idx, opt_outputs in enumerate(batch_end_outputs): + sample_output = opt_outputs[-1] + + # decide if we need to reduce at the end of the epoch automatically + auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end + hook_overridden = ( + is_overridden("training_epoch_end", model=self.trainer.get_model()) or + is_overridden("on_train_epoch_end", model=self.trainer.get_model()) + ) + + # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end + if not(hook_overridden or auto_reduce_tng_result): + continue + # with 1 step (no tbptt) don't use a sequence at epoch end if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result): opt_outputs = opt_outputs[0] + epoch_output[opt_idx].append(opt_outputs) def get_optimizers_iterable(self): @@ -537,17 +552,14 @@ def run_training_epoch(self): if batch_output.signal == -1: break - # only track outputs when user implements training_epoch_end - # otherwise we will build up unnecessary memory - epoch_end_outputs = self.process_train_step_outputs( + batch_end_outputs = self.process_train_step_outputs( batch_output.training_step_output_for_epoch_end, self.early_stopping_accumulator, self.checkpoint_accumulator, ) - # hook # TODO: add outputs to batches - self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx) + self.on_train_batch_end(epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx) # ----------------------------------------- # SAVE METRICS TO LOGGERS @@ -901,7 +913,7 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu # the training step outputs a list per optimizer. The list contains the outputs at each time step # when no TBPTT is used, then the list has 1 item per batch # when TBPTT IS used, then the list has n items (1 per time step) - epoch_end_outputs = [] + batch_end_outputs = [] for optimizer_idx_outputs in all_train_step_outputs: # extract one representative sample from each time step (1 if no tbptt) and 0th optimizer if len(optimizer_idx_outputs) == 0: @@ -916,14 +928,9 @@ def process_train_step_outputs(self, all_train_step_outputs, early_stopping_accu if isinstance(sample_output, dict) and "checkpoint_on" in sample_output: checkpoint_accumulator.accumulate(sample_output["checkpoint_on"]) - # decide if we need to reduce at the end of the epoch automatically - auto_reduce_tng_result = isinstance(sample_output, Result) and sample_output.should_reduce_on_epoch_end - - # only track when a) it needs to be autoreduced OR b) the user wants to manually reduce on epoch end - if is_overridden("training_epoch_end", model=self.trainer.get_model()) or auto_reduce_tng_result: - epoch_end_outputs.append(optimizer_idx_outputs) + batch_end_outputs.append(optimizer_idx_outputs) - return epoch_end_outputs + return batch_end_outputs def prepare_optimizers(self): # in manual optimization we loop over all optimizers at once diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 5352e749c5e55..62d17515119cd 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -18,8 +18,10 @@ import pytest import torch + from pytorch_lightning import Trainer from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator +import pytorch_lightning as pl from tests.base import BoringModel, EvalModelTemplate, RandomDataset @@ -90,6 +92,58 @@ def training_epoch_end(self, outputs): assert metrics[f'epoch_metric_{i}'] == i +def test_training_epoch_end_metrics_collection_on_override(tmpdir): + """ Test that batch end metrics are collected when training_epoch_end is overridden at the end of an epoch. """ + num_epochs = 1 + + class LoggingCallback(pl.Callback): + + def on_train_epoch_end(self, trainer, pl_module): + self.len_outputs = 0 + + def on_train_epoch_end(self, trainer, pl_module, outputs): + self.len_outputs = len(outputs[0]) + + class OverriddenModel(EvalModelTemplate): + + def on_train_epoch_start(self): + self.num_train_batches = 0 + + def training_epoch_end(self, outputs): # Overridden + pass + return + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.num_train_batches += 1 + + class NotOverriddenModel(EvalModelTemplate): + + def on_train_epoch_start(self): + self.num_train_batches = 0 + + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + self.num_train_batches += 1 + + overridden_model = OverriddenModel() + not_overridden_model = NotOverriddenModel() + + callback = LoggingCallback() + trainer = Trainer( + max_epochs=num_epochs, + default_root_dir=tmpdir, + overfit_batches=2, + callbacks=[callback], + ) + + result = trainer.fit(overridden_model) + assert callback.len_outputs == overridden_model.num_train_batches + # outputs from on_train_batch_end should be accessible in on_train_epoch_end hook if training_epoch_end is overridden + + result = trainer.fit(not_overridden_model) + assert callback.len_outputs == 0 + # outputs from on_train_batch_end should be empty + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_transfer_batch_hook():