Skip to content

Commit

Permalink
Fix setup callback hook to pass LightningModule through (#4608)
Browse files Browse the repository at this point in the history
* Fix setup callback hook

* Update CHANGELOG.md

* Update test_trainer.py

* Update test_trainer.py

* Update test_trainer.py

* fix chlog

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
ananthsub and Borda committed Nov 14, 2020
1 parent 2d78d9b commit d096a2e
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `setup` callback hook to correctly pass the LightningModule through ([#4608](https://github.com/PyTorchLightning/pytorch-lightning/pull/4608))




Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ class TrainerCallbackHookMixin(ABC):
callbacks: List[Callback] = []
get_model: Callable

def setup(self, stage: str):
def setup(self, model, stage: str):
"""Called in the beginning of fit and test"""
for callback in self.callbacks:
callback.setup(self, self.get_model(), stage)
callback.setup(self, model, stage)

def teardown(self, stage: str):
"""Called at the end of fit and test"""
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ def call_setup_hook(self, model):
called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit
if not called:
self.datamodule.setup(stage_name)
self.setup(stage_name)
self.setup(model, stage_name)
model.setup(stage_name)

def _reset_result_and_set_hook_fx_name(self, hook_name):
Expand Down
3 changes: 2 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,7 +1431,8 @@ def setup(self, stage):
self.stage = stage

class TrainerSubclass(Trainer):
def setup(self, stage):
def setup(self, model, stage):
assert model is not None
self.stage = stage

model = CurrentModel()
Expand Down

0 comments on commit d096a2e

Please sign in to comment.