From 341adad81927cf089dce9ca7a904629c6963cc30 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 16 Jun 2021 10:00:33 +0200 Subject: [PATCH] Loop Refactor 2/N - Remove Old Training Loop (#7985) Co-authored-by: Justus Schock Co-authored-by: Justus Schock Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- CHANGELOG.md | 1 + pytorch_lightning/trainer/trainer.py | 4 +- pytorch_lightning/trainer/training_loop.py | 923 --------------------- setup.cfg | 1 - 4 files changed, 3 insertions(+), 926 deletions(-) delete mode 100644 pytorch_lightning/trainer/training_loop.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a46427fe623c..c35dda4876dd1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -125,6 +125,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Simplified logic for updating the learning rate for schedulers ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682)) * Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701)) * Refactored internal loop interface; added new classes `FitLoop`, `TrainingEpochLoop`, `TrainingBatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871)) + * Removed `pytorch_lightning/trainer/training_loop.py` ([#7985](https://github.com/PyTorchLightning/pytorch-lightning/pull/7985)) - Refactored logging * Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1f97a5f2c5767..158fe20beee77 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1229,10 +1229,10 @@ def _call_teardown_hook(self, model: LightningModule) -> None: model._current_dataloader_idx = None def call_hook(self, hook_name: str, *args, **kwargs) -> Any: - # Note this implementation is copy/pasted into the TrainLoop class in TrainLoop._on_train_epoch_end_hook + # Note this implementation is copy/pasted into the TrainLoop class in TrainingEpochLoop._on_train_epoch_end_hook # This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end # If making changes to this function, ensure that those changes are also made to - # TrainLoop._on_train_epoch_end_hook + # TrainingEpochLoop._on_train_epoch_end_hook if self.lightning_module: prev_fx_name = self.lightning_module._current_fx_name self.lightning_module._current_fx_name = hook_name diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py deleted file mode 100644 index f76568454b7ac..0000000000000 --- a/pytorch_lightning/trainer/training_loop.py +++ /dev/null @@ -1,923 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import OrderedDict -from contextlib import contextmanager, suppress -from functools import partial, update_wrapper -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union - -import numpy as np -import torch -from torch.optim import Optimizer - -from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.plugins import ParallelPlugin -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.supporters import TensorRunningAccum -from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, DeviceType -from pytorch_lightning.utilities.distributed import rank_zero_info -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.finite_checks import detect_nan_parameters -from pytorch_lightning.utilities.grads import grad_norm -from pytorch_lightning.utilities.model_helpers import is_overridden -from pytorch_lightning.utilities.parsing import AttributeDict -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.warnings import WarningCache - - -class TrainLoop: - - def __init__( - self, - trainer, - max_epochs: Optional[int], - min_epochs: Optional[int], - max_steps: Optional[int], - min_steps: Optional[int], - num_sanity_val_steps: int, - ): - self.trainer = trainer - self.accumulated_loss = None - self.warning_cache = WarningCache() - self.running_loss = TensorRunningAccum(window_length=20) - self._skip_backward = False - self._optimizer_freq_cumsum = None - self._hiddens = None - - self.global_step = 0 - self.current_epoch = 0 - self.trainer.should_stop = False - - # the total batch index across all epochs - self.total_batch_idx = 0 - # the current batch index in the loop that runs over the dataloader(s) - self.batch_idx = 0 - # the current split index when the batch gets split into chunks in truncated backprop through time - self.split_idx = None - - self.trainer.num_training_batches = 0 - self.trainer.train_dataloader = None - - # If neither max_epochs or max_steps is set, then use existing default of max_epochs = 1000 - self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs - # If neither min_epochs or min_steps is set, then use existing default of min_epochs = 1 - self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.max_steps = max_steps - self.min_steps = min_steps - - if num_sanity_val_steps == -1: - self.trainer.num_sanity_val_steps = float("inf") - else: - self.trainer.num_sanity_val_steps = num_sanity_val_steps - - self.results = ResultCollection(training=True) - - @property - def num_active_optimizers(self) -> int: - return len(self.get_active_optimizers()) - - @property - def optimizer_freq_cumsum(self): - if self._optimizer_freq_cumsum is None: - self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) - return self._optimizer_freq_cumsum - - def should_skip_training(self) -> bool: - should_by_max_steps = self.max_steps is not None and self.global_step >= self.max_steps - should_by_epoch = self.max_epochs is not None and self.current_epoch >= self.max_epochs - return should_by_max_steps or should_by_epoch or self.trainer.num_training_batches == 0 - - def on_train_start(self): - self.results.to(device=self.trainer.lightning_module.device) - - self.trainer.call_hook("on_train_start") - - def on_train_end(self): - # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates - # when a checkpoint was saved at the last step - self.global_step -= 1 - self.check_checkpoint_callback(should_update=True, is_last=True) - self.global_step += 1 - - # hook - self.trainer.call_hook("on_train_end") - - # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. - # It might be related to xla tensors blocked when moving the cpu - # kill loggers - if self.trainer.logger is not None: - self.trainer.logger.finalize("success") - - # summarize profile results - self.trainer.profiler.describe() - - # give accelerators a chance to finish - self.trainer.accelerator.on_train_end() - - # reset bookkeeping - self.trainer.state.stage = None - - def check_checkpoint_callback(self, should_update, is_last=False): - # TODO bake this logic into the ModelCheckpoint callback - if should_update and self.trainer.checkpoint_connector.has_trained: - callbacks = self.trainer.checkpoint_callbacks - - if is_last and any(cb.save_last and cb.verbose for cb in callbacks): - rank_zero_info("Saving latest checkpoint...") - - model = self.trainer.lightning_module - - for cb in callbacks: - cb.on_validation_end(self.trainer, model) - - def on_train_epoch_start(self, epoch): - - # update training progress in trainer - self.current_epoch = epoch - - model = self.trainer.lightning_module - - # reset train dataloader - if epoch != 0 and self.trainer.reload_dataloaders_every_epoch: - self.trainer.reset_train_dataloader(model) - - # todo: specify the possible exception - with suppress(Exception): - # set seed for distributed sampler (enables shuffling for each epoch) - self.trainer.train_dataloader.sampler.set_epoch(epoch) - - # changing gradient according accumulation_scheduler - self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module) - - # stores accumulated grad fractions per batch - self.accumulated_loss = TensorRunningAccum(window_length=self.trainer.accumulate_grad_batches) - - # hook - self.trainer.logger_connector.on_epoch_start() - self.trainer.call_hook("on_epoch_start") - self.trainer.call_hook("on_train_epoch_start") - - def on_train_batch_end(self, epoch_output, batch_end_outputs, batch, batch_idx, dataloader_idx): - batch_end_outputs = [opt_idx_out for opt_idx_out in batch_end_outputs if len(opt_idx_out)] - - processed_batch_end_outputs = TrainLoop._prepare_outputs(batch_end_outputs, batch_mode=True) - - # hook - self.trainer.call_hook('on_train_batch_end', processed_batch_end_outputs, batch, batch_idx, dataloader_idx) - self.trainer.call_hook('on_batch_end') - self.trainer.logger_connector.on_batch_end() - - # figure out what to track for epoch end - self.track_epoch_end_reduce_metrics(epoch_output, batch_end_outputs) - - def reset_train_val_dataloaders(self, model) -> None: - """ - Resets train and val dataloaders if none are attached to the trainer. - - The val dataloader must be initialized before training loop starts, as the training loop - inspects the val dataloader to determine whether to run the evaluation loop. - """ - if self.trainer.train_dataloader is None: - self.trainer.reset_train_dataloader(model) - - if self.trainer.val_dataloaders is None: - self.trainer.reset_val_dataloader(model) - - def track_epoch_end_reduce_metrics(self, epoch_output, batch_end_outputs): - hook_overridden = self._should_add_batch_output_to_epoch_output() - if not hook_overridden: - return - - # track the outputs to reduce at the end of the epoch - for opt_idx, opt_outputs in enumerate(batch_end_outputs): - # with 1 step (no tbptt) don't use a sequence at epoch end - if ( - isinstance(opt_outputs, list) and len(opt_outputs) == 1 - and not isinstance(opt_outputs[0], ResultCollection) - ): - opt_outputs = opt_outputs[0] - - epoch_output[opt_idx].append(opt_outputs) - - def _should_add_batch_output_to_epoch_output(self) -> bool: - # We add to the epoch outputs if - # 1. The model defines training_epoch_end OR - # 2. The model overrides on_train_epoch_end which has `outputs` in the signature - # TODO: in v1.5 this only needs to check if training_epoch_end is overridden - lightning_module = self.trainer.lightning_module - if is_overridden("training_epoch_end", lightning_module): - return True - - if is_overridden("on_train_epoch_end", lightning_module): - model_hook_fx = getattr(lightning_module, "on_train_epoch_end") - if is_param_in_hook_signature(model_hook_fx, "outputs"): - return True - - return False - - def get_active_optimizers(self, batch_idx: Optional[int] = None) -> List[Tuple[int, Optimizer]]: - """ - Returns the currently active optimizers. When multiple optimizers are used with different frequencies, - only one of the optimizers is active at a time. - - Returns: - A list of tuples (opt_idx, optimizer) of currently active optimizers. - """ - if not self.trainer.optimizer_frequencies: - # call training_step once per optimizer - return list(enumerate(self.trainer.optimizers)) - - batch_idx = self.total_batch_idx if batch_idx is None else batch_idx - optimizers_loop_length = self.optimizer_freq_cumsum[-1] - current_place_in_loop = batch_idx % optimizers_loop_length - - # find optimzier index by looking for the first {item > current_place} in the cumsum list - opt_idx = int(np.argmax(self.optimizer_freq_cumsum > current_place_in_loop)) - return [(opt_idx, self.trainer.optimizers[opt_idx])] - - def on_after_backward(self, batch_idx, untouched_loss): - # insert after step hook - self.trainer.call_hook("on_after_backward") - - # when in dev debugging track the losses - self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) - - def _check_training_step_output(self, training_step_output): - if isinstance(training_step_output, torch.Tensor) and not self.trainer.lightning_module.automatic_optimization: - if training_step_output.grad_fn is None: - # TODO: Find why - RuntimeError: Expected to mark a variable ready only once ... - raise MisconfigurationException("In manual optimization, `training_step` should not return a Tensor") - elif self.trainer.lightning_module.automatic_optimization: - if not any(( - isinstance(training_step_output, torch.Tensor), - (isinstance(training_step_output, Mapping) - and 'loss' in training_step_output), training_step_output is None - )): - raise MisconfigurationException( - "In automatic optimization, `training_step` must either return a Tensor, " - "a dict with key 'loss' or None (where the step will be skipped)." - ) - - def training_step(self, split_batch, batch_idx, opt_idx, hiddens): - # give the PL module a result for logging - model_ref = self.trainer.lightning_module - - with self.trainer.profiler.profile("model_forward"): - step_kwargs = self._build_kwargs(split_batch, batch_idx, opt_idx, hiddens) - - # manually capture logged metrics - model_ref._current_fx_name = 'training_step' - with self.trainer.profiler.profile("training_step"): - training_step_output = self.trainer.accelerator.training_step(step_kwargs) - self.trainer.accelerator.post_training_step() - - training_step_output = self.trainer.call_hook("training_step_end", training_step_output) - - self._check_training_step_output(training_step_output) - - training_step_output = self._process_training_step_output(training_step_output) - if training_step_output is None: - return - - closure_loss = None - loss = None - if self.trainer.lightning_module.automatic_optimization: - # accumulate loss. if accumulate_grad_batches==1, no effect - closure_loss = training_step_output.minimize / self.trainer.accumulate_grad_batches - # the loss will get scaled for amp. avoid any modifications to it - loss = closure_loss.detach().clone() - return AttributeDict(closure_loss=closure_loss, loss=loss, training_step_output=training_step_output) - - def _process_training_step_output(self, training_step_output): - if training_step_output is None: - return None - - results = self.results - loss = None - hiddens = None - results.extra = {} - - # handle dict return - if isinstance(training_step_output, dict): - loss = training_step_output.pop("loss", None) - hiddens = training_step_output.pop("hiddens", None) - if hiddens is not None: - hiddens = hiddens.detach() - results.extra = training_step_output - - # handle scalar return - elif isinstance(training_step_output, torch.Tensor): - loss = training_step_output - - # map to results under the hood - results.minimize = loss - self._hiddens = hiddens - - if self.trainer.move_metrics_to_cpu: - results.cpu() - return results - - @staticmethod - def _prepare_outputs( - outputs: List[List[List['ResultCollection']]], - batch_mode: bool, - ) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]: - """ - Extract required information from batch or epoch end results. - - Args: - outputs: A 3-dimensional list of ``ResultCollection`` objects with dimensions: - ``[optimizer outs][batch outs][tbptt steps]``. - - batch_mode: If True, ignore the batch output dimension. - - Returns: - The cleaned outputs with ``ResultCollection`` objects converted to dictionaries. - All list dimensions of size one will be collapsed. - """ - processed_outputs = [] - for opt_outputs in outputs: - # handle an edge case where an optimizer output is the empty list - if len(opt_outputs) == 0: - continue - - processed_batch_outputs = [] - - if batch_mode: - opt_outputs = [opt_outputs] - - for batch_outputs in opt_outputs: - processed_tbptt_outputs = [] - - if isinstance(batch_outputs, ResultCollection): - batch_outputs = [batch_outputs] - - for tbptt_output in batch_outputs: - out = tbptt_output.extra - if tbptt_output.minimize is not None: - out['loss'] = tbptt_output.minimize.detach() - processed_tbptt_outputs.append(out) - - # if there was only one tbptt step then we can collapse that dimension - if len(processed_tbptt_outputs) == 1: - processed_tbptt_outputs = processed_tbptt_outputs[0] - processed_batch_outputs.append(processed_tbptt_outputs) - - # batch_outputs should be just one dict (or a list of dicts if using tbptt) per optimizer - if batch_mode: - processed_batch_outputs = processed_batch_outputs[0] - processed_outputs.append(processed_batch_outputs) - - # if there is only one optimiser then we collapse that dimension - if len(processed_outputs) == 1: - processed_outputs = processed_outputs[0] - return processed_outputs - - def optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure): - model_ref = self.trainer.lightning_module - - is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) - using_native_amp = self.trainer.amp_backend == AMPType.NATIVE - - # native amp + lbfgs is a no go right now - if using_native_amp and is_lbfgs: - raise MisconfigurationException( - 'native PyTorch amp and lbfgs are not compatible.' - ' To request, please file a Github issue in PyTorch and tag @mcarilli' - ) - - # wraps into LightningOptimizer only for running step - optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) - - # model hook - model_ref.optimizer_step( - self.trainer.current_epoch, - batch_idx, - optimizer, - opt_idx, - train_step_and_backward_closure, - on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE, - using_native_amp=using_native_amp, - using_lbfgs=is_lbfgs, - ) - - def on_before_zero_grad(self, optimizer): - self.trainer.call_hook('on_before_zero_grad', optimizer) - - def optimizer_zero_grad(self, batch_idx, optimizer, opt_idx): - self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) - - def track_and_norm_grad(self, optimizer) -> dict: - # track gradient norms - grad_norm_dict = {} - if (self.global_step + 1) % self.trainer.log_every_n_steps == 0 and float(self.trainer.track_grad_norm) > 0: - grad_norm_dict = grad_norm(self.trainer.lightning_module, self.trainer.track_grad_norm) - - # clip gradients - self.trainer.accelerator.clip_gradients( - optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm - ) - return grad_norm_dict - - def _tbptt_split_batch(self, batch: Any) -> List[Any]: - splits = [batch] - truncated_bptt_enabled = self._truncated_bptt_enabled() - if truncated_bptt_enabled: - model_ref = self.trainer.lightning_module - with self.trainer.profiler.profile("tbptt_split_batch"): - splits = model_ref.tbptt_split_batch(batch, self._truncated_bptt_steps()) - return splits - - def run_training_epoch(self): - # modify dataloader if needed (ddp, etc...) - train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) - - # track epoch output - epoch_output = [[] for _ in range(self.num_active_optimizers)] - - train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) - dataloader_idx = 0 - batch_idx = None - - for batch_idx, (batch, is_last_batch) in train_dataloader: - self.batch_idx = batch_idx - - # ------------------------------------ - # TRAINING_STEP + TRAINING_STEP_END - # ------------------------------------ - with self.trainer.profiler.profile("run_training_batch"): - batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) - - # when returning -1 from train_step, we end epoch early - if batch_output.signal == -1: - break - - # hook - self.on_train_batch_end( - epoch_output, - batch_output.training_step_output, - batch, - batch_idx, - dataloader_idx, - ) - - # ----------------------------------------- - # SAVE METRICS TO LOGGERS AND PROGRESS_BAR - # ----------------------------------------- - self.trainer.logger_connector.update_train_step_metrics() - - # ----------------------------------------- - # VALIDATE IF NEEDED - # ----------------------------------------- - should_check_val = self._should_check_val_fx(batch_idx, is_last_batch) - if should_check_val: - self.trainer.validating = True - self.trainer._run_evaluation() - self.trainer.training = True - - # ----------------------------------------- - # SAVE LOGGERS (ie: Tensorboard, etc...) - # ----------------------------------------- - self.save_loggers_on_train_batch_end() - - # update LR schedulers - self.update_lr_schedulers('step') - self.trainer.checkpoint_connector.has_trained = True - - self.total_batch_idx += 1 - - # progress global step according to grads progress - self.increment_accumulated_grad_global_step() - - max_steps_reached = (self.max_steps is not None and self.max_steps <= self.global_step) - if max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(is_last_batch): - break - - if batch_idx is None: - # dataloader/iterator did not produce a batch - return - - # handle epoch_output on epoch end - self.on_train_epoch_end(epoch_output) - - # the global step is manually decreased here due to backwards compatibility with existing loggers - # as they expect that the same step is used when logging epoch end metrics even when the batch loop has - # finished. this means the attribute does not exactly track the number of optimizer steps applied. - # TODO(@carmocca): deprecate and rename so users don't get confused - self.global_step -= 1 - # log epoch metrics - self.trainer.logger_connector.update_train_epoch_metrics() - self.global_step += 1 - - self.update_lr_schedulers('epoch') - - did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.should_skip_evaluation( - self.trainer.num_val_batches - ) - if did_train_only: - self.global_step -= 1 - self.check_checkpoint_callback(True) - self.global_step += 1 - - def on_train_epoch_end(self, epoch_output: List[List[List['ResultCollection']]]) -> None: - # inform logger the batch loop has finished - self.trainer.logger_connector.epoch_end_reached() - - # prepare epoch output - processed_epoch_output = TrainLoop._prepare_outputs(epoch_output, batch_mode=False) - - # get the model and call model.training_epoch_end - model = self.trainer.lightning_module - - if is_overridden('training_epoch_end', model): - # run training_epoch_end - # refresh the result for custom logging at the epoch level - model._current_fx_name = 'training_epoch_end' - training_epoch_end_output = model.training_epoch_end(processed_epoch_output) - - if training_epoch_end_output is not None: - raise MisconfigurationException( - 'training_epoch_end expects a return of None. ' - 'HINT: remove the return statement in training_epoch_end' - ) - - # call train epoch end hooks - self._on_train_epoch_end_hook(processed_epoch_output) - self.trainer.call_hook('on_epoch_end') - self.trainer.logger_connector.on_epoch_end() - - def _on_train_epoch_end_hook(self, processed_epoch_output) -> None: - # We cannot rely on Trainer.call_hook because the signatures might be different across - # lightning module and callback - # As a result, we need to inspect if the module accepts `outputs` in `on_train_epoch_end` - - # This implementation is copied from Trainer.call_hook - hook_name = "on_train_epoch_end" - prev_fx_name = self.trainer.lightning_module._current_fx_name - self.trainer.lightning_module._current_fx_name = hook_name - - # always profile hooks - with self.trainer.profiler.profile(hook_name): - - # first call trainer hook - if hasattr(self.trainer, hook_name): - trainer_hook = getattr(self.trainer, hook_name) - trainer_hook(processed_epoch_output) - - # next call hook in lightningModule - model_ref = self.trainer.lightning_module - if is_overridden(hook_name, model_ref): - hook_fx = getattr(model_ref, hook_name) - if is_param_in_hook_signature(hook_fx, "outputs"): - self.warning_cache.warn( - "The signature of `ModelHooks.on_train_epoch_end` has changed in v1.3." - " `outputs` parameter has been deprecated." - " Support for the old signature will be removed in v1.5", DeprecationWarning - ) - model_ref.on_train_epoch_end(processed_epoch_output) - else: - model_ref.on_train_epoch_end() - - # call the accelerator hook - if hasattr(self.trainer.accelerator, hook_name): - accelerator_hook = getattr(self.trainer.accelerator, hook_name) - accelerator_hook() - - # restore current_fx when nested context - self.trainer.lightning_module._current_fx_name = prev_fx_name - - def run_training_batch(self, batch, batch_idx, dataloader_idx): - # bookkeeping - self._hiddens = None - - optimizers = list(enumerate(self.trainer.optimizers)) - - # track all outputs across time and num of optimizers - batch_outputs = [[] for _ in range(len(optimizers))] - - if batch is None: - self.warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...") - return AttributeDict(signal=0, training_step_output=batch_outputs) - - # hook - self.trainer.logger_connector.on_batch_start() - response = self.trainer.call_hook("on_batch_start") - if response == -1: - return AttributeDict(signal=-1) - - # hook - response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, dataloader_idx) - if response == -1: - return AttributeDict(signal=-1) - - # lightning module hook - splits = self._tbptt_split_batch(batch) - - for split_idx, split_batch in enumerate(splits): - self.split_idx = split_idx - - # let logger connector extract batch size - self.trainer.logger_connector.on_train_split_start(batch_idx, split_idx, split_batch) - - if self.trainer.lightning_module.automatic_optimization: - for opt_idx, optimizer in self.get_active_optimizers(batch_idx): - result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) - if result: - batch_outputs[opt_idx].append(result.training_step_output) - else: - # in manual optimization, there is no looping over optimizers - result = self._run_optimization(batch_idx, split_batch) - if result: - batch_outputs[0].append(result.training_step_output) - - return AttributeDict(signal=0, training_step_output=batch_outputs) - - def _run_optimization(self, batch_idx, split_batch, opt_idx=0, optimizer=None): - # TODO: In v1.5, when optimizer_idx gets removed from training_step in manual_optimization, change - # opt_idx=0 to opt_idx=None in the signature here - - # toggle model params - self.run_optimization_start(opt_idx, optimizer) - - result = AttributeDict() - closure = self.make_closure(split_batch, batch_idx, opt_idx, optimizer, self._hiddens, result) - - if self.should_accumulate(): - # For gradient accumulation - - # ------------------- - # calculate loss (train step + train step end) - # ------------------- - # automatic_optimization=True: perform ddp sync only when performing optimizer_step - # automatic_optimization=False: don't block synchronization here - with self.block_ddp_sync_behaviour(): - closure() - - # ------------------------------ - # BACKWARD PASS - # ------------------------------ - # gradient update with accumulated gradients - else: - if self.trainer.lightning_module.automatic_optimization: - self.optimizer_step(optimizer, opt_idx, batch_idx, closure) - if len(self.trainer.optimizers) > 1: - # revert back to previous state - self.trainer.lightning_module.untoggle_optimizer(opt_idx) - else: - result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens) - - if not result: - # user decided to skip optimization - return result - - # update running loss + reset accumulated loss - self.update_running_loss(result.loss) - - self._process_closure_result(result) - return result - - def training_step_and_backward_closure( - self, - split_batch: Any, - batch_idx: int, - opt_idx: int, - optimizer: Optimizer, - hiddens, - return_result: AttributeDict, - ) -> Optional[torch.Tensor]: - - result = self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, hiddens) - if result is not None: - return_result.update(result) - return return_result.loss - - def make_closure(self, *closure_args, **closure_kwargs: Any) -> Callable: - """ Wraps the training step closure into a partial object which will be called within ``optimizer.step``. """ - partial_func = partial(self.training_step_and_backward_closure, *closure_args, **closure_kwargs) - return update_wrapper(partial_func, self.training_step_and_backward_closure) - - @contextmanager - def block_ddp_sync_behaviour(self, should_block_sync: bool = False): - """ - automatic_optimization = True - Blocks ddp sync gradients behaviour on backwards pass. - This is useful for skipping sync when accumulating gradients, reducing communication overhead - - automatic_optimization = False - do not block ddp gradient sync when using manual optimization - as gradients are needed within the training step - - Returns: - context manager with sync behaviour off - - """ - if ( - isinstance(self.trainer.training_type_plugin, ParallelPlugin) - and (self.trainer.lightning_module.automatic_optimization or should_block_sync) - ): - with self.trainer.training_type_plugin.block_backward_sync(): - yield None - else: - yield None - - def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -> None: - if not opt_closure_result: - return - - # check if loss or model weights are nan - if self.trainer.terminate_on_nan: - self._check_finite(opt_closure_result.loss) - - def training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens): - """Wrap forward, zero_grad and backward in a closure so second order methods work""" - with self.trainer.profiler.profile("training_step_and_backward"): - # lightning module hook - result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) - - if not self._skip_backward and self.trainer.lightning_module.automatic_optimization: - is_first_batch_to_accumulate = batch_idx % self.trainer.accumulate_grad_batches == 0 - - if is_first_batch_to_accumulate: - self.on_before_zero_grad(optimizer) - self.optimizer_zero_grad(batch_idx, optimizer, opt_idx) - - # backward pass - if result is not None: - with self.trainer.profiler.profile("backward"): - self.backward(result, optimizer, opt_idx) - - # hook - call this hook only - # when gradients have finished to accumulate - if not self.should_accumulate(): - self.on_after_backward(batch_idx, result.loss) - - # check if loss or model weights are nan - if self.trainer.terminate_on_nan: - self._check_finite(result.loss) - - else: - self.warning_cache.warn( - "training_step returned None. If this was on purpose, ignore this warning..." - ) - - return result - - def _check_finite(self, loss: torch.Tensor) -> None: - if not torch.isfinite(loss).all(): - raise ValueError(f'The loss returned in `training_step` is {loss}.') - model = self.trainer.lightning_module - detect_nan_parameters(model) - - def backward(self, result, optimizer, opt_idx, *args, **kwargs): - self.trainer.dev_debugger.track_event("backward_call") - - should_accumulate = self.should_accumulate() - - # backward can be called manually in the training loop - if isinstance(result, torch.Tensor): - self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) - else: - result.closure_loss = self.trainer.accelerator.backward( - result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs - ) - - if not self.should_accumulate(): - # track gradients - grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) - if grad_norm_dict: - self.trainer.lightning_module._current_fx_name = "on_after_backward" - self.trainer.lightning_module.log_grad_norm(grad_norm_dict) - - def update_lr_schedulers(self, interval: str) -> None: - if interval == "step": - finished_accumulation = self._accumulated_batches_reached() - finished_epoch = self._num_training_batches_reached() - if not finished_accumulation and not finished_epoch: - return - self.trainer.optimizer_connector.update_learning_rates( - interval=interval, - opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()], - ) - - def increment_accumulated_grad_global_step(self): - num_accumulated_batches_reached = self._accumulated_batches_reached() - num_training_batches_reached = self._num_training_batches_reached() - - # progress global step according to grads progress - if num_accumulated_batches_reached or num_training_batches_reached: - self.global_step = self.trainer.accelerator.update_global_step(self.total_batch_idx, self.global_step) - - def _accumulated_batches_reached(self): - return (self.batch_idx + 1) % self.trainer.accumulate_grad_batches == 0 - - def _num_training_batches_reached(self, is_last_batch=False): - return (self.batch_idx + 1) == self.trainer.num_training_batches or is_last_batch - - def should_accumulate(self): - # checks if backward or backward + optimizer step (via closure) - accumulation_done = self._accumulated_batches_reached() - is_final_batch = self._num_training_batches_reached() - return not (accumulation_done or is_final_batch) - - def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool) -> bool: - """ Decide if we should run validation. """ - if not self.trainer.enable_validation: - return False - - is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 - if not is_val_check_epoch: - return False - - # val_check_batch is inf for iterable datasets with no length defined - is_infinite_dataset = self.trainer.val_check_batch == float('inf') - if is_last_batch and is_infinite_dataset: - return True - - if self.trainer.should_stop: - return True - - # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch - is_val_check_batch = is_last_batch - if isinstance(self.trainer.limit_train_batches, int) and is_infinite_dataset: - is_val_check_batch = (batch_idx + 1) % self.trainer.limit_train_batches == 0 - elif self.trainer.val_check_batch != float('inf'): - is_val_check_batch = (batch_idx + 1) % self.trainer.val_check_batch == 0 - return is_val_check_batch - - def _build_kwargs(self, batch, batch_idx, opt_idx, hiddens): - # enable not needing to add opt_idx to training_step - step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) - - lightning_module = self.trainer.lightning_module - - if len(self.trainer.optimizers) > 1: - training_step_fx = getattr(lightning_module, "training_step") - has_opt_idx_in_train_step = is_param_in_hook_signature(training_step_fx, "optimizer_idx") - if has_opt_idx_in_train_step: - if not lightning_module.automatic_optimization: - self.warning_cache.warn( - "`training_step` hook signature has changed in v1.3." - " `optimizer_idx` argument has been removed in case of manual optimization. Support for" - " the old signature will be removed in v1.5", DeprecationWarning - ) - step_kwargs['optimizer_idx'] = opt_idx - elif not has_opt_idx_in_train_step and self.trainer.lightning_module.automatic_optimization: - raise ValueError( - f"Your LightningModule defines {len(self.trainer.optimizers)} optimizers but" - ' `training_step` is missing the `optimizer_idx` argument.' - ) - - # pass hiddens if using tbptt - if self._truncated_bptt_enabled(): - step_kwargs['hiddens'] = hiddens - - return step_kwargs - - def _truncated_bptt_enabled(self) -> bool: - """ Temporary tbptt utilities until this flag is fully migrated to the lightning module. """ - return self._truncated_bptt_steps() > 0 - - def _truncated_bptt_steps(self) -> int: - lightning_module = self.trainer.lightning_module - # Give precedence to the LightningModule as the Trainer flag will be removed in v1.5 - if lightning_module.truncated_bptt_steps > 0: - return lightning_module.truncated_bptt_steps - return self.trainer.truncated_bptt_steps or 0 - - def save_loggers_on_train_batch_end(self): - # when loggers should save to disk - should_flush_logs = self.trainer.logger_connector.should_flush_logs - if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: - self.trainer.logger.save() - - def run_optimization_start(self, opt_idx, optimizer): - # make sure only the gradients of the current optimizer's parameters are calculated - # in the training step to prevent dangling gradients in multiple-optimizer setup. - if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: - model = self.trainer.lightning_module - model.toggle_optimizer(optimizer, opt_idx) - - def update_running_loss(self, current_loss: torch.Tensor) -> None: - if self.trainer.lightning_module.automatic_optimization: - # track total loss for logging (avoid mem leaks) - self.accumulated_loss.append(current_loss) - - accumulated_loss = self.accumulated_loss.mean() - - if accumulated_loss is not None: - # calculate running loss for display - self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches) - - # reset for next set of accumulated grads - self.accumulated_loss.reset() diff --git a/setup.cfg b/setup.cfg index 594ebcb594f3b..74e02d932dc3c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,7 +45,6 @@ exclude_lines = omit = pytorch_lightning/cluster_environments/*.py pytorch_lightning/utilities/distributed.py - pytorch_lightning/trainer/training_loop.py pytorch_lightning/tuner/auto_gpu_select.py