Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restrict public interface of training loop #8024

Merged
merged 23 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
9d48e99
active optimizers
awaelchli Jun 18, 2021
515687a
check checkpoint callback
awaelchli Jun 18, 2021
ea5c885
epoch loop properties
awaelchli Jun 18, 2021
c37907d
epoch loop methods
awaelchli Jun 18, 2021
9ad8e39
training_batch_loop
awaelchli Jun 18, 2021
7715650
changelog
awaelchli Jun 18, 2021
f026740
update chlog
awaelchli Jun 18, 2021
a9a3296
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2021
ab8c6ca
unused imports
awaelchli Jun 18, 2021
5199c85
Merge remote-tracking branch 'origin/refactor/loops/public-interface'…
awaelchli Jun 18, 2021
5f7b748
yapf
awaelchli Jun 18, 2021
b3ec83b
backward
awaelchli Jun 18, 2021
acdeabe
fix missing string reference
awaelchli Jun 18, 2021
a07a3a3
Merge branch 'master' into refactor/loops/public-interface
awaelchli Jun 18, 2021
a41e426
Merge branch 'master' into refactor/loops/public-interface
awaelchli Jun 18, 2021
5671662
is_last_batch remains public
awaelchli Jun 18, 2021
16809f3
remove dead code
awaelchli Jun 18, 2021
bf3c6ac
Merge branch 'master' into refactor/loops/public-interface
awaelchli Jun 21, 2021
1bc5355
Merge branch 'master' into refactor/loops/public-interface
carmocca Jun 21, 2021
48ed87e
Merge branch 'master' into refactor/loops/public-interface
awaelchli Jun 22, 2021
79f79d2
Merge branch 'master' into refactor/loops/public-interface
justusschock Jun 23, 2021
a7cb382
Merge branch 'master' into refactor/loops/public-interface
tchaton Jun 23, 2021
ccccf15
remove unused imports
awaelchli Jun 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701))
* Refactored internal loop interface; added new classes `FitLoop`, `TrainingEpochLoop`, `TrainingBatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871))
* Removed `pytorch_lightning/trainer/training_loop.py` ([#7985](https://github.com/PyTorchLightning/pytorch-lightning/pull/7985))
* Restricted public access to several internal functions ([#8024](https://github.com/PyTorchLightning/pytorch-lightning/pull/8024))

- Refactored logging
* Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def _store(

def on_train_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
for opt_idx, optimizer in trainer.train_loop.get_active_optimizers():
for opt_idx, optimizer in trainer.fit_loop.training_loop.batch_loop.get_active_optimizers():
num_param_groups = len(optimizer.param_groups)
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
current_param_groups = optimizer.param_groups
Expand Down
13 changes: 4 additions & 9 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@

import logging
from contextlib import suppress
from typing import Any, List, Optional, Tuple
from typing import Any, Optional

from deprecate import void
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.loops.base import Loop
Expand Down Expand Up @@ -226,7 +225,7 @@ def on_advance_end(self) -> None:
)
if did_train_only:
self.global_step -= 1
self.check_checkpoint_callback(True)
self._check_checkpoint_callback(True)
self.global_step += 1

def on_run_end(self) -> None:
Expand All @@ -241,7 +240,7 @@ def on_run_end(self) -> None:
# when a checkpoint was saved at the last step
self.training_loop.global_step -= 1
# TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406
self.check_checkpoint_callback(should_update=True, is_last=True)
self._check_checkpoint_callback(should_update=True, is_last=True)
self.training_loop.global_step += 1

# hook
Expand All @@ -266,11 +265,7 @@ def should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated"""
return self.training_loop.batch_loop.should_accumulate()

def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]:
"""Generates a list of active optimizers"""
return self.training_loop.batch_loop.get_active_optimizers(batch_idx)

def check_checkpoint_callback(self, should_update: bool, is_last: bool = False):
def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False):
"""Checks if checkpointing needs to be done"""
# TODO: bake this logic into the ModelCheckpoint callback
if should_update and self.trainer.checkpoint_connector.has_trained:
Expand Down
92 changes: 30 additions & 62 deletions pytorch_lightning/loops/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def on_run_start(self, batch: Any, batch_idx: int, dataloader_idx: int):
dataloader_idx: the index of the dataloader producing the current batch
"""
void(batch_idx, dataloader_idx)
self._remaining_splits = list(enumerate(self.tbptt_split_batch(batch)))
self._remaining_splits = list(enumerate(self._tbptt_split_batch(batch)))

def advance(self, batch, batch_idx, dataloader_idx):
"""Runs the train step together with optimization (if necessary) on the current batch split
Expand Down Expand Up @@ -162,10 +162,10 @@ def _run_optimization(
# opt_idx=0 to opt_idx=None in the signature here

# toggle model params
self.run_optimization_start(opt_idx, optimizer)
self._run_optimization_start(opt_idx, optimizer)

result = AttributeDict()
closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result)
closure = self._make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result)

if self.should_accumulate():
# For gradient accumulation
Expand All @@ -184,24 +184,24 @@ def _run_optimization(
# gradient update with accumulated gradients
else:
if self.trainer.lightning_module.automatic_optimization:
self.optimizer_step(optimizer, opt_idx, batch_idx, closure)
self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
if len(self.trainer.optimizers) > 1:
# revert back to previous state
self.trainer.lightning_module.untoggle_optimizer(opt_idx)
else:
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)
result = self._training_step(split_batch, batch_idx, opt_idx, self._hiddens)

if not result:
# user decided to skip optimization
return result

# update running loss + reset accumulated loss
self.update_running_loss(result.loss)
self._update_running_loss(result.loss)

self._process_closure_result(result)
return result

def training_step_and_backward_closure(
def _training_step_and_backward_closure(
self,
split_batch: Any,
batch_idx: int,
Expand All @@ -226,10 +226,10 @@ def training_step_and_backward_closure(
return_result.update(result)
return return_result.loss

def make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable:
def _make_closure(self, *closure_args: Any, **closure_kwargs: Any) -> Callable:
""" Wraps the training step closure into a partial object which will be called within ``optimizer.step``. """
partial_func = partial(self.training_step_and_backward_closure, *closure_args, **closure_kwargs)
return update_wrapper(partial_func, self.training_step_and_backward_closure)
partial_func = partial(self._training_step_and_backward_closure, *closure_args, **closure_kwargs)
return update_wrapper(partial_func, self._training_step_and_backward_closure)

def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None:
"""Checks if the closure results is finite and optionally breaks if it is not
Expand All @@ -244,7 +244,7 @@ def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -
if self.trainer.terminate_on_nan:
self._check_finite(opt_closure_result.loss)

def on_after_backward(self, batch_idx: int, untouched_loss: Tensor) -> None:
def _on_after_backward(self, batch_idx: int, untouched_loss: Tensor) -> None:
"""Calls ``on_after_backward`` hook and tracks loss history

Args:
Expand Down Expand Up @@ -281,7 +281,13 @@ def _check_training_step_output(self, training_step_output: STEP_OUTPUT) -> None
"a dict with key 'loss' or None (where the step will be skipped)."
)

def training_step(self, split_batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor) -> Optional[AttributeDict]:
def _training_step(
self,
split_batch: Any,
batch_idx: int,
opt_idx: int,
hiddens: Tensor,
) -> Optional[AttributeDict]:
"""Performs the actual train step with the tied hooks.

Args:
Expand Down Expand Up @@ -360,7 +366,7 @@ def _process_training_step_output(self, training_step_output: STEP_OUTPUT) -> Op
results.cpu()
return results

def optimizer_step(
def _optimizer_step(
self, optimizer: torch.optim.Optimizer, opt_idx: int, batch_idx: int, train_step_and_backward_closure: Callable
) -> None:
"""Performs the optimizer step and some sanity checking.
Expand Down Expand Up @@ -399,15 +405,15 @@ def optimizer_step(
using_lbfgs=is_lbfgs,
)

def on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
"""Calls the ``on_before_zero_grad`` hook.

Args:
optimizer: the current optimizer
"""
self.trainer.call_hook('on_before_zero_grad', optimizer)

def optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None:
"""Zeroes out all gradients of parameters optimized by the current optimizer.

Args:
Expand All @@ -417,7 +423,7 @@ def optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer,
"""
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)

def track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]:
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.

Args:
Expand Down Expand Up @@ -457,7 +463,7 @@ def should_accumulate(self) -> bool:
is_final_batch = self._num_training_batches_reached()
return not (accumulation_done or is_final_batch)

def tbptt_split_batch(self, batch: Any) -> List[Any]:
def _tbptt_split_batch(self, batch: Any) -> List[Any]:
"""Splits a single batch into a list of sequence steps for tbptt.

Args:
Expand All @@ -470,45 +476,7 @@ def tbptt_split_batch(self, batch: Any) -> List[Any]:
splits = model_ref.tbptt_split_batch(batch, self.trainer.truncated_bptt_steps)
return splits

def build_train_args(self, batch: Any, batch_idx: int, opt_idx: int, hiddens: Tensor) -> List[Any]:
"""Builds arguments for train step

Args:
batch: the current batch to train on
batch_idx: the index of the current batch
opt_idx: the index of the current optimizer
hiddens: the hidden state of the previous RNN iteration

Returns:
the positional arguments for training
"""
# enable not needing to add opt_idx to training_step
args = [batch, batch_idx]

if len(self.trainer.optimizers) > 1:
if self.trainer.has_arg("training_step", "optimizer_idx"):
if not self.trainer.lightning_module.automatic_optimization:
self.warning_cache.warn(
"`training_step` hook signature has changed in v1.3."
" `optimizer_idx` argument has been removed in case of manual optimization. Support for"
" the old signature will be removed in v1.5", DeprecationWarning
)
args.append(opt_idx)
elif not self.trainer.has_arg(
"training_step", "optimizer_idx"
) and self.trainer.lightning_module.automatic_optimization:
raise ValueError(
f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but"
' `training_step` is missing the `optimizer_idx` argument.'
)

# pass hiddens if using tbptt
if self.trainer.truncated_bptt_steps is not None:
args.append(hiddens)

return args

def run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
"""Toggles the optimizer to ensure the correct one is used and prevend dangling grads.

Args:
Expand Down Expand Up @@ -556,14 +524,14 @@ def training_step_and_backward(
"""Wrap forward, zero_grad and backward in a closure so second order methods work"""
with self.trainer.profiler.profile("training_step_and_backward"):
# lightning module hook
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
result = self._training_step(split_batch, batch_idx, opt_idx, hiddens)

if not self._skip_backward and self.trainer.lightning_module.automatic_optimization:
is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0

if is_first_batch_to_accumulate:
self.on_before_zero_grad(optimizer)
self.optimizer_zero_grad(batch_idx, optimizer, opt_idx)
self._on_before_zero_grad(optimizer)
self._optimizer_zero_grad(batch_idx, optimizer, opt_idx)

# backward pass
if result is not None:
Expand All @@ -573,7 +541,7 @@ def training_step_and_backward(
# hook - call this hook only
# when gradients have finished to accumulate
if not self.should_accumulate():
self.on_after_backward(batch_idx, result.loss)
self._on_after_backward(batch_idx, result.loss)

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
Expand Down Expand Up @@ -621,12 +589,12 @@ def backward(

if not self.should_accumulate():
# track gradients
grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer)
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
if grad_norm_dict:
self.trainer.lightning_module._current_fx_name = "on_after_backward"
self.trainer.lightning_module.log_grad_norm(grad_norm_dict)

def update_running_loss(self, current_loss: Tensor) -> None:
def _update_running_loss(self, current_loss: Tensor) -> None:
"""Updates the running loss value with the current value"""
if self.trainer.lightning_module.automatic_optimization:
# track total loss for logging (avoid mem leaks)
Expand Down
Loading