From 5246055756235da4ce6335498afa7c4b8ccaddbc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 3 Jun 2021 22:29:48 +0100 Subject: [PATCH 1/5] Modify API to ensure hooks defined in the accelerator are called as expected --- pytorch_lightning/trainer/trainer.py | 10 ++++++---- pytorch_lightning/trainer/training_loop.py | 5 ++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ded6e3395e30c..4f6dbb75933f7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1237,11 +1237,13 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: hook_fx = getattr(model_ref, hook_name) output = hook_fx(*args, **kwargs) - # if the PL module doesn't have the hook then call the accelerator - # used to auto-reduce things for the user with Results obj - elif hasattr(self.accelerator, hook_name): + # call the accelerator hook + if hasattr(self.accelerator, hook_name): accelerator_hook = getattr(self.accelerator, hook_name) - output = accelerator_hook(*args, **kwargs) + accelerator_output = accelerator_hook(*args, **kwargs) + # Rely on the accelerator output if lightningModule hook returns nothing + if output is None: + output = accelerator_output if not skip: self._cache_logged_metrics() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 156fea5d37ac8..1d7eba01aebb1 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -624,9 +624,8 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: else: model_ref.on_train_epoch_end() - # if the PL module doesn't have the hook then call the accelerator - # used to auto-reduce things for the user with Results obj - elif hasattr(self.trainer.accelerator, hook_name): + # call the accelerator hook + if hasattr(self.trainer.accelerator, hook_name): accelerator_hook = getattr(self.trainer.accelerator, hook_name) accelerator_hook() From 18ceedce4c3ee6a8af6f5d606c0f3579f2feaf0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 4 Jun 2021 12:44:55 +0200 Subject: [PATCH 2/5] handle step_end in dp --- pytorch_lightning/plugins/training_type/dp.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/dp.py b/pytorch_lightning/plugins/training_type/dp.py index 18aeb6a451d4a..2787ab5644ccd 100644 --- a/pytorch_lightning/plugins/training_type/dp.py +++ b/pytorch_lightning/plugins/training_type/dp.py @@ -19,6 +19,7 @@ from pytorch_lightning.overrides.data_parallel import LightningParallelModule from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import _METRIC_COLLECTION @@ -101,10 +102,16 @@ def predict_step(self, *args, **kwargs): return self.model(*args, **kwargs) def training_step_end(self, output): - return self.reduce(output) + if not is_overridden("training_step_end", self.lightning_module): + return self.reduce(output) + return output def validation_step_end(self, output): - return self.reduce(output) + if not is_overridden("validation_step_end", self.lightning_module): + return self.reduce(output) + return output def test_step_end(self, output): - return self.reduce(output) + if not is_overridden("test_step_end", self.lightning_module): + return self.reduce(output) + return output From d9a4b546d32658dc97e7821022de6b34ed28ae46 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 4 Jun 2021 12:02:48 +0100 Subject: [PATCH 3/5] Add changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index a30ddc6530790..94bb4041bc36f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `Trainer.fit` now raises an error when using manual optimization with unsupported features such as `gradient_clip_val` or `accumulate_grad_batches` ([#7788](https://github.com/PyTorchLightning/pytorch-lightning/pull/7788)) +- Accelerator hooks are called regardless if `LightningModule` overrides the same hooks ([#7826](https://github.com/PyTorchLightning/pytorch-lightning/pull/7826)) + + ### Deprecated From 6a55739e77086fd91752a7aa1e84cbd4cae33e7e Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Fri, 4 Jun 2021 14:38:04 +0100 Subject: [PATCH 4/5] Update pytorch_lightning/trainer/trainer.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4f6dbb75933f7..bbf5c436cbe2f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1242,8 +1242,7 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: accelerator_hook = getattr(self.accelerator, hook_name) accelerator_output = accelerator_hook(*args, **kwargs) # Rely on the accelerator output if lightningModule hook returns nothing - if output is None: - output = accelerator_output + output = output or accelerator_output if not skip: self._cache_logged_metrics() From e66be326dc53bf223183ec419e9efabd7c214ac0 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 4 Jun 2021 14:52:02 +0100 Subject: [PATCH 5/5] Add todo and explanation --- pytorch_lightning/trainer/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bbf5c436cbe2f..b9846af644e82 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1242,7 +1242,9 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any: accelerator_hook = getattr(self.accelerator, hook_name) accelerator_output = accelerator_hook(*args, **kwargs) # Rely on the accelerator output if lightningModule hook returns nothing - output = output or accelerator_output + # Required for cases such as DataParallel where we reduce the output for the user + # todo: move this data parallel logic into the data parallel plugin + output = accelerator_output if output is None else output if not skip: self._cache_logged_metrics()