Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor plugins backward #8328

Merged
merged 6 commits into from
Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added XLA Profiler ([#8014](https://github.com/PyTorchLightning/pytorch-lightning/pull/8014))


- Added `PrecisionPlugin.{pre,post}_backward` ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- Added `on_load_checkpoint` and `on_save_checkpoint` hooks to the `PrecisionPlugin` base class ([#7831](https://github.com/PyTorchLightning/pytorch-lightning/pull/7831))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


- Added `max_depth` parameter in `ModelSummary` ([#8062](https://github.com/PyTorchLightning/pytorch-lightning/pull/8062))


Expand Down Expand Up @@ -235,7 +241,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved profilers to their own file ([#7822](https://github.com/PyTorchLightning/pytorch-lightning/pull/7822))


- Added `on_load_checkpoint` and `on_save_checkpoint` hooks to the `PrecisionPlugin` base class ([#7831](https://github.com/PyTorchLightning/pytorch-lightning/pull/7831))
- The `on_after_backward` hook is now called on accumulating iterations ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- The mixed precision loss is no longer unscaled before the `on_after_backward` hook ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- The `TrainingTypePlugin.{pre,post}_backward` hooks no longer take the `optimizer, opt_idx, should_accumulate` arguments ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- The `PrecisionPlugin.backward` hooks no longer returns a value ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- The `PrecisionPlugin.backward` hooks no longer takes a `should_accumulate` argument ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- `LightningCLI` now aborts with a clearer message if config already exists and disables save config during `fast_dev_run`([#7963](https://github.com/PyTorchLightning/pytorch-lightning/pull/7963))
Expand Down Expand Up @@ -386,6 +404,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `self.optimizers()` not returning a single optimizer if it had been wrapped ([#8326](https://github.com/PyTorchLightning/pytorch-lightning/pull/8326))


- Fixed the `on_after_backward` hook not getting called when using manual optimization and no plugins ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- Fixed the `LightningModule.backward` hook only getting called with the `apex` plugin when using manual optimization ([#8328](https://github.com/PyTorchLightning/pytorch-lightning/pull/8328))


- Fixed moving batch to device before sending it to the `on_*_batch_start`/`on_*_batch_end` callbacks and model hooks ([#7378](https://github.com/PyTorchLightning/pytorch-lightning/pull/7378))


Expand Down
18 changes: 7 additions & 11 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,27 +274,23 @@ def validation_step_end(self, output: Optional[STEP_OUTPUT]) -> Optional[STEP_OU
def backward(
self,
closure_loss: Tensor,
optimizer: Optimizer,
optimizer_idx: int,
should_accumulate: bool,
*args: Any,
**kwargs: Any,
) -> Tensor:
"""Forwards backward-calls to the precision plugin.

Args:
closure_loss: a tensor holding the loss value to backpropagate
should_accumulate: whether to accumulate gradients
"""
self.training_type_plugin.pre_backward(closure_loss, should_accumulate, optimizer, optimizer_idx)
self.training_type_plugin.pre_backward(closure_loss)
closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)

output = self.precision_plugin.backward(
self.lightning_module, closure_loss, optimizer, optimizer_idx, should_accumulate, *args, **kwargs
)
self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

self.training_type_plugin.post_backward(closure_loss, should_accumulate, optimizer, optimizer_idx)
closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
self.training_type_plugin.post_backward(closure_loss)

return output
return closure_loss
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def optimizer_step(self, optimizer: Optimizer, opt_idx: int, lambda_closure: Callable, **kwargs: Any) -> None:
"""performs the actual optimizer step.
Expand Down Expand Up @@ -362,7 +358,7 @@ def setup_precision_plugin(self) -> None:
model, optimizers, schedulers = self.precision_plugin.connect(self.model, self.optimizers, self.lr_schedulers)
self.model = model
self.optimizers = optimizers
self.schedulers = schedulers
self.lr_schedulers = schedulers
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@property
def amp_backend(self) -> Optional[LightningEnum]:
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,18 +1447,20 @@ def training_step(...):
self._verify_is_manual_optimization('manual_backward')

# backward
self.trainer.fit_loop.epoch_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs)
self.trainer.fit_loop.epoch_loop.batch_loop.backward(loss, None, None, *args, **kwargs)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
def backward(
self, loss: Tensor, optimizer: Optional[Optimizer], optimizer_idx: Optional[int], *args, **kwargs
) -> None:
"""
Called to perform backward on the loss returned in :meth:`training_step`.
Override this hook with your own implementation if you need to.

Args:
loss: The loss tensor returned by :meth:`training_step`. If gradient accumulation is used, the loss here
holds the normalized value (scaled by 1 / accumulation steps).
optimizer: Current optimizer being used
optimizer_idx: Index of the current optimizer being used
optimizer: Current optimizer being used. ``None`` if using manual optimization.
optimizer_idx: Index of the current optimizer being used. ``None`` if using manual optimization.

Example::

Expand Down
40 changes: 12 additions & 28 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,20 +260,6 @@ def _process_closure_result(self, opt_closure_result: Optional[AttributeDict]) -
if self.trainer.terminate_on_nan:
self._check_finite(opt_closure_result.loss)

def _on_after_backward(self, batch_idx: int, untouched_loss: Tensor) -> None:
"""Calls ``on_after_backward`` hook and tracks loss history

Args:
batch_idx: the index of the current batch
untouched_loss: the original loss value
"""

# insert after step hook
self.trainer.call_hook("on_after_backward")

# when in dev debugging track the losses
self.trainer.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())

def _check_training_step_output(self, training_step_output: STEP_OUTPUT) -> None:
"""Sanity checks that training produced a valid output and optimizer step has already been called in manual
optimization.
Expand Down Expand Up @@ -559,10 +545,8 @@ def training_step_and_backward(
with self.trainer.profiler.profile("backward"):
self.backward(result, optimizer, opt_idx)

# hook - call this hook only
# when gradients have finished to accumulate
if not self.should_accumulate():
self._on_after_backward(batch_idx, result.loss)
# when in dev debugging track the losses
self.trainer.dev_debugger.track_train_loss_history(batch_idx, result.loss)

# check if loss or model weights are nan
if self.trainer.terminate_on_nan:
Expand All @@ -587,26 +571,26 @@ def _check_finite(self, loss: Tensor) -> None:
detect_nan_parameters(model)

def backward(
self, result: STEP_OUTPUT, optimizer: torch.optim.Optimizer, opt_idx: int, *args: Any, **kwargs: Any
self,
result: STEP_OUTPUT,
optimizer: Optional[torch.optim.Optimizer],
*args: Any,
**kwargs: Any,
) -> None:
"""Performs the backward step.

Args:
result: The output of the trainstep (including the loss value)
optimizer: The optimizer optimizing the gradients to call backward for
opt_idx: the index of the current optimizer
optimizer: Current optimizer being used. ``None`` if using manual optimization.
opt_idx: Index of the current optimizer being used. ``None`` if using manual optimization.
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""
should_accumulate = self.should_accumulate()

# backward can be called manually in the training loop
if isinstance(result, Tensor):
self.trainer.accelerator.backward(result, optimizer, opt_idx, should_accumulate, *args, **kwargs)
self.trainer.accelerator.backward(result, optimizer, *args, **kwargs)
else:
result.closure_loss = self.trainer.accelerator.backward(
result.closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs
)
result.closure_loss = self.trainer.accelerator.backward(result.closure_loss, optimizer, *args, **kwargs)

if not should_accumulate:
if not self.should_accumulate():
# track gradients
grad_norm_dict = self._track_and_norm_grad(optimizer=optimizer)
if grad_norm_dict:
Expand Down
43 changes: 8 additions & 35 deletions pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, ContextManager, Dict, Sequence
from typing import Any, Callable, Dict, Optional, Sequence

import torch
from torch import Tensor
Expand Down Expand Up @@ -51,43 +51,20 @@ def backward(
self,
model: 'pl.LightningModule',
closure_loss: Tensor,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
optimizer: Optional[Optimizer],
*args: Any,
**kwargs: Any,
) -> Tensor:
"""performs the actual backpropagation
) -> None:
"""Run before precision plugin executes backward

Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
optimizer: the optimizer to perform the step lateron
opt_idx: the optimizer index
should_accumulate: whether to accumulate gradients or not

optimizer: the optimizer to perform the step later on
"""
opt = model.trainer.optimizers if optimizer is None else optimizer
scaled_loss: ContextManager[Tensor] = amp.scale_loss(closure_loss, opt)

# enter apex context
closure_loss = scaled_loss.__enter__()
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# do backward pass
# TODO: not entirely sure, why we need this
if model is not None and isinstance(model, pl.LightningModule):
model.backward(closure_loss, optimizer, opt_idx, **kwargs)
else:
closure_loss.backward(*args, **kwargs)

# exit amp context
error = scaled_loss.__exit__(None, None, None)
if error:
raise Exception("apex unscale error")

# once backward has been applied, release graph
closure_loss = closure_loss.detach()
return closure_loss
opt = optimizer or model.trainer.optimizers
with amp.scale_loss(closure_loss, opt) as closure_loss:
super().backward(model, closure_loss, optimizer, *args, **kwargs)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def reinit_scheduler_properties(optimizers: Sequence[Optimizer], schedulers: Sequence[Any]) -> None:
Expand Down Expand Up @@ -124,10 +101,6 @@ def pre_optimizer_step(
"""
# apex amp does not support closures.
lambda_closure()

if not pl_module.automatic_optimization:
pl_module.trainer.call_hook("on_after_backward")

optimizer.step(**kwargs)
return False

Expand Down
15 changes: 2 additions & 13 deletions pytorch_lightning/plugins/precision/deepspeed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,27 +41,19 @@ def pre_optimizer_step(
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
deepspeed_engine = pl_module.trainer.model
# DeepSpeed not support closures.
lambda_closure()

if not pl_module.automatic_optimization:
pl_module.trainer.call_hook("on_after_backward")

deepspeed_engine = pl_module.trainer.model
deepspeed_engine.step()

return False

def backward(
self,
model: 'pl.LightningModule',
closure_loss: Tensor,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args: Any,
**kwargs: Any,
) -> Tensor:
) -> None:
if is_overridden('backward', model):
warning_cache.warn(
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
Expand All @@ -70,9 +62,6 @@ def backward(
# todo: hack around for deepspeed engine to call backward
deepspeed_engine = model.trainer.model
deepspeed_engine.backward(closure_loss, *args, **kwargs)
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
return closure_loss

def clip_gradients(
self,
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/plugins/precision/ipu_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@ def backward(
self,
model: 'pl.LightningModule',
closure_loss: Tensor,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args: Any,
**kwargs: Any,
) -> Tensor:
Expand Down
51 changes: 11 additions & 40 deletions pytorch_lightning/plugins/precision/native_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,35 +37,12 @@ def __init__(self) -> None:
self.backend = AMPType.NATIVE
self.scaler = torch.cuda.amp.GradScaler()

def backward(
def pre_backward(
self,
model: 'pl.LightningModule',
closure_loss: torch.Tensor,
optimizer: Optimizer,
opt_idx: int,
should_accumulate: bool,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
"""performs the actual backpropagation

Args:
model: the model to be optimized
closure_loss: the loss value obtained from the closure
optimizer: the optimizer to perform the step lateron
opt_idx: the optimizer's index
should_accumulate: whether to accumulate gradients or not

"""
closure_loss = self.scaler.scale(closure_loss)

closure_loss = super().backward(model, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs)

# unscale gradient to allow analyze within `on_after_backward`
if not should_accumulate and model.automatic_optimization:
self.scaler.unscale_(optimizer)

return closure_loss
return self.scaler.scale(closure_loss)

def pre_optimizer_step(
self,
Expand All @@ -75,27 +52,21 @@ def pre_optimizer_step(
lambda_closure: Callable,
**kwargs: Any,
) -> bool:
"""always called before the optimizer step.
Checks that the optimizer is not LBFGS, as this one is not supported by native amp
"""
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"native PyTorch amp and lbfgs are not compatible (optimizer {optimizer_idx})."
" To request, please file a Github issue in PyTorch and tag @mcarilli"
)

if not pl_module.automatic_optimization:
self.scaler.unscale_(optimizer)
pl_module.trainer.call_hook("on_after_backward")
self.scaler.step(optimizer)
self.scaler.update()
else:
# TODO: Add `on_before_optimizer_step`
# self.scaler.unscale_(optimizer)
# pl_module.trainer.call_hook("on_before_optimizer_step")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
if pl_module.automatic_optimization:
result = lambda_closure()
# lambda_closure returning None indicates that backward has been skipped
if result is not None:
self.scaler.step(optimizer)
self.scaler.update()

if result is None:
# lambda_closure returning None indicates that backward has been skipped
return False
self.scaler.step(optimizer)
self.scaler.update()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return False

@contextmanager
Expand Down
Loading