Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch_lightning.loops file structure: group by dataloader, epoch, and batch loop #8077

Merged
merged 12 commits into from
Jun 24, 2021
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Simplified "should run validation" logic ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
* Simplified logic for updating the learning rate for schedulers ([#7682](https://github.com/PyTorchLightning/pytorch-lightning/pull/7682))
* Removed the `on_epoch` guard from the "should stop" validation check ([#7701](https://github.com/PyTorchLightning/pytorch-lightning/pull/7701))
* Refactored internal loop interface; added new classes `FitLoop`, `TrainingEpochLoop`, `TrainingBatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871))
* Refactored internal loop interface; added new classes `FitLoop`, `TrainingEpochLoop`, `TrainingBatchLoop` ([#7871](https://github.com/PyTorchLightning/pytorch-lightning/pull/7871), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077))
* Removed `pytorch_lightning/trainer/training_loop.py` ([#7985](https://github.com/PyTorchLightning/pytorch-lightning/pull/7985))
* Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationDataLoaderLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990))
* Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077))
* Removed `pytorch_lightning/trainer/evaluation_loop.py` ([#8056](https://github.com/PyTorchLightning/pytorch-lightning/pull/8056))
* Restricted public access to several internal functions ([#8024](https://github.com/PyTorchLightning/pytorch-lightning/pull/8024))
* Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065))
* Refactored prediction loop interface; added new classes `PredictionDataLoaderLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700))
* Refactored prediction loop interface; added new classes `PredictionLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700), [#8077](https://github.com/PyTorchLightning/pytorch-lightning/pull/8077))
* Removed `pytorch_lightning/trainer/predict_loop.py` ([#8094](https://github.com/PyTorchLightning/pytorch-lightning/pull/8094))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def _store(

def on_train_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
for opt_idx, optimizer in trainer.fit_loop.training_loop.batch_loop.get_active_optimizers():
for opt_idx, optimizer in trainer.fit_loop.epoch_loop.batch_loop.get_active_optimizers():
num_param_groups = len(optimizer.param_groups)
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
current_param_groups = optimizer.param_groups
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1371,7 +1371,7 @@ def training_step(...):

# backward
self._running_manual_backward = True
self.trainer.fit_loop.training_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs)
self.trainer.fit_loop.epoch_loop.batch_loop.backward(loss, optimizer=None, opt_idx=None, *args, **kwargs)
self._running_manual_backward = False

def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
Expand Down Expand Up @@ -1471,7 +1471,7 @@ def optimizer_step(
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter
to ``optimizer.step()`` function as shown in the examples. This ensures that
``training_step()``, ``optimizer.zero_grad()``, ``backward()`` are called within
:meth:`~pytorch_lightning.trainer.fit_loop.training_loop.batch_loop.TrainingBatchLoop.advance`.
:meth:`~pytorch_lightning.loops.training_batch_loop.TrainingBatchLoop.advance`.

Args:
epoch: Current epoch
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def toggle_model(self, sync_grad: bool = True):
during the accumulation phase.
Setting `sync_grad` to False will block this synchronization and improve performance.
"""
with self._trainer.fit_loop.training_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad):
with self._trainer.fit_loop.epoch_loop.batch_loop.block_ddp_sync_behaviour(not sync_grad):
self._toggle_model()
yield
self._untoggle_model()
Expand Down
7 changes: 3 additions & 4 deletions pytorch_lightning/loops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# limitations under the License.

from pytorch_lightning.loops.base import Loop # noqa: F401
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401
from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop # noqa: F401
from pytorch_lightning.loops.batch import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.dataloader import DataLoaderLoop, EvaluationLoop, PredictionLoop # noqa: F401
from pytorch_lightning.loops.epoch import EvaluationEpochLoop, PredictionEpochLoop, TrainingEpochLoop # noqa: F401
from pytorch_lightning.loops.fit_loop import FitLoop # noqa: F401
from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop # noqa: F401
from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop # noqa: F401
15 changes: 15 additions & 0 deletions pytorch_lightning/loops/batch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop # noqa: F401
3 changes: 2 additions & 1 deletion pytorch_lightning/loops/dataloader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
# limitations under the License.

from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop # noqa: F401
from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop # noqa: F401
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop # noqa: F401
from pytorch_lightning.loops.dataloader.prediction_loop import PredictionLoop # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

import pytorch_lightning as pl
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.evaluation_epoch_loop import EvaluationEpochLoop
from pytorch_lightning.loops.epoch.evaluation_epoch_loop import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT


class EvaluationDataLoaderLoop(DataLoaderLoop):
class EvaluationLoop(DataLoaderLoop):
"""Loops over all dataloaders for evaluation."""

def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@

import pytorch_lightning as pl
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT


class PredictionDataLoaderLoop(DataLoaderLoop):
class PredictionLoop(DataLoaderLoop):
"""Loop to run over dataloaders for prediction"""

def __init__(self):
Expand Down
17 changes: 17 additions & 0 deletions pytorch_lightning/loops/epoch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.loops.epoch.evaluation_epoch_loop import EvaluationEpochLoop # noqa: F401
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop # noqa: F401
from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pytorch_lightning as pl
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.training_batch_loop import TrainingBatchLoop
from pytorch_lightning.loops.batch.training_batch_loop import TrainingBatchLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -175,10 +175,10 @@ def on_advance_end(self):

def _run_validation(self):
# reload dataloaders
self.trainer.fit_loop.validation_loop.reload_evaluation_dataloaders()
self.trainer.fit_loop.val_loop.reload_evaluation_dataloaders()

with torch.no_grad():
self.trainer.fit_loop.validation_loop.run()
self.trainer.fit_loop.val_loop.run()

def on_run_end(self) -> List[List[STEP_OUTPUT]]:
"""Calls the on_epoch_end hook.
Expand Down
56 changes: 28 additions & 28 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import pytorch_lightning as pl
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop
from pytorch_lightning.loops.training_epoch_loop import TrainingEpochLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.epoch.training_epoch_loop import TrainingEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_info
Expand Down Expand Up @@ -51,15 +51,15 @@ def __init__(
super().__init__()
self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs
self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs
self.training_loop = TrainingEpochLoop(min_steps, max_steps)
self.validation_loop = EvaluationDataLoaderLoop()
self.epoch_loop = TrainingEpochLoop(min_steps, max_steps)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.val_loop = EvaluationLoop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need to move this inside the epoch loop as we discussed.

Or do you want to do that in a separate PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I have it already implemented locally,
there is going to be a circular import due to that, so I would like to discuss the resolution of that in a separate PR.


@property
def results(self) -> ResultCollection:
if self.trainer.training:
return self.training_loop.results
return self.epoch_loop.results
elif self.trainer.validating:
return self.validation_loop.results
return self.val_loop.results
raise RuntimeError("`FitLoop.results` property isn't defined. Accessed outside of scope")

@property
Expand All @@ -75,59 +75,59 @@ def current_epoch(self, value: int) -> None:
@property
def global_step(self) -> int:
"""Returns the global step"""
return self.training_loop.global_step
return self.epoch_loop.global_step

@global_step.setter
def global_step(self, value: int) -> None:
"""Sets the global step (forwards to training_loop)"""
self.training_loop.global_step = value
"""Sets the global step (forwards to epoch_loop)"""
self.epoch_loop.global_step = value

@property
def total_batch_idx(self) -> int:
"""Returns the total number of batches already run (across all epochs)"""
return self.training_loop.total_batch_idx
return self.epoch_loop.total_batch_idx

@property
def batch_idx(self) -> int:
"""Returns the number of batches already run within this epoch"""
return self.training_loop.iteration_count
return self.epoch_loop.iteration_count

@property
def split_idx(self) -> int:
"""Returns the index of the current batch split (within the current batch) for bptt"""
return self.training_loop.split_idx
return self.epoch_loop.split_idx

@property
def min_steps(self) -> int:
# TODO(@justusschock): Why aren't we using the attribute in this class?
"""Returns the minimum numnber of steps to run"""
return self.training_loop.min_steps
return self.epoch_loop.min_steps

@property
def max_steps(self) -> int:
"""Returns the maximum number of steps to run"""
return self.training_loop.max_steps
return self.epoch_loop.max_steps

@max_steps.setter
def max_steps(self, value: int) -> None:
"""Sets the maximum number of steps (forwards to training_loop)"""
"""Sets the maximum number of steps (forwards to epoch_loop)"""
# TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoided
self.training_loop.max_steps = value
self.epoch_loop.max_steps = value

@property
def running_loss(self) -> TensorRunningAccum:
"""Returns the running loss"""
return self.training_loop.batch_loop.running_loss
return self.epoch_loop.batch_loop.running_loss

@property
def _skip_backward(self) -> bool:
""" Determines whether the loop will skip backward during automatic optimization. """
return self.training_loop.batch_loop._skip_backward
return self.epoch_loop.batch_loop._skip_backward

@_skip_backward.setter
def _skip_backward(self, value: bool) -> None:
""" Determines whether the loop will skip backward during automatic optimization. """
self.training_loop.batch_loop._skip_backward = value
self.epoch_loop.batch_loop._skip_backward = value

@property
def done(self) -> bool:
Expand Down Expand Up @@ -165,8 +165,8 @@ def skip(self) -> bool:
def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with necessary arguments like the trainer"""
super().connect(trainer, *args, **kwargs)
self.training_loop.connect(trainer)
self.validation_loop.connect(trainer)
self.epoch_loop.connect(trainer)
self.val_loop.connect(trainer)

def reset(self) -> None:
"""Resets the internal state of this loop"""
Expand All @@ -193,7 +193,7 @@ def on_advance_start(self) -> None:
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)

# stores accumulated grad fractions per batch
self.training_loop.batch_loop.accumulated_loss = TensorRunningAccum(
self.epoch_loop.batch_loop.accumulated_loss = TensorRunningAccum(
window_length=self.trainer.accumulate_grad_batches
)

Expand All @@ -204,7 +204,7 @@ def advance(self) -> None:

with self.trainer.profiler.profile("run_training_epoch"):
# run train epoch
epoch_output = self.training_loop.run(train_dataloader)
epoch_output = self.epoch_loop.run(train_dataloader)

if epoch_output is None:
return
Expand All @@ -220,10 +220,10 @@ def advance(self) -> None:

def on_advance_end(self) -> None:
"""Updates the LR schedulers and does some internal bookkeeping"""
if self.training_loop.batches_seen == 0:
if self.epoch_loop.batches_seen == 0:
return

self.training_loop.update_lr_schedulers('epoch', update_plateau_schedulers=True)
self.epoch_loop.update_lr_schedulers('epoch', update_plateau_schedulers=True)

did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.skip
if did_train_only:
Expand All @@ -241,10 +241,10 @@ def on_run_end(self) -> None:

# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
# when a checkpoint was saved at the last step
self.training_loop.global_step -= 1
self.epoch_loop.global_step -= 1
# TODO: see discussion/rework https://github.com/PyTorchLightning/pytorch-lightning/issues/7406
self._check_checkpoint_callback(should_update=True, is_last=True)
self.training_loop.global_step += 1
self.epoch_loop.global_step += 1

# hook
self.trainer.call_hook("on_train_end")
Expand All @@ -266,7 +266,7 @@ def on_run_end(self) -> None:

def should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated"""
return self.training_loop.batch_loop.should_accumulate()
return self.epoch_loop.batch_loop.should_accumulate()

def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False):
"""Checks if checkpointing needs to be done"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def progress_bar_metrics(self) -> Dict[str, float]:
return self._progress_bar_metrics

def teardown(self):
self.trainer.fit_loop.training_loop._results.cpu()
self.trainer.fit_loop.validation_loop._results.cpu()
self.trainer.fit_loop.epoch_loop._results.cpu()
self.trainer.fit_loop.val_loop._results.cpu()
self.trainer.validation_loop._results.cpu()
self.trainer.test_loop._results.cpu()
14 changes: 7 additions & 7 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
from pytorch_lightning.loops.fit_loop import FitLoop
from pytorch_lightning.plugins import ParallelPlugin, PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
Expand Down Expand Up @@ -63,8 +63,8 @@ class TrainerProperties(ABC):
logger_connector: LoggerConnector
state: TrainerState
fit_loop: FitLoop
validation_loop: EvaluationDataLoaderLoop
test_loop: EvaluationDataLoaderLoop
validation_loop: EvaluationLoop
test_loop: EvaluationLoop
"""
Accelerator properties
"""
Expand Down Expand Up @@ -489,9 +489,9 @@ def sanity_checking(self, val: bool) -> None:
"""

@property
def evaluation_loop(self) -> EvaluationDataLoaderLoop:
def evaluation_loop(self) -> EvaluationLoop:
if self.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
return self.fit_loop.validation_loop
return self.fit_loop.val_loop
elif self.state.fn == TrainerFn.VALIDATING:
return self.validation_loop
elif self.state.fn == TrainerFn.TESTING:
Expand Down Expand Up @@ -524,10 +524,10 @@ def min_steps(self) -> Optional[int]:

@property
def is_last_batch(self) -> bool:
return self.fit_loop.training_loop.is_last_batch
return self.fit_loop.epoch_loop.is_last_batch

@property
def _active_loop(self) -> Optional[Union[FitLoop, EvaluationDataLoaderLoop]]:
def _active_loop(self) -> Optional[Union[FitLoop, EvaluationLoop]]:
if self.training:
return self.fit_loop
elif self.sanity_checking or self.evaluating:
Expand Down
Loading