Skip to content

Commit

Permalink
Loop Refactor 5/N - Prediction Loop (#7700)
Browse files Browse the repository at this point in the history
* integrate d180bb2

* Minor changes

* Refactor loop logic into logger connector

* Refactor test

* Tighter fx validator

* Add back split idx

* Typing

* update

* Conflict

* Fix tests

* resolve grad_norm

* update

* move to train loop

* Bye grad_norm_dict parameter

* Fix sync test

* update

* Fix bug when validation is run mid epoch

* fix grad_norm_dict test

* Fix fx_validator test

* fix grad_norm_dict test

* Fix order bug

* Detach tensors in test

* resolve some tests

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove pdb

* resolve flake8

* Update test

* more tests

* Revert last thomas' changes

* resolve 1 test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Refactor context restoration

* integrate latest changes from logger connector refactor poc

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* integrate latest changes from logger connector refactor poc

* Minor changes

* update changelog

* Remove unused argument

* Update CHANGELOG

* Copy call_hook changes

* Docs

* Fix ref

* move to cpu

* Bad merge

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove pdb

* remove pdb

* Refactor to

* Avoid partial

* trigger ci

* Bad merge

* integrate latest logger connector changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove grad norm dicts list

* Diff

* properties first

* Bad merge

* Reuse metrics_to_scalars

* Use active loop

* Move to device

* resolve test

* integrate latest changes from logger connector poc

* define union

* define union

* Update logger connector

* Update result

* Update imports

* Update after rename

* Refactor reduce_fx and op

* Fix test after rename

* mypy

* integrate latest logger connector refactor poc changes

* Fix test

* Refactor test

* Deprecate `self.log(sync_dist_op)` in favor of `self.log(reduce_fx)`

* Undo field

* add redundant return

* rename

rename files and classes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename

* Replace code

* Fix names and imports

* Remove metric_attribute

* imports

* loop hygiene

* yapf on loops

* protected new loop trigger

* rename NEW LOOP guard

* integrate latest logger connector changes

* integrate latest logger connector changes (eval loop)

* resolve todo dataloading reset

* re-add notebooks

* add missing init

* bad merge

* remove NEW_LOOP guard

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* flake8

* exclude coverage


coverage

* integrate #7917, remove teardown from training loop

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update "accumulated_batches_reached" condition

 based on if iter count was updated  or not

* remove public loop properties

* make skip backward protected again

* typing base loop

* typing fit loop

* typing training_batch_loop

* typing evaluation loop

* typing prediction loop

* typing training epoch loop

* dataloader_loop

* evaluation_dataloader_loop

* prediction_dataloader_loop

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* integrate train loop changes from master

* integrate eval loop changes from master

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tpipes moving model to cpu and leaving it there.

* don't reset fit loop


don't reset fit loop

* fix test iteration count <-> batch_idx reset

* replace torch.Tensor -> Tensor

* fix attribute error to block_ddp_sync_behaviour

* fix flake8 and yapf conflict

* remove redundant override

* add classes

Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
Co-authored-by: Justus Schock <justus.schock@posteo.de>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* trainer changes

* connect

* clean up

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test renaming

* rename evaluation loop to evaluation epoch loop

* minor docstring improvements

* update chlog

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* try ci fix

* update code owners for pl/loops

* update mock path

* re-order

* simplify dataloader reset

* simplify get_dataloaders()

* save predictions on_run_end()

* improve skip condition re-routing

* re-order

* remove unused type import

* check which assert is failing

* pig

* hobbit

* teardown for evaluation

* Revert "hobbit"

This reverts commit e81b0db.

* Revert "pig"

This reverts commit 33d89e0.

* Revert "check which assert is failing"

This reverts commit b7483b4.

* free memory in fit loop teardown

* update docstring

* period

* remove dead code

* else carlos

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* update chlog

* unused imp

* move default construction in run_evaluation

* add something for lawyer to read

* switch typehint for eval loop trainer property

* add missing imports

* remove a todo that needs more discussion

* combine _get_num_dataloaders with the property

* Update pytorch_lightning/loops/dataloader/dataloader_loop.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* black + yapf

* avoid coverage on old unused eval loop

* empty space in docstring

Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>

* resolve todo for args forwarding

* weekproxy trainer

* fix check for num dataloaders kwargs

* clean up num prediction dataloaders property

* free memory

* rm notebooks folder

* rm old file

* revert changes to old eval loop

* bad merge

* undo teardown

* setup signature

* remove file for notes

* free memory

* chlog

* Revert "weekproxy trainer"

This reverts commit d4e6969.

* connect trainer

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* clean up max batches and dataloaders

* max batches handling

* no grad handling

* unused argument

* protected attrs

* unused imports

* undo unintentional rename

* consistent naming

* capitalization in docstring

* list all args

* Update pytorch_lightning/loops/prediction_epoch_loop.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/loops/prediction_epoch_loop.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/loops/prediction_epoch_loop.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
Co-authored-by: tchaton <thomas@grid.ai>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justus Schock <justus.schock@posteo.de>
Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk>
  • Loading branch information
8 people committed Jun 23, 2021
1 parent 58b47da commit a45ab00
Show file tree
Hide file tree
Showing 11 changed files with 321 additions and 64 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationDataLoaderLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990))
* Removed `pytorch_lightning/trainer/evaluation_loop.py` ([#8056](https://github.com/PyTorchLightning/pytorch-lightning/pull/8056))
* Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065))
* Refactored prediction loop interface; added new classes `PredictionDataLoaderLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700))


- Refactored logging
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/callbacks/prediction_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def on_predict_batch_end(
if not self.interval.on_batch:
return
is_distributed = trainer.accelerator_connector.is_distributed
batch_indices = trainer.predict_loop.batch_indices if is_distributed else None
batch_indices = trainer.predict_loop.epoch_loop.current_batch_indices if is_distributed else None
self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx)

def on_predict_epoch_end(
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

from abc import ABC, abstractmethod
from typing import Any, Optional
from weakref import proxy

from deprecate import void

Expand Down Expand Up @@ -59,7 +58,8 @@ def skip(self) -> bool:

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects Loop with all the necessary things like connectors and accelerators."""
self.trainer = proxy(trainer)
# TODO(@justusschock): Make the trainer a weakref/proxy
self.trainer = trainer

def on_skip(self) -> Optional[Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def predictions(self):
def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
"""Connects the loop to everything necessary (like trainer and accelerators)"""
super().connect(trainer, *args, **kwargs)
# TODO: Make the trainer a weakref/proxy
self.epoch_loop.connect(trainer)

@property
Expand Down
148 changes: 148 additions & 0 deletions pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
from typing import Any, List, Optional, Sequence, Union

from deprecate.utils import void
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT


class PredictionDataLoaderLoop(DataLoaderLoop):
"""Loop to run over dataloaders for prediction"""

def __init__(self):
super().__init__()
self.epoch_loop: PredictionEpochLoop = PredictionEpochLoop()
self.predictions: Optional[List[List[Any]]] = None
self.epoch_batch_indices: Optional[List[List[int]]] = None
self._return_predictions: bool = False

@property
def return_predictions(self) -> bool:
"""Whether to return the predictions or not"""
return self._return_predictions

@return_predictions.setter
def return_predictions(self, return_predictions: Optional[bool] = None) -> None:
# ``DDPSpawnPlugin`` plugins and derivate don't support return predictions.
is_ddp_spawn = isinstance(self.trainer.training_type_plugin, DDPSpawnPlugin)
if return_predictions and is_ddp_spawn:
raise MisconfigurationException(
"`return_predictions` should be set to `False` when using the `DDPSpawnPlugin` or children class. "
f"Found {return_predictions} with training_type_plugin {type(self.trainer.training_type_plugin)}."
)
# For non ``DDPSpawnPlugin`` plugin, the `return_predictions` is True by default unless user decide otherwise.
self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions

@property
def num_dataloaders(self) -> int:
"""Returns the number of prediction dataloaders"""
# case where user does:
# return dl1, dl2
dataloaders = self.dataloaders
length = len(dataloaders)
if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)):
length = len(dataloaders[0])
return length

@property
def max_batches(self) -> List[int]:
"""The max number of batches this loop will run for each dataloader."""
max_batches = self.trainer.num_predict_batches
if isinstance(max_batches, int):
max_batches = [max_batches] * len(self.dataloaders)
return max_batches

@property
def dataloaders(self) -> Sequence[DataLoader]:
"""Returns all prediction dataloaders"""
return self.trainer.predict_dataloaders

@property
def done(self) -> bool:
"""Whether prediction is finished: Max batches run or all dataloaders processed"""
return self.current_dataloader_idx >= len(self.dataloaders)

@property
def skip(self) -> bool:
return sum(self.max_batches) == 0

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with all necessary things (like trainer)"""
super().connect(trainer, *args, **kwargs)
self.epoch_loop.connect(trainer, *args, **kwargs)

def reset(self) -> None:
"""Resets the internal state of the loop for a new run"""
super().reset()
self.predictions = []
self.epoch_batch_indices = []

def on_run_start(self) -> None:
"""Calls ``on_predict_start`` hook"""
self.on_predict_start()

def advance(self, *args: Any, **kwargs: Any) -> None:
"""Predicts one entire dataloader"""
void(*args, **kwargs)
dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader)
dataloader_iter = enumerate(dataloader)
dl_max_batches = self.max_batches[self.current_dataloader_idx]

dl_predictions, dl_batch_indices = self.epoch_loop.run(
dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders, self.return_predictions
)
self.predictions.append(dl_predictions)
self.epoch_batch_indices.append(dl_batch_indices)

def on_run_end(self) -> Union[List[Any], List[List[Any]]]:
"""Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders"""
results = self.on_predict_epoch_end()
self.on_predict_end()
return results

def on_predict_start(self) -> None:
"""
Sets model to eval mode and disables gradients. Also calls ``on_predict_start`` and
``on_predict_epoch_start`` hooks.
"""
# enable eval mode + no grads
self.on_predict_model_eval()
self.trainer.lightning_module.zero_grad()

# hook
self.trainer.call_hook("on_predict_start")
self.trainer.call_hook("on_predict_epoch_start")

def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
"""Calls ``on_predict_epoch_end`` hook.
Returns:
the results for all dataloaders
"""
self.trainer.profiler.describe()

results = self.predictions

self.trainer.call_hook("on_predict_epoch_end", results)

if self.return_predictions:
return results[0] if self.num_dataloaders == 1 else results

def on_predict_end(self) -> None:
"""Resets previous gradient status and calls ``on_predict_end`` hook"""
# clear memory. the predictions are extracted in `on_predict_epoch_end`.
self.predictions = []
self.epoch_batch_indices = []

# hook
self.trainer.call_hook("on_predict_end")

def on_predict_model_eval(self):
"""Calls ``on_predict_model_eval`` hook"""
model_ref = self.trainer.lightning_module
model_ref.on_predict_model_eval()
6 changes: 1 addition & 5 deletions pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from contextlib import suppress
from typing import Any, List, Optional, Tuple

from deprecate import void
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand Down Expand Up @@ -167,10 +166,7 @@ def skip(self) -> bool:

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
"""Connects the loop with necessary arguments like the trainer"""
# TODO(@justusschock): Do we want to forward *args and **kwargs to the inner loop here?
# TODO(@justusschock): Can we make the trainer a weakref/proxy?
void(*args, **kwargs)
self.trainer = trainer
super().connect(trainer, *args, **kwargs)
self.training_loop.connect(trainer)
self.validation_loop.connect(trainer)

Expand Down
151 changes: 151 additions & 0 deletions pytorch_lightning/loops/prediction_epoch_loop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from collections import OrderedDict
from typing import Any, Dict, Iterator, List, Optional, Tuple

from deprecate import void

from pytorch_lightning.loops.base import Loop
from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper
from pytorch_lightning.utilities.warnings import WarningCache


class PredictionEpochLoop(Loop):
"""Loop performing prediction on arbitrary sequentially used dataloaders."""

def __init__(self) -> None:
super().__init__()
self.return_predictions: bool = False
self.predictions: List[Any] = []
self.current_batch_indices: List[int] = []
self._dl_max_batches: Optional[int] = None
self._num_dataloaders: Optional[int] = None
self._warning_cache = WarningCache()
self._all_batch_indices: List[int] = []

@property
def done(self) -> bool:
"""Ends prediction when the iteration count exceeds the total number of available batches"""
return self.iteration_count >= self._dl_max_batches

@property
def should_store_predictions(self) -> bool:
"""Whether the predictions should be stored for later usage (e.g. aggregation or returning)"""
any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks)
return self.return_predictions or any_pred

def reset(self) -> None:
"""Resets the loops internal state"""
self.iteration_count = 0
self._all_batch_indices: List[int] = []
self.predictions: List[Any] = []

def on_run_start(
self,
dataloader_iter: Iterator,
dataloader_idx: int,
dl_max_batches: int,
num_dataloaders: int,
return_predictions: bool = False
) -> None:
"""
Prepares the loops internal state
Args:
dataloader_iter: the iterator over the current dataloader
dataloader_idx: the index of the current dataloader
dl_max_batches: the maximum number of batches the current loader can produce
num_dataloaders: the total number of dataloaders
return_predictions: whether to return the obtained predictions
"""
void(dataloader_iter, dataloader_idx)
self._dl_max_batches = dl_max_batches
self._num_dataloaders = num_dataloaders
self.return_predictions = return_predictions

def advance(
self,
dataloader_iter: Iterator,
dataloader_idx: int,
dl_max_batches: int,
num_dataloaders: int,
return_predictions: bool = False
) -> None:
"""
Runs one prediction step.
Args:
dataloader_iter: the iterator over the current dataloader
dataloader_idx: the index of the current dataloader
dl_max_batches: the maximum number of batches the current loader can produce
num_dataloaders: the total number of dataloaders
return_predictions: whether to return the obtained predictions
"""
batch_idx, batch = next(dataloader_iter)
if batch is None:
raise StopIteration

with self.trainer.profiler.profile("predict_step"):
self._predict_step(batch, batch_idx, dataloader_idx)

def on_run_end(self) -> Tuple[Any, Any]:
"""Returns the predictions and the corresponding batch indices"""
return self.predictions, self._all_batch_indices

def teardown(self) -> None:
"""Frees memory of collected predictions."""
self.predictions = []
self._all_batch_indices = []

def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
"""Runs the actual predict step together with all the
necessary bookkeeping and the hooks tied to the predict step.
Args:
batch: the current batch to run the prediction on
batch_idx: the index of the current batch
dataloader_idx: the index of the dataloader producing the current batch
"""
# configure step_kwargs
step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx)

# extract batch_indices and store them
self._store_batch_indices(dataloader_idx)

model_ref = self.trainer.lightning_module

self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx)

model_ref._current_fx_name = "predict_step"
predictions = self.trainer.accelerator.predict_step(step_kwargs)

if predictions is None:
self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")

self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx)

if self.should_store_predictions:
self.predictions.append(predictions)

def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Any]:
"""
Assembles the keyword arguments for the ``predict_step``
Args:
batch: the current batch to run the prediction on
batch_idx: the index of the current batch
dataloader_idx: the index of the dataloader producing the current batch
Returns:
the dictionary containing all the keyboard arguments for the predict step
"""
step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)])
if self._num_dataloaders > 1:
step_kwargs['dataloader_idx'] = dataloader_idx
return step_kwargs

def _store_batch_indices(self, dataloader_idx: int) -> None:
"""Stores the batch indices if the predictions should be stored"""
batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler
if isinstance(batch_sampler, IndexBatchSamplerWrapper):
self.current_batch_indices = batch_sampler.batch_indices
if self.should_store_predictions:
self._all_batch_indices.append(batch_sampler.batch_indices)
11 changes: 3 additions & 8 deletions pytorch_lightning/loops/training_batch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.plugins import ParallelPlugin
Expand Down Expand Up @@ -67,11 +66,6 @@ def optimizer_freq_cumsum(self) -> int:
self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies)
return self._optimizer_freq_cumsum

def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None:
# TODO(@justusschock): can we make this a weakref/proxy?
void(*args, **kwargs)
self.trainer = trainer

def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
"""Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks
Expand All @@ -96,8 +90,9 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict:
return AttributeDict(signal=-1)

super().run(batch, batch_idx, dataloader_idx)

return AttributeDict(signal=0, training_step_output=self.batch_outputs)
output = AttributeDict(signal=0, training_step_output=self.batch_outputs)
self.batch_outputs = None # free memory
return output

def reset(self) -> None:
"""Resets the loop state"""
Expand Down
Loading

0 comments on commit a45ab00

Please sign in to comment.