Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Trainer.call_hook re-design #9029

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def on_pretrain_routine_end(self) -> None:

"""

def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Optional[int]:
"""
Called in the training loop before anything happens for that batch.

Expand Down
26 changes: 14 additions & 12 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,12 @@ def run(self, batch: Any, batch_idx: int) -> AttributeDict:

# hook
self.trainer.logger_connector.on_batch_start()
response = self.trainer.call_hook("on_batch_start")
if response == -1:
return AttributeDict(signal=-1)
self.trainer.call_hook(self.trainer.on_batch_start)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# hook
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0)
self.trainer.call_hook(self.trainer.on_train_batch_start, batch, batch_idx, 0)
response = self.trainer.call_hook(self.trainer.lightning_module.on_train_batch_start, batch, batch_idx, 0)
self.trainer.call_hook(self.trainer.accelerator.on_train_batch_start, batch, batch_idx, 0)
if response == -1:
return AttributeDict(signal=-1)

Expand Down Expand Up @@ -274,15 +274,14 @@ def _training_step(
with self.trainer.profiler.profile("model_forward"):
step_kwargs = self._build_kwargs(split_batch, batch_idx, opt_idx, hiddens)

# manually capture logged metrics
model_ref._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()
training_step_output = self.trainer.call_hook(self.trainer.accelerator.training_step, step_kwargs)
self.trainer.accelerator.post_training_step()

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)
model_output = self.trainer.call_hook(model_ref.training_step_end, training_step_output)
accel_output = self.trainer.call_hook(self.trainer.accelerator.training_step_end, training_step_output)
training_step_output = accel_output if model_output is None else model_output

_check_training_step_output(self.trainer.lightning_module, training_step_output)
_check_training_step_output(model_ref, training_step_output)

training_step_output, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
if training_step_output is None:
Expand Down Expand Up @@ -347,7 +346,10 @@ def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
optimizer: the current optimizer
"""
self.optim_progress.optimizer.zero_grad.increment_ready()
self.trainer.call_hook("on_before_zero_grad", optimizer)

self.trainer.call_hook(self.trainer.on_before_zero_grad, optimizer)
self.trainer.call_hook(self.trainer.lightning_module.on_before_zero_grad, optimizer)

self.optim_progress.optimizer.zero_grad.increment_started()

def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
Expand Down
48 changes: 31 additions & 17 deletions pytorch_lightning/loops/dataloader/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,20 @@ def on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
self._results.to(device=self.trainer.lightning_module.device)

if self.trainer.testing:
self.trainer.call_hook("on_test_start", *args, **kwargs)
self.trainer.call_hook(self.trainer.on_test_start, *args, **kwargs)
self.trainer.call_hook(self.trainer.lightning_module.on_test_start, *args, **kwargs)
self.trainer.call_hook(self.trainer.accelerator.on_test_start, *args, **kwargs)
else:
self.trainer.call_hook("on_validation_start", *args, **kwargs)
self.trainer.call_hook(self.trainer.on_validation_start, *args, **kwargs)
self.trainer.call_hook(self.trainer.lightning_module.on_validation_start, *args, **kwargs)
self.trainer.call_hook(self.trainer.accelerator.on_validation_start, *args, **kwargs)

def on_evaluation_model_eval(self) -> None:
"""Sets model to eval mode"""
if self.trainer.testing:
self.trainer.call_hook("on_test_model_eval")
self.trainer.call_hook(self.trainer.lightning_module.on_test_model_eval)
else:
self.trainer.call_hook("on_validation_model_eval")
self.trainer.call_hook(self.trainer.lightning_module.on_validation_model_eval)

def on_evaluation_model_train(self) -> None:
"""Sets model to train mode"""
Expand All @@ -195,22 +199,29 @@ def on_evaluation_model_train(self) -> None:
def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_{validation/test}_end`` hook"""
if self.trainer.testing:
self.trainer.call_hook("on_test_end", *args, **kwargs)
self.trainer.call_hook(self.trainer.on_test_end, *args, **kwargs)
self.trainer.call_hook(self.trainer.lightning_module.on_test_end, *args, **kwargs)
self.trainer.call_hook(self.trainer.accelerator.on_test_end, *args, **kwargs)
else:
self.trainer.call_hook("on_validation_end", *args, **kwargs)
self.trainer.call_hook(self.trainer.on_validation_end, *args, **kwargs)
self.trainer.call_hook(self.trainer.lightning_module.on_validation_end, *args, **kwargs)
self.trainer.call_hook(self.trainer.accelerator.on_validation_end, *args, **kwargs)

# reset any `torchmetrics.Metric` and the logger connector state
self.trainer.logger_connector.reset(metrics=True)

def on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks"""
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start", *args, **kwargs)
self.trainer.call_hook(self.trainer.on_epoch_start, *args, **kwargs)
self.trainer.call_hook(self.trainer.lightning_module.on_epoch_start, *args, **kwargs)

if self.trainer.testing:
self.trainer.call_hook("on_test_epoch_start", *args, **kwargs)
self.trainer.call_hook(self.trainer.on_test_epoch_start, *args, **kwargs)
self.trainer.call_hook(self.trainer.lightning_module.on_test_epoch_start, *args, **kwargs)
else:
self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)
self.trainer.call_hook(self.trainer.on_validation_epoch_start, *args, **kwargs)
self.trainer.call_hook(self.trainer.lightning_module.on_validation_epoch_start, *args, **kwargs)

def _should_track_batch_outputs_for_epoch_end(self) -> bool:
"""Whether the batch outputs should be stored for later usage"""
Expand All @@ -232,19 +243,22 @@ def evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:

if self.trainer.testing:
if is_overridden("test_epoch_end", model):
model._current_fx_name = "test_epoch_end"
model.test_epoch_end(outputs)

self.trainer.call_hook(model.test_epoch_end, outputs)
else:
if is_overridden("validation_epoch_end", model):
model._current_fx_name = "validation_epoch_end"
model.validation_epoch_end(outputs)
self.trainer.call_hook(model.validation_epoch_end, outputs)

def on_evaluation_epoch_end(self) -> None:
"""Runs ``on_{validation/test}_epoch_end`` hook"""
hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
self.trainer.call_hook(hook_name)
self.trainer.call_hook("on_epoch_end")
if self.trainer.testing:
self.trainer.call_hook(self.trainer.on_test_epoch_end)
self.trainer.call_hook(self.trainer.lightning_module.on_test_epoch_end)
else:
self.trainer.call_hook(self.trainer.on_validation_epoch_end)
self.trainer.call_hook(self.trainer.lightning_module.on_validation_epoch_end)

self.trainer.call_hook(self.trainer.on_epoch_end)
self.trainer.call_hook(self.trainer.lightning_module.on_epoch_end)
self.trainer.logger_connector.on_epoch_end()

def teardown(self) -> None:
Expand Down
15 changes: 11 additions & 4 deletions pytorch_lightning/loops/dataloader/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,12 @@ def on_predict_start(self) -> None:
self.trainer.lightning_module.zero_grad()

# hook
self.trainer.call_hook("on_predict_start")
self.trainer.call_hook("on_predict_epoch_start")
self.trainer.call_hook(self.trainer.on_predict_start)
self.trainer.call_hook(self.trainer.lightning_module.on_predict_start)
self.trainer.call_hook(self.trainer.accelerator.on_predict_start)

self.trainer.call_hook(self.trainer.on_predict_epoch_start)
self.trainer.call_hook(self.trainer.lightning_module.on_predict_epoch_start)

def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
"""Calls ``on_predict_epoch_end`` hook.
Expand All @@ -121,7 +125,8 @@ def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
"""
results = self.predictions

self.trainer.call_hook("on_predict_epoch_end", results)
self.trainer.call_hook(self.trainer.on_predict_epoch_end, results)
self.trainer.call_hook(self.trainer.lightning_module.on_predict_epoch_end, results)

if self.return_predictions:
return results[0] if self.num_dataloaders == 1 else results
Expand All @@ -133,7 +138,9 @@ def on_predict_end(self) -> None:
self.epoch_batch_indices = []

# hook
self.trainer.call_hook("on_predict_end")
self.trainer.call_hook(self.trainer.on_predict_end)
self.trainer.call_hook(self.trainer.lightning_module.on_predict_end)
self.trainer.call_hook(self.trainer.accelerator.on_predict_end)

def on_predict_model_eval(self):
"""Calls ``on_predict_model_eval`` hook"""
Expand Down
38 changes: 25 additions & 13 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,20 +141,20 @@ def evaluation_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Op
step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

if self.trainer.testing:
self.trainer.lightning_module._current_fx_name = "test_step"
with self.trainer.profiler.profile("test_step"):
output = self.trainer.accelerator.test_step(step_kwargs)
output = self.trainer.call_hook(self.trainer.accelerator.test_step, step_kwargs)
else:
self.trainer.lightning_module._current_fx_name = "validation_step"
with self.trainer.profiler.profile("validation_step"):
output = self.trainer.accelerator.validation_step(step_kwargs)

output = self.trainer.call_hook(self.trainer.accelerator.validation_step, step_kwargs)
return output

def evaluation_step_end(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
"""Calls the `{validation/test}_step_end` hook"""
hook_name = "test_step_end" if self.trainer.testing else "validation_step_end"
output = self.trainer.call_hook(hook_name, *args, **kwargs)
if self.trainer.testing:
model_output = self.trainer.call_hook(self.trainer.lightning_module.test_step_end, *args, **kwargs)
accel_output = self.trainer.call_hook(self.trainer.accelerator.test_step_end, *args, **kwargs)
else:
model_output = self.trainer.call_hook(self.trainer.lightning_module.validation_step_end, *args, **kwargs)
accel_output = self.trainer.call_hook(self.trainer.accelerator.validation_step_end, *args, **kwargs)
output = accel_output if model_output is None else model_output
return output

def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
Expand All @@ -174,9 +174,13 @@ def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx:
self.trainer.logger_connector.on_evaluation_batch_start(batch, batch_idx, dataloader_idx, self._num_dataloaders)

if self.trainer.testing:
self.trainer.call_hook("on_test_batch_start", batch, batch_idx, dataloader_idx)
self.trainer.call_hook(self.trainer.on_test_batch_start, batch, batch_idx, dataloader_idx)
self.trainer.call_hook(self.trainer.lightning_module.on_test_batch_start, batch, batch_idx, dataloader_idx)
else:
self.trainer.call_hook("on_validation_batch_start", batch, batch_idx, dataloader_idx)
self.trainer.call_hook(self.trainer.on_validation_batch_start, batch, batch_idx, dataloader_idx)
self.trainer.call_hook(
self.trainer.lightning_module.on_validation_batch_start, batch, batch_idx, dataloader_idx
)

def on_evaluation_batch_end(
self, output: Optional[STEP_OUTPUT], batch: Any, batch_idx: int, dataloader_idx: int
Expand All @@ -189,8 +193,16 @@ def on_evaluation_batch_end(
batch_idx: The index of the current batch
dataloader_idx: Index of the dataloader producing the current batch
"""
hook_name = "on_test_batch_end" if self.trainer.testing else "on_validation_batch_end"
self.trainer.call_hook(hook_name, output, batch, batch_idx, dataloader_idx)
if self.trainer.testing:
self.trainer.call_hook(self.trainer.on_test_batch_end, output, batch, batch_idx, dataloader_idx)
self.trainer.call_hook(
self.trainer.lightning_module.on_test_batch_end, output, batch, batch_idx, dataloader_idx
)
else:
self.trainer.call_hook(self.trainer.on_validation_batch_end, output, batch, batch_idx, dataloader_idx)
self.trainer.call_hook(
self.trainer.lightning_module.on_validation_batch_end, output, batch, batch_idx, dataloader_idx
)

self.trainer.logger_connector.on_batch_end()

Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/loops/epoch/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,20 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None

model_ref = self.trainer.lightning_module

self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx)
self.trainer.call_hook(self.trainer.on_predict_batch_start, batch, batch_idx, dataloader_idx)
self.trainer.call_hook(model_ref.on_predict_batch_start, batch, batch_idx, dataloader_idx)

self.batch_progress.increment_started()

model_ref._current_fx_name = "predict_step"
predictions = self.trainer.accelerator.predict_step(step_kwargs)
predictions = self.trainer.call_hook(self.trainer.accelerator.predict_step, step_kwargs)

self.batch_progress.increment_processed()

if predictions is None:
self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")

self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)
self.trainer.call_hook(self.trainer.on_predict_batch_end, predictions, batch, batch_idx, dataloader_idx)
self.trainer.call_hook(model_ref.on_predict_batch_end, predictions, batch, batch_idx, dataloader_idx)

self.batch_progress.increment_completed()

Expand Down
39 changes: 24 additions & 15 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,14 @@ def reset(self) -> None:
self.batch_loop.optim_progress.reset_on_epoch()

def on_run_start(self, *args: Any, **kwargs: Any) -> None:
# hook
self.trainer.logger_connector.on_epoch_start()
self.trainer.call_hook("on_epoch_start")
self.trainer.call_hook("on_train_epoch_start")

self.trainer.call_hook(self.trainer.on_epoch_start)
self.trainer.call_hook(self.trainer.lightning_module.on_epoch_start)

self.trainer.call_hook(self.trainer.on_train_epoch_start)
self.trainer.call_hook(self.trainer.lightning_module.on_train_epoch_start)

self.trainer.fit_loop.epoch_progress.increment_started()

def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
Expand Down Expand Up @@ -162,8 +166,14 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:

# hook
if not isinstance(self.batch_loop, IteratorBatchProcessor):
self.trainer.call_hook("on_train_batch_end", processed_batch_end_outputs, batch, self.batch_idx, 0)
self.trainer.call_hook("on_batch_end")
self.trainer.call_hook(
self.trainer.on_train_batch_end, processed_batch_end_outputs, batch, self.batch_idx, 0
)
self.trainer.call_hook(
self.trainer.lightning_module.on_train_batch_end, processed_batch_end_outputs, batch, self.batch_idx, 0
)

self.trainer.call_hook(self.trainer.on_batch_end)
self.trainer.logger_connector.on_batch_end()

self.batch_progress.increment_completed()
Expand Down Expand Up @@ -225,13 +235,8 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:
model = self.trainer.lightning_module

if is_overridden("training_epoch_end", model):
# run training_epoch_end
# refresh the result for custom logging at the epoch level
model._current_fx_name = "training_epoch_end"

# lightningmodule hook
training_epoch_end_output = model.training_epoch_end(processed_outputs)

# hook
training_epoch_end_output = self.trainer.call_hook(model.training_epoch_end, processed_outputs)
if training_epoch_end_output is not None:
raise MisconfigurationException(
"training_epoch_end expects a return of None. "
Expand All @@ -240,9 +245,13 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]:

self.trainer.fit_loop.epoch_progress.increment_processed()

# call train epoch end hooks
self.trainer.call_hook("on_train_epoch_end")
self.trainer.call_hook("on_epoch_end")
self.trainer.call_hook(self.trainer.on_train_epoch_end)
self.trainer.call_hook(model.on_train_epoch_end)
self.trainer.call_hook(self.trainer.accelerator.on_train_epoch_end)

self.trainer.call_hook(self.trainer.on_epoch_end)
self.trainer.call_hook(model.on_epoch_end)

self.trainer.logger_connector.on_epoch_end()

if self._num_training_batches_reached(self.is_last_batch):
Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,10 @@ def reset(self) -> None:
def on_run_start(self) -> None:
"""Calls the ``on_train_start`` hook."""
self._results.to(device=self.trainer.lightning_module.device)
self.trainer.call_hook("on_train_start")

self.trainer.call_hook(self.trainer.on_train_start)
self.trainer.call_hook(self.trainer.lightning_module.on_train_start)
self.trainer.call_hook(self.trainer.accelerator.on_train_start)

def on_advance_start(self) -> None:
"""Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and ``on_train_epoch_start``"""
Expand Down Expand Up @@ -223,10 +226,9 @@ def on_run_end(self) -> None:
self.current_epoch -= 1

# hook
self.trainer.call_hook("on_train_end")

# give accelerators a chance to finish
self.trainer.accelerator.on_train_end()
self.trainer.call_hook(self.trainer.on_train_end)
self.trainer.call_hook(self.trainer.lightning_module.on_train_end)
self.trainer.call_hook(self.trainer.accelerator.on_train_end)

def should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated"""
Expand Down
Loading