Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid deprecation warning after #9901 #9951

Merged
merged 4 commits into from
Oct 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
return super().setup(trainer)

def on_train_start(self) -> None:
super().on_train_start()
# clear cache before training
torch.cuda.empty_cache()

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def on_test_end(self):
def on_predict_end(self):
self._detach_models()

def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
# Updates optimizer stats if LR scheduler modified the optimizer state
optimizer = self.lightning_module.trainer.optimizers[0]
self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def on_predict_end(self):
"""Called when predict ends."""
pass

def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Called in the training loop before anything happens for that batch."""
pass

Expand Down
17 changes: 16 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,15 +1401,30 @@ def call_hook(
if callable(model_fx):
output = model_fx(*args, **kwargs)

# *Bad code alert*
# The `Accelerator` mostly calls the `TrainingTypePlugin` but some of those calls are deprecated.
# The following logic selectively chooses which hooks are called on each object.
# In the case of `setup` and `teardown`, the hooks on the `LightningModule` should not call the hooks of the
# same name in these objects as they are meant to be managed outside of the `LightningModule` lifecycle.
# All of this should be fixed by #8506

# call the accelerator hook
if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name):
if hook_name in ("on_train_start",) and hasattr(self.accelerator, hook_name):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
accelerator_hook = getattr(self.accelerator, hook_name)
accelerator_output = accelerator_hook(*args, **kwargs)
# Rely on the accelerator output if lightningModule hook returns nothing
# Required for cases such as DataParallel where we reduce the output for the user
# todo: move this data parallel logic into the data parallel plugin
output = accelerator_output if output is None else output

# call the ttp hook
if hook_name not in ("setup", "teardown", "on_train_start") and hasattr(
self.training_type_plugin, hook_name
):
ttp_hook = getattr(self.training_type_plugin, hook_name)
ttp_output = ttp_hook(*args, **kwargs)
output = ttp_output if output is None else output

if pl_module:
# restore current_fx when nested context
pl_module._current_fx_name = prev_fx_name
Expand Down
2 changes: 1 addition & 1 deletion tests/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def run_training(**trainer_kwargs):
@pytest.mark.parametrize(["max_epochs", "batch_idx_"], [(2, 5), (3, 8), (4, 12)])
def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_, tmpdir):
class CurrentModel(BoringModel):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if batch_idx == batch_idx_:
return -1

Expand Down