Skip to content

Commit

Permalink
[IPU] Call accelerator hooks regardless if LM hook overridden 1/n (#7826
Browse files Browse the repository at this point in the history
)

* Modify API to ensure hooks defined in the accelerator are called as expected

* handle step_end in dp

* Add changelog

* Update pytorch_lightning/trainer/trainer.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Add todo and explanation

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
3 people authored Jun 4, 2021
1 parent 51d370f commit 7c7182d
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `Trainer.fit` now raises an error when using manual optimization with unsupported features such as `gradient_clip_val` or `accumulate_grad_batches` ([#7788](https://github.com/PyTorchLightning/pytorch-lightning/pull/7788))


- Accelerator hooks are called regardless if `LightningModule` overrides the same hooks ([#7826](https://github.com/PyTorchLightning/pytorch-lightning/pull/7826))


- Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822))


Expand Down
13 changes: 10 additions & 3 deletions pytorch_lightning/plugins/training_type/dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import _METRIC_COLLECTION


Expand Down Expand Up @@ -101,10 +102,16 @@ def predict_step(self, *args, **kwargs):
return self.model(*args, **kwargs)

def training_step_end(self, output):
return self.reduce(output)
if not is_overridden("training_step_end", self.lightning_module):
return self.reduce(output)
return output

def validation_step_end(self, output):
return self.reduce(output)
if not is_overridden("validation_step_end", self.lightning_module):
return self.reduce(output)
return output

def test_step_end(self, output):
return self.reduce(output)
if not is_overridden("test_step_end", self.lightning_module):
return self.reduce(output)
return output
11 changes: 7 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,11 +1237,14 @@ def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
hook_fx = getattr(model_ref, hook_name)
output = hook_fx(*args, **kwargs)

# if the PL module doesn't have the hook then call the accelerator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.accelerator, hook_name):
# call the accelerator hook
if hasattr(self.accelerator, hook_name):
accelerator_hook = getattr(self.accelerator, hook_name)
output = accelerator_hook(*args, **kwargs)
accelerator_output = accelerator_hook(*args, **kwargs)
# Rely on the accelerator output if lightningModule hook returns nothing
# Required for cases such as DataParallel where we reduce the output for the user
# todo: move this data parallel logic into the data parallel plugin
output = accelerator_output if output is None else output

if not skip:
self._cache_logged_metrics()
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,9 +634,8 @@ def _on_train_epoch_end_hook(self, processed_epoch_output) -> None:
else:
model_ref.on_train_epoch_end()

# if the PL module doesn't have the hook then call the accelerator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.trainer.accelerator, hook_name):
# call the accelerator hook
if hasattr(self.trainer.accelerator, hook_name):
accelerator_hook = getattr(self.trainer.accelerator, hook_name)
accelerator_hook()

Expand Down

0 comments on commit 7c7182d

Please sign in to comment.