Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Share the training step output data via ClosureResult #9349

Merged
merged 42 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
b6d210c
WIP
carmocca Sep 6, 2021
f97f5f0
WIP
carmocca Sep 6, 2021
573f92a
WIP
carmocca Sep 7, 2021
718c971
WIP
carmocca Sep 7, 2021
096129e
Remove hiddens
carmocca Sep 7, 2021
bd90a94
Add closure result tests
carmocca Sep 7, 2021
bc87081
Fix tests
carmocca Sep 7, 2021
15c5972
Fail if closure is not executed
carmocca Sep 7, 2021
9803cec
Remove deepcopy
carmocca Sep 7, 2021
9f5f13a
Merge branch 'master' into bugfix/closure-result
carmocca Sep 7, 2021
24a70b9
Remove debugging statement
carmocca Sep 7, 2021
8630761
Enforce that the optimizer closure is executed when `optimizer_step` …
carmocca Sep 7, 2021
46d8113
Add comment
carmocca Sep 7, 2021
32492ca
Docs and changelog
carmocca Sep 7, 2021
bd54363
is not None
carmocca Sep 7, 2021
4265449
Improve error message to cover broken optimizers
carmocca Sep 7, 2021
8716d4c
Undo getter
carmocca Sep 7, 2021
3f7ea80
Add license and period
carmocca Sep 8, 2021
ad0cc5a
Merge branch 'master' into refactor/enforce-closure-execution
carmocca Sep 8, 2021
4174e1f
Merge manual loop PR
carmocca Sep 8, 2021
3140d9b
Clone loss
carmocca Sep 8, 2021
6689f3d
Remove apply_accumulation
carmocca Sep 8, 2021
6f9f5af
Move to cpu
carmocca Sep 8, 2021
8108c1f
mypy
carmocca Sep 8, 2021
35006e5
Merge branch 'master' into bugfix/closure-result
carmocca Sep 8, 2021
c4f360a
Handle hiddens with an utility
carmocca Sep 8, 2021
bed7624
Without closure and tests
carmocca Sep 8, 2021
6a81c0c
Remove code to move the ClosureResult
carmocca Sep 8, 2021
a09cfc8
Tests
carmocca Sep 8, 2021
78253e4
Fix TODO
carmocca Sep 8, 2021
26e96a1
Undo docs changes
carmocca Sep 8, 2021
5b50789
Update CHANGELOG
carmocca Sep 8, 2021
bf08c2b
Fix
carmocca Sep 8, 2021
f110bc1
Stricter hiddens returning
carmocca Sep 9, 2021
9d8d587
Drop closure loss
carmocca Sep 9, 2021
5138edf
Logic fix
carmocca Sep 9, 2021
3aaf9dd
Bad rename
carmocca Sep 9, 2021
1f32d52
Fix for raise StopIteration
carmocca Sep 9, 2021
04a54a6
Fix comment
carmocca Sep 9, 2021
a4d49e4
Address comments
carmocca Sep 10, 2021
cdf5edd
model_ref -> lightning_module
carmocca Sep 10, 2021
65ab8d9
Merge branch 'master' into bugfix/closure-result
tchaton Sep 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `replace_sampler` missing the batch size under specific conditions ([#9367](https://github.com/PyTorchLightning/pytorch-lightning/pull/9367))


- Fixed bug where the training step output needed to be `deepcopy`-ed ([#9349](https://github.com/PyTorchLightning/pytorch-lightning/pull/9349))


- Fixed freeing data iterators in loop `on_run_end` ([#9386](https://github.com/PyTorchLightning/pytorch-lightning/pull/9386))


Expand Down
35 changes: 25 additions & 10 deletions pytorch_lightning/loops/batch/manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
from typing import Any, Optional

from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.closure import ClosureResult
from pytorch_lightning.loops.utilities import (
_build_training_step_kwargs,
_check_training_step_output,
_process_training_step_output,
_extract_hiddens,
check_finite_loss,
)
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection


class ManualOptimization(Loop):
Expand All @@ -35,7 +36,7 @@ def __init__(self) -> None:
super().__init__()
self._done: bool = False
self._hiddens: Optional[Any] = None
self._output: Optional[ResultCollection] = None
self._output: Optional[ClosureResult] = None

@property
def done(self) -> bool:
Expand All @@ -52,16 +53,16 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
batch_idx: the index of the current batch
"""
assert self.trainer is not None
ligtning_module = self.trainer.lightning_module
carmocca marked this conversation as resolved.
Show resolved Hide resolved
lightning_module = self.trainer.lightning_module

with self.trainer.profiler.profile("model_forward"):

step_kwargs = _build_training_step_kwargs(
ligtning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens
lightning_module, self.trainer.optimizers, batch, batch_idx, opt_idx=None, hiddens=self._hiddens
)

# manually capture logged metrics
ligtning_module._current_fx_name = "training_step"
lightning_module._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()
Expand All @@ -70,14 +71,28 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

_check_training_step_output(ligtning_module, training_step_output)
_check_training_step_output(lightning_module, training_step_output)

result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)

# TODO: do not use `ClosureResult`
carmocca marked this conversation as resolved.
Show resolved Hide resolved
result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)

if self.trainer.terminate_on_nan:
check_finite_loss(result.closure_loss)

if self.trainer.move_metrics_to_cpu:
# hiddens and the training step output are not moved as they are not considered "metrics"
# the user might need them on the correct device for an operation in `training_epoch_end`
assert self.trainer._results is not None
self.trainer._results.cpu()

self._done = True
self._output = result_collection
self._output = result

def on_run_end(self) -> Optional[ResultCollection]:
def on_run_end(self) -> Optional[ClosureResult]:
"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""
output, self._output = self._output, None # free memory
# #9052 added support for raising `StopIteration` in the `training_step`. If that happens, then `advance`
# doesn't finish and `self._output` stays as `None`. If #9415 happens then this would always return a result
return output
7 changes: 3 additions & 4 deletions pytorch_lightning/loops/batch/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
from typing import Any, List, Optional, Tuple

import numpy as np
Expand Down Expand Up @@ -131,9 +130,9 @@ def advance(self, batch, batch_idx):
self.batch_outputs[k].extend(batch_outputs[k])
else:
# in manual optimization, hand over execution to the ManualOptimization loop
output = self.manual_loop.run(split_batch, batch_idx)
if output is not None:
self.batch_outputs[0].append(deepcopy(output))
result = self.manual_loop.run(split_batch, batch_idx)
if result is not None and result.loss is not None:
self.batch_outputs[0].append(result.drop_closure_loss())

def on_run_end(self) -> None:
self.optimizer_loop._hiddens = None
Expand Down
69 changes: 60 additions & 9 deletions pytorch_lightning/loops/closure.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,84 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional

from torch import Tensor

from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.warnings import WarningCache
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, WarningCache


@dataclass
class ClosureResult:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""A container to hold the result of a :class:`AbstractClosure` call.

It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`.

Attributes:
closure_loss: The loss with a graph attached.
loss: A detached copy of the closure loss.
result_collection: A collection of results returned by the closure.
extra: Any keys other than the loss returned.
"""

closure_loss: Optional[Tensor]
loss: Optional[Tensor]
result_collection: Optional[ResultCollection]
loss: Optional[Tensor] = field(init=False, default=None)
extra: Dict[str, Tensor] = field(default_factory=dict)

def __post_init__(self) -> None:
# TODO: remove with the deprecation removal in v1.6
ClosureResult._check_extra_detach_deprecation(self.extra)
self.extra = recursive_detach(self.extra)

self._clone_loss()

def _clone_loss(self) -> None:
if self.closure_loss is not None:
# the loss will get scaled for amp. avoid any modifications to it
self.loss = self.closure_loss.detach().clone()

@classmethod
def from_training_step_output(
cls, training_step_output: Optional[STEP_OUTPUT], normalize: int = 1
) -> "ClosureResult":
closure_loss, extra = None, {}

if isinstance(training_step_output, dict):
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
closure_loss = training_step_output.get("loss")
extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
elif isinstance(training_step_output, Tensor):
closure_loss = training_step_output

if closure_loss is not None:
# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect
closure_loss /= normalize

return cls(closure_loss, extra=extra)

@staticmethod
def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> None:
def check_fn(v: Tensor) -> Tensor:
if v.grad_fn is not None:
rank_zero_deprecation(
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
" but this behaviour will change in v1.6. Please detach it manually:"
" `return {'loss': ..., 'something': something.detach()}`"
)
return v

apply_to_collection(extra, Tensor, check_fn)

def drop_closure_loss(self) -> "ClosureResult":
"""Return itself without the closure loss which could have a `grad_fn`"""
self.closure_loss = None
return self


class AbstractClosure(ABC):
Expand Down Expand Up @@ -107,7 +159,7 @@ class Closure(AbstractClosure):

def __init__(
self,
step_fn: Callable[[], Optional[Dict]],
step_fn: Callable[[], ClosureResult],
backward_fn: Optional[Callable[[Tensor], Tensor]] = None,
zero_grad_fn: Optional[Callable[[], None]] = None,
profiler: Optional[BaseProfiler] = None,
Expand All @@ -121,7 +173,6 @@ def __init__(
def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
with self._profiler.profile("training_step_and_backward"):
step_output = self._step_fn()
step_output = ClosureResult(**step_output) if step_output else ClosureResult(None, None, None)

if step_output.closure_loss is None:
self.warning_cache.warn(
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.closure import ClosureResult
from pytorch_lightning.loops.utilities import _prepare_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import Progress, SchedulerProgress
Expand Down Expand Up @@ -283,18 +284,18 @@ def _track_epoch_end_reduce_metrics(

@staticmethod
def _prepare_outputs(
outputs: List[List[List["ResultCollection"]]], batch_mode: bool
outputs: List[List[List[ClosureResult]]], batch_mode: bool
) -> Union[List[List[List[Dict]]], List[List[Dict]], List[Dict], Dict]:
"""Extract required information from batch or epoch end results.

Args:
outputs: A 3-dimensional list of ``ResultCollection`` objects with dimensions:
outputs: A 3-dimensional list of ``ClosureResult`` objects with dimensions:
``[optimizer outs][batch outs][tbptt steps]``.

batch_mode: If True, ignore the batch output dimension.

Returns:
The cleaned outputs with ``ResultCollection`` objects converted to dictionaries.
The cleaned outputs with ``ClosureResult`` objects converted to dictionaries.
All list dimensions of size one will be collapsed.
"""
processed_outputs = []
Expand All @@ -311,13 +312,13 @@ def _prepare_outputs(
for batch_outputs in opt_outputs:
processed_tbptt_outputs = []

if isinstance(batch_outputs, ResultCollection):
if isinstance(batch_outputs, ClosureResult):
batch_outputs = [batch_outputs]

for tbptt_output in batch_outputs:
out = {}
if tbptt_output.minimize is not None:
out["loss"] = tbptt_output.minimize.detach()
if tbptt_output.loss is not None:
out["loss"] = tbptt_output.loss
out.update(tbptt_output.extra)
processed_tbptt_outputs.append(out)

Expand Down
52 changes: 25 additions & 27 deletions pytorch_lightning/loops/optimizer/optimizer_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import deepcopy
from functools import partial
from typing import Any, Callable, Dict, List, Optional

Expand All @@ -27,16 +25,16 @@
_block_parallel_sync_behavior,
_build_training_step_kwargs,
_check_training_step_output,
_process_training_step_output,
_extract_hiddens,
check_finite_loss,
)
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import OptimizationProgress
from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm
from pytorch_lightning.utilities import AMPType, DeviceType, grad_norm
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.finite_checks import detect_nan_parameters
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE

_OUTPUTS_TYPE = List[List[Optional[ResultCollection]]]
_OUTPUTS_TYPE = List[List[ClosureResult]]


class OptimizerLoop(Loop):
Expand Down Expand Up @@ -80,8 +78,8 @@ def advance(self, batch: Any, *args, **kwargs) -> None: # type: ignore[override
self._optimizers[self.optim_progress.optimizer_idx],
self.optim_progress.optimizer_idx,
)
if result.result_collection is not None:
self.outputs[self.optim_progress.optimizer_idx].append(deepcopy(result.result_collection))
if result.loss is not None:
self.outputs[self.optim_progress.optimizer_idx].append(result.drop_closure_loss())

self.optim_progress.optimizer_idx += 1

Expand Down Expand Up @@ -168,7 +166,7 @@ def _make_closure(self, split_batch: Any, batch_idx: int, opt_idx: int, optimize
step_fn=step_fn, backward_fn=backward_fn, zero_grad_fn=zero_grad_fn, profiler=self.trainer.profiler
)

def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Callable[[], Optional[AttributeDict]]:
def _make_step_fn(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Callable[[], ClosureResult]:
"""Build the step function that runs the `training_step` and processes its output."""
return partial(self._training_step, split_batch, batch_idx, opt_idx)

Expand Down Expand Up @@ -241,7 +239,7 @@ def _optimizer_step(
train_step_and_backward_closure: the closure function performing the train step and computing the
gradients. By default called by the optimizer (if possible)
"""
model_ref = self.trainer.lightning_module
lightning_module = self.trainer.lightning_module

is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
using_native_amp = self.trainer.amp_backend is not None and self.trainer.amp_backend == AMPType.NATIVE
Expand All @@ -259,7 +257,7 @@ def _optimizer_step(
self.optim_progress.optimizer.step.increment_ready()

# model hook
model_ref.optimizer_step(
lightning_module.optimizer_step(
self.trainer.current_epoch,
batch_idx,
optimizer,
Expand Down Expand Up @@ -293,7 +291,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer,
self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self.optim_progress.optimizer.zero_grad.increment_completed()

def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Optional[AttributeDict]:
def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> ClosureResult:
"""Performs the actual train step with the tied hooks.

Args:
Expand All @@ -302,19 +300,19 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Opti
opt_idx: the index of the current optimizer

Returns:
an AttributeDict containing the loss value and the training step output.
A ``ClosureResult`` containing the training step output.
"""
# give the PL module a result for logging
model_ref = self.trainer.lightning_module
lightning_module = self.trainer.lightning_module

with self.trainer.profiler.profile("model_forward"):

step_kwargs = _build_training_step_kwargs(
self.trainer.lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens
lightning_module, self.trainer.optimizers, split_batch, batch_idx, opt_idx, self._hiddens
)

# manually capture logged metrics
model_ref._current_fx_name = "training_step"
lightning_module._current_fx_name = "training_step"
with self.trainer.profiler.profile("training_step"):
training_step_output = self.trainer.accelerator.training_step(step_kwargs)
self.trainer.accelerator.post_training_step()
Expand All @@ -323,20 +321,20 @@ def _training_step(self, split_batch: Any, batch_idx: int, opt_idx: int) -> Opti

training_step_output = self.trainer.call_hook("training_step_end", training_step_output)

_check_training_step_output(self.trainer.lightning_module, training_step_output)
_check_training_step_output(lightning_module, training_step_output)

self._hiddens = _extract_hiddens(training_step_output, lightning_module.truncated_bptt_steps)

result_collection, self._hiddens = _process_training_step_output(self.trainer, training_step_output)
if result_collection is None:
return None
result = ClosureResult.from_training_step_output(training_step_output, self.trainer.accumulate_grad_batches)

# output validation already done, here loss can't be None
assert result_collection.minimize is not None
if self.trainer.terminate_on_nan:
check_finite_loss(result.closure_loss)

if self.trainer.move_metrics_to_cpu:
# hiddens and the training step output are not moved as they are not considered "metrics"
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.trainer._results.cpu()

# accumulate loss. if accumulate_grad_batches==1, no effect
closure_loss = result_collection.minimize / self.trainer.accumulate_grad_batches
# the loss will get scaled for amp. avoid any modifications to it
loss = closure_loss.detach().clone()
return AttributeDict(closure_loss=closure_loss, loss=loss, result_collection=result_collection)
return result

def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, float]:
"""Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer.
Expand Down
Loading