diff --git a/CHANGELOG.md b/CHANGELOG.md index bd2d65a17ffaa..e8feb6001b081 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -128,6 +128,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added XLA Profiler ([#8014](https://github.com/PyTorchLightning/pytorch-lightning/pull/8014)) +- Added `PrecisionPlugin.{pre,post}_backward` ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) + + +- Added `on_load_checkpoint` and `on_save_checkpoint` hooks to the `PrecisionPlugin` base class ([#7831](https://github.com/PyTorchLightning/pytorch-lightning/pull/7831)) + + - Added `max_depth` parameter in `ModelSummary` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062)) @@ -235,7 +241,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822)) -- Added `on_load_checkpoint` and `on_save_checkpoint` hooks to the `PrecisionPlugin` base class ([#7831](https://github.com/PyTorchLightning/pytorch-lightning/pull/7831)) +- The `on_after_backward` hook is now called on accumulating iterations ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) + + +- The mixed precision loss is no longer unscaled before the `on_after_backward` hook ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) + + +- The `TrainingTypePlugin.{pre,post}_backward` hooks no longer take the `optimizer, opt_idx, should_accumulate` arguments ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) + + +- The `PrecisionPlugin.backward` hooks no longer returns a value ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) + + +- The `PrecisionPlugin.backward` hooks no longer takes a `should_accumulate` argument ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) - `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963)) @@ -386,6 +404,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `self.optimizers()` not returning a single optimizer if it had been wrapped ([#8326](https://github.com/PyTorchLightning/pytorch-lightning/pull/8326)) +- Fixed the `on_after_backward` hook not getting called when using manual optimization and no plugins ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) + + +- Fixed the `LightningModule.backward` hook only getting called with the `apex` plugin when using manual optimization ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328)) + + - Fixed moving batch to device before sending it to the `on_*_batch_start`/`on_*_batch_end` callbacks and model hooks ([#7378](https://github.com/PyTorchLightning/pytorch-lightning/pull/7378)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index faae0b27b519e..ef2deae1a5515 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -274,9 +274,6 @@ def validation_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OU def backward( self, closure_loss: Tensor, - optimizer: Optimizer, - optimizer_idx: int, - should_accumulate: bool, *args: Any, **kwargs: Any, ) -> Tensor: @@ -284,17 +281,16 @@ def backward( Args: closure_loss: a tensor holding the loss value to backpropagate - should_accumulate: whether to accumulate gradients """ - self.training_type_plugin.pre_backward(closure_loss, should_accumulate, optimizer, optimizer_idx) + self.training_type_plugin.pre_backward(closure_loss) + closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss) - output = self.precision_plugin.backward( - self.lightning_module, closure_loss, optimizer, optimizer_idx, should_accumulate, *args, **kwargs - ) + self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs) - self.training_type_plugin.post_backward(closure_loss, should_accumulate, optimizer, optimizer_idx) + closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss) + self.training_type_plugin.post_backward(closure_loss) - return output + return closure_loss def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None: """performs the actual optimizer step. @@ -362,7 +358,7 @@ def setup_precision_plugin(self) -> None: model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers) self.model = model self.optimizers = optimizers - self.schedulers = schedulers + self.lr_schedulers = schedulers @property def amp_backend(self) -> Optional[LightningEnum]: diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 2c5ea69aeea70..3ec5fe0e8ffb1 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1447,9 +1447,11 @@ def training_step(...): self._verify_is_manual_optimization('manual_backward') # backward - self.trainer.fit_loop.epoch_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs) + self.trainer.fit_loop.epoch_loop.batch_loop.backward(loss, None, None, *args, **kwargs) - def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None: + def backward( + self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs + ) -> None: """ Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your own implementation if you need to. @@ -1457,8 +1459,8 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args Args: loss: The loss tensor returned by :meth:`training_step`. If gradient accumulation is used, the loss here holds the normalized value (scaled by 1 / accumulation steps). - optimizer: Current optimizer being used - optimizer_idx: Index of the current optimizer being used + optimizer: Current optimizer being used. ``None`` if using manual optimization. + optimizer_idx: Index of the current optimizer being used. ``None`` if using manual optimization. Example:: diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 41ad9280ffaf7..886fce3bb3961 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -260,20 +260,6 @@ def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) - if self.trainer.terminate_on_nan: self._check_finite(opt_closure_result.loss) - def _on_after_backward(self, batch_idx: int, untouched_loss: Tensor) -> None: - """Calls ``on_after_backward`` hook and tracks loss history - - Args: - batch_idx: the index of the current batch - untouched_loss: the original loss value - """ - - # insert after step hook - self.trainer.call_hook("on_after_backward") - - # when in dev debugging track the losses - self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) - def _check_training_step_output(self, training_step_output: STEP_OUTPUT) -> None: """Sanity checks that training produced a valid output and optimizer step has already been called in manual optimization. @@ -559,10 +545,8 @@ def training_step_and_backward( with self.trainer.profiler.profile("backward"): self.backward(result, optimizer, opt_idx) - # hook - call this hook only - # when gradients have finished to accumulate - if not self.should_accumulate(): - self._on_after_backward(batch_idx, result.loss) + # when in dev debugging track the losses + self.trainer.dev_debugger.track_train_loss_history(batch_idx, result.loss) # check if loss or model weights are nan if self.trainer.terminate_on_nan: @@ -587,26 +571,26 @@ def _check_finite(self, loss: Tensor) -> None: detect_nan_parameters(model) def backward( - self, result: STEP_OUTPUT, optimizer: torch.optim.Optimizer, opt_idx: int, *args: Any, **kwargs: Any + self, + result: STEP_OUTPUT, + optimizer: Optional[torch.optim.Optimizer], + *args: Any, + **kwargs: Any, ) -> None: """Performs the backward step. Args: result: The output of the trainstep (including the loss value) - optimizer: The optimizer optimizing the gradients to call backward for - opt_idx: the index of the current optimizer + optimizer: Current optimizer being used. ``None`` if using manual optimization. + opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization. """ - should_accumulate = self.should_accumulate() - # backward can be called manually in the training loop if isinstance(result, Tensor): - self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs) + self.trainer.accelerator.backward(result, optimizer, *args, **kwargs) else: - result.closure_loss = self.trainer.accelerator.backward( - result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs - ) + result.closure_loss = self.trainer.accelerator.backward(result.closure_loss, optimizer, *args, **kwargs) - if not should_accumulate: + if not self.should_accumulate(): # track gradients grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer) if grad_norm_dict: diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index d960c06fb4c9e..022dd3ee392e8 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, ContextManager, Dict, Sequence +from typing import Any, Callable, Dict, Optional, Sequence import torch from torch import Tensor @@ -51,43 +51,20 @@ def backward( self, model: 'pl.LightningModule', closure_loss: Tensor, - optimizer: Optimizer, - opt_idx: int, - should_accumulate: bool, + optimizer: Optional[Optimizer], *args: Any, **kwargs: Any, - ) -> Tensor: - """performs the actual backpropagation + ) -> None: + """Run before precision plugin executes backward Args: model: the model to be optimized closure_loss: the loss value obtained from the closure - optimizer: the optimizer to perform the step lateron - opt_idx: the optimizer index - should_accumulate: whether to accumulate gradients or not - + optimizer: current optimizer being used. ``None`` if using manual optimization """ - opt = model.trainer.optimizers if optimizer is None else optimizer - scaled_loss: ContextManager[Tensor] = amp.scale_loss(closure_loss, opt) - - # enter apex context - closure_loss = scaled_loss.__enter__() - - # do backward pass - # TODO: not entirely sure, why we need this - if model is not None and isinstance(model, pl.LightningModule): - model.backward(closure_loss, optimizer, opt_idx, **kwargs) - else: - closure_loss.backward(*args, **kwargs) - - # exit amp context - error = scaled_loss.__exit__(None, None, None) - if error: - raise Exception("apex unscale error") - - # once backward has been applied, release graph - closure_loss = closure_loss.detach() - return closure_loss + opt = optimizer or model.trainer.optimizers + with amp.scale_loss(closure_loss, opt) as closure_loss: + super().backward(model, closure_loss, optimizer, *args, **kwargs) @staticmethod def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Sequence[Any]) -> None: @@ -124,10 +101,6 @@ def pre_optimizer_step( """ # apex amp does not support closures. lambda_closure() - - if not pl_module.automatic_optimization: - pl_module.trainer.call_hook("on_after_backward") - optimizer.step(**kwargs) return False diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 3edcc6866e219..4809b4e8c2c79 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -41,27 +41,19 @@ def pre_optimizer_step( lambda_closure: Callable, **kwargs: Any, ) -> bool: - deepspeed_engine = pl_module.trainer.model # DeepSpeed not support closures. lambda_closure() - - if not pl_module.automatic_optimization: - pl_module.trainer.call_hook("on_after_backward") - + deepspeed_engine = pl_module.trainer.model deepspeed_engine.step() - return False def backward( self, model: 'pl.LightningModule', closure_loss: Tensor, - optimizer: Optimizer, - opt_idx: int, - should_accumulate: bool, *args: Any, **kwargs: Any, - ) -> Tensor: + ) -> None: if is_overridden('backward', model): warning_cache.warn( "You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles" @@ -70,9 +62,6 @@ def backward( # todo: hack around for deepspeed engine to call backward deepspeed_engine = model.trainer.model deepspeed_engine.backward(closure_loss, *args, **kwargs) - # once backward has been applied, release graph - closure_loss = closure_loss.detach() - return closure_loss def clip_gradients( self, diff --git a/pytorch_lightning/plugins/precision/ipu_precision.py b/pytorch_lightning/plugins/precision/ipu_precision.py index e6983966e166b..163f32a9ab5ae 100644 --- a/pytorch_lightning/plugins/precision/ipu_precision.py +++ b/pytorch_lightning/plugins/precision/ipu_precision.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Any, Optional, Union -from torch import Tensor from torch.nn import Module from torch.optim import Optimizer @@ -21,6 +20,10 @@ from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.warnings import WarningCache + +warning_cache = WarningCache() class IPUPrecisionPlugin(PrecisionPlugin): @@ -29,18 +32,12 @@ def __init__(self, precision: int) -> None: super().__init__() self.precision = precision - def backward( - self, - model: 'pl.LightningModule', - closure_loss: Tensor, - optimizer: Optimizer, - opt_idx: int, - should_accumulate: bool, - *args: Any, - **kwargs: Any, - ) -> Tensor: - # IPU internally manages bwd step. - return closure_loss + def backward(self, model: 'pl.LightningModule', *args: Any, **kwargs: Any) -> None: + if is_overridden('backward', model): + warning_cache.warn( + "You have overridden the `LightningModule.backward` hook but it will be ignored since IPUs handle" + " the backward logic internally." + ) def clip_gradients( self, diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index e25f46d9ec239..3e37de6b5d07b 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -37,35 +37,12 @@ def __init__(self) -> None: self.backend = AMPType.NATIVE self.scaler = torch.cuda.amp.GradScaler() - def backward( + def pre_backward( self, model: 'pl.LightningModule', closure_loss: torch.Tensor, - optimizer: Optimizer, - opt_idx: int, - should_accumulate: bool, - *args: Any, - **kwargs: Any, ) -> torch.Tensor: - """performs the actual backpropagation - - Args: - model: the model to be optimized - closure_loss: the loss value obtained from the closure - optimizer: the optimizer to perform the step lateron - opt_idx: the optimizer's index - should_accumulate: whether to accumulate gradients or not - - """ - closure_loss = self.scaler.scale(closure_loss) - - closure_loss = super().backward(model, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs) - - # unscale gradient to allow analyze within `on_after_backward` - if not should_accumulate and model.automatic_optimization: - self.scaler.unscale_(optimizer) - - return closure_loss + return self.scaler.scale(closure_loss) def pre_optimizer_step( self, @@ -75,27 +52,21 @@ def pre_optimizer_step( lambda_closure: Callable, **kwargs: Any, ) -> bool: - """always called before the optimizer step. - Checks that the optimizer is not LBFGS, as this one is not supported by native amp - """ if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})." " To request, please file a Github issue in PyTorch and tag @mcarilli" ) - - if not pl_module.automatic_optimization: - self.scaler.unscale_(optimizer) - pl_module.trainer.call_hook("on_after_backward") - self.scaler.step(optimizer) - self.scaler.update() - else: + # TODO: Add `on_before_optimizer_step` + # self.scaler.unscale_(optimizer) + # pl_module.trainer.call_hook("on_before_optimizer_step") + if pl_module.automatic_optimization: result = lambda_closure() - # lambda_closure returning None indicates that backward has been skipped - if result is not None: - self.scaler.step(optimizer) - self.scaler.update() - + if result is None: + # lambda_closure returning None indicates that backward has been skipped + return False + self.scaler.step(optimizer) + self.scaler.update() return False @contextmanager diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index e8dccbed741fa..edba2e4b15240 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -51,37 +51,54 @@ def connect( """Connects this plugin to the accelerator and the training process""" return model, optimizers, lr_schedulers + def pre_backward( + self, + model: 'pl.LightningModule', + closure_loss: Tensor, + ) -> Tensor: + """Run before precision plugin executes backward + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + """ + return closure_loss + def backward( self, model: 'pl.LightningModule', closure_loss: Tensor, - optimizer: Optimizer, - opt_idx: int, - should_accumulate: bool, + optimizer: Optional[Optimizer], *args: Any, **kwargs: Any, - ) -> Tensor: - """performs the actual backpropagation + ) -> None: + """Performs the actual backpropagation Args: model: the model to be optimized closure_loss: the loss value obtained from the closure - optimizer: the optimizer to perform the step lateron - opt_idx: the optimizer's index - should_accumulate: whether to accumulate gradients or not - + optimizer: current optimizer being used. ``None`` if using manual optimization """ - automatic_optimization = model.automatic_optimization - # do backward pass - if automatic_optimization: - model.backward(closure_loss, optimizer, opt_idx) + if model is not None and isinstance(model, pl.LightningModule): + model.backward(closure_loss, optimizer, *args, **kwargs) else: closure_loss.backward(*args, **kwargs) + def post_backward( + self, + model: 'pl.LightningModule', + closure_loss: Tensor, + ) -> Tensor: + """Run after precision plugin executes backward + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + """ # once backward has been applied, release graph closure_loss = closure_loss.detach() - + model.trainer.call_hook("on_after_backward") return closure_loss def pre_optimizer_step( diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index bbea7c9d4d514..9e0fbad33dd9c 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -27,7 +27,6 @@ import torch import torch.distributed from torch.nn.parallel.distributed import DistributedDataParallel -from torch.optim import Optimizer from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -361,7 +360,7 @@ def barrier(self, *args, **kwargs) -> None: def broadcast(self, obj: object, src: int = 0) -> object: return self.dist.broadcast(obj) - def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + def pre_backward(self, closure_loss: torch.Tensor) -> None: """Run before precision plugin executes backward""" if not self.lightning_module.automatic_optimization: prepare_for_backward(self.model, closure_loss) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index b08006f3360a5..64f10f30ea5a8 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -20,7 +20,6 @@ import torch.distributed import torch.multiprocessing as mp from torch.nn.parallel.distributed import DistributedDataParallel -from torch.optim import Optimizer from pytorch_lightning.distributed.dist import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule @@ -336,7 +335,7 @@ def model_to_device(self): torch.cuda.set_device(self.root_device) self.model.to(self.root_device) - def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + def pre_backward(self, closure_loss: torch.Tensor) -> None: """Run before precision plugin executes backward""" if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: prepare_for_backward(self.model, closure_loss) diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index 6c5e2d3a15ddd..6443a8ca29096 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -15,7 +15,7 @@ from typing import Any, List, Optional, Union import torch -from torch.optim.lr_scheduler import _LRScheduler, Optimizer +from torch.optim.lr_scheduler import _LRScheduler from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin @@ -195,7 +195,7 @@ def all_gather( gathered_result = list(gathered.split(1, dim=0)) return gathered_result - def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + def post_backward(self, closure_loss: torch.Tensor) -> None: # synchronize all horovod optimizers. for optimizer in self.lightning_module.trainer.optimizers: optimizer.synchronize() diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 79a612b4473de..218e84c72c920 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -14,7 +14,6 @@ from typing import Optional import torch -from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer @@ -94,7 +93,7 @@ def lightning_module(self) -> 'pl.LightningModule': ) return unwrap_lightning_module_sharded(self._model) - def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + def pre_backward(self, closure_loss: torch.Tensor) -> None: pass def post_training_step(self): diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index c583ac756cd0f..f756f52f2f693 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -14,7 +14,6 @@ from typing import Optional import torch -from torch.optim import Optimizer import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin @@ -79,7 +78,7 @@ def lightning_module(self) -> 'pl.LightningModule': ) return unwrap_lightning_module_sharded(self._model) - def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + def pre_backward(self, closure_loss: torch.Tensor) -> None: pass def post_training_step(self): diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index e7ca73bc9f40d..e49d170a93d66 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -111,10 +111,10 @@ def reduce_boolean_decision(self, decision: bool) -> bool: """Reduce the early stopping decision across all processes""" return decision - def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + def pre_backward(self, closure_loss: torch.Tensor) -> None: """Run before precision plugin executes backward""" - def post_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): + def post_backward(self, closure_loss: torch.Tensor) -> None: """Run after precision plugin executes backward""" def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int, **kwargs) -> None: diff --git a/tests/plugins/test_amp_plugins.py b/tests/plugins/test_amp_plugins.py index cf58427b071ce..9d9c029a65629 100644 --- a/tests/plugins/test_amp_plugins.py +++ b/tests/plugins/test_amp_plugins.py @@ -72,6 +72,11 @@ def test_amp_apex_ddp( class GradientUnscaleBoringModel(BoringModel): def on_after_backward(self): + # TODO: replace with `on_before_optimizer_step` so we don't need to check accumulate and unscale manually + if self.trainer.fit_loop.should_accumulate(): + return + opt = self.optimizers() + self.trainer.precision_plugin.scaler.unscale_(opt) norm = torch.nn.utils.clip_grad_norm_(self.parameters(), 2) if not (torch.isinf(norm) or torch.isnan(norm)): assert norm.item() < 15.