Skip to content

Commit

Permalink
Loop Refactor 1/N - Training Loop (#7871)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
Co-authored-by: Justus Schock <justus.schock@posteo.de>
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
  • Loading branch information
9 people authored Jun 15, 2021
1 parent 560b197 commit 971908a
Show file tree
Hide file tree
Showing 20 changed files with 1,641 additions and 80 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507))
* Refactored the logic around manual and automatic optimization inside the optimizer loop ([#7526](https://github.com/PyTorchLightning/pytorch-lightning/pull/7526))
* Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
* Refactored "should run validation" logic when the trainer is signaled to stop ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701))

* Simplified logic for updating the learning rate for schedulers ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
* Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701))
* Refactored internal loop interface; added new classes `FitLoop`, `TrainingEpochLoop`, `TrainingBatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871))

- Refactored logging
* Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1346,7 +1346,7 @@ def training_step(...):

# backward
self._running_manual_backward = True
self.trainer.train_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs)
self.trainer.fit_loop.training_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs)
self._running_manual_backward = False

def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
Expand Down Expand Up @@ -1445,7 +1445,7 @@ def optimizer_step(
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter
to ``optimizer.step()`` function as shown in the examples. This ensures that
``training_step()``, ``optimizer.zero_grad()``, ``backward()`` are called within
:meth:`~pytorch_lightning.trainer.training_loop.TrainLoop.run_training_batch`.
:meth:`~pytorch_lightning.trainer.fit_loop.training_loop.batch_loop.TrainingBatchLoop.advance`.
Args:
epoch: Current epoch
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def toggle_model(self, sync_grad: bool = True):
during the accumulation phase.
Setting `sync_grad` to False will block this synchronization and improve performance.
"""
with self._trainer.train_loop.block_ddp_sync_behaviour(not sync_grad):
with self._trainer.fit_loop.training_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad):
self._toggle_model()
yield
self._untoggle_model()
Expand Down
18 changes: 18 additions & 0 deletions pytorch_lightning/loops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.loops.base import Loop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop # noqa: F401
122 changes: 122 additions & 0 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Optional
from weakref import proxy

from deprecate import void

import pytorch_lightning as pl


class Loop(ABC):
"""
Basic Loops interface. All classes derived from this must implement the following properties and methods:
* :attr:`done` (property): Condition to break the loop
* :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run`
* :attr:`advance` (method): Implements one step of the loop
This class implements the following loop structure:
.. codeblock:: python
on_run_start()
while not done:
on_advance_start()
advance()
on_advance_end()
on_run_end()
"""

def __init__(self) -> None:
self.iteration_count: int = 0
self.trainer: Optional['pl.Trainer'] = None

@property
@abstractmethod
def done(self) -> bool:
"""Property indicating when loop is finished"""

@property
def skip(self) -> bool:
"""Determine whether to return immediately from the call to :meth:`run`."""
return False

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects Loop with all the necessary things like connectors and accelerators."""
self.trainer = proxy(trainer)

@abstractmethod
def reset(self) -> None:
"""Resets the internal state of the loop at the beginning of each call to :attr:`run`."""

def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
"""
The main entry point to the loop.
Will frequently check the :attr:`done` condition and calls :attr:`advance`
until :attr:`done` evaluates to ``True``.
Returns:
the output of :attr`on_run_end` (often outputs collected from each step of the loop)
"""
if self.skip:
return

self.reset()
self.on_run_start(*args, **kwargs)

while not self.done:
try:
self.on_advance_start(*args, **kwargs)
self.advance(*args, **kwargs)
self.on_advance_end()
self.iteration_count += 1
except StopIteration:
break

output = self.on_run_end()
self.teardown()
return output

def on_run_start(self, *args: Any, **kwargs: Any) -> None:
"""
Hook to be called as the first thing after entering :attr:`run` (except the state reset).
Accepts all arguments passed to :attr:`run`.
"""
void(*args, **kwargs)

def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
"""
Hook to be called each time before :attr:`advance` is called. Accepts all arguments passed to :attr`run`.
"""
void(*args, **kwargs)

@abstractmethod
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs a single step. Accepts all arguments passed to :attr:`run`."""

def on_advance_end(self) -> None:
"""Hook to be called each time after :attr:`advance` is called."""

def on_run_end(self) -> Any:
"""Hook to be called at the end of the run. Its return argument is returned from :attr:`run`."""

def teardown(self) -> None:
"""The very last method called inside :meth:`run`. Use to release memory etc."""
Loading

0 comments on commit 971908a

Please sign in to comment.