Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid deprecation warning after #9901 #9951

Merged
merged 4 commits into from
Oct 16, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def on_test_end(self):
def on_predict_end(self):
self._detach_models()

def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
# Updates optimizer stats if LR scheduler modified the optimizer state
optimizer = self.lightning_module.trainer.optimizers[0]
self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def on_predict_end(self):
"""Called when predict ends."""
pass

def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Called in the training loop before anything happens for that batch."""
pass

Expand Down
12 changes: 6 additions & 6 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,14 +1401,14 @@ def call_hook(
if callable(model_fx):
output = model_fx(*args, **kwargs)

# call the accelerator hook
if hook_name not in ("setup", "teardown") and hasattr(self.accelerator, hook_name):
accelerator_hook = getattr(self.accelerator, hook_name)
accelerator_output = accelerator_hook(*args, **kwargs)
# Rely on the accelerator output if lightningModule hook returns nothing
# call the ttp hook
if hook_name not in ("setup", "teardown") and hasattr(self.training_type_plugin, hook_name):
ttp_hook = getattr(self.training_type_plugin, hook_name)
ttp_output = ttp_hook(*args, **kwargs)
# Rely on the TTP output if lightningModule hook returns nothing
# Required for cases such as DataParallel where we reduce the output for the user
# todo: move this data parallel logic into the data parallel plugin
output = accelerator_output if output is None else output
output = ttp_output if output is None else output
carmocca marked this conversation as resolved.
Show resolved Hide resolved

if pl_module:
# restore current_fx when nested context
Expand Down
2 changes: 1 addition & 1 deletion tests/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def run_training(**trainer_kwargs):
@pytest.mark.parametrize(["max_epochs", "batch_idx_"], [(2, 5), (3, 8), (4, 12)])
def test_on_train_batch_start_return_minus_one(max_epochs, batch_idx_, tmpdir):
class CurrentModel(BoringModel):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if batch_idx == batch_idx_:
return -1

Expand Down