diff --git a/CHANGELOG.md b/CHANGELOG.md index 666438b872b9c..c1ac8a689ce15 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -148,6 +148,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Refactored evaluation loop interface; added new classes `DataLoaderLoop`, `EvaluationDataLoaderLoop`, `EvaluationEpochLoop` ([#7990](https://github.com/PyTorchLightning/pytorch-lightning/pull/7990)) * Removed `pytorch_lightning/trainer/evaluation_loop.py` ([#8056](https://github.com/PyTorchLightning/pytorch-lightning/pull/8056)) * Refactored trainer `_run_*` functions and separate evaluation loops ([#8065](https://github.com/PyTorchLightning/pytorch-lightning/pull/8065)) + * Refactored prediction loop interface; added new classes `PredictionDataLoaderLoop`, `PredictionEpochLoop` ([#7700](https://github.com/PyTorchLightning/pytorch-lightning/pull/7700)) - Refactored logging diff --git a/pytorch_lightning/callbacks/prediction_writer.py b/pytorch_lightning/callbacks/prediction_writer.py index cbcff74ff0278..962877cc5a658 100644 --- a/pytorch_lightning/callbacks/prediction_writer.py +++ b/pytorch_lightning/callbacks/prediction_writer.py @@ -109,7 +109,7 @@ def on_predict_batch_end( if not self.interval.on_batch: return is_distributed = trainer.accelerator_connector.is_distributed - batch_indices = trainer.predict_loop.batch_indices if is_distributed else None + batch_indices = trainer.predict_loop.epoch_loop.current_batch_indices if is_distributed else None self.write_on_batch_end(trainer, pl_module, outputs, batch_indices, batch, batch_idx, dataloader_idx) def on_predict_epoch_end( diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index c823620a244aa..1d976aa3cd079 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -14,7 +14,6 @@ from abc import ABC, abstractmethod from typing import Any, Optional -from weakref import proxy from deprecate import void @@ -59,7 +58,8 @@ def skip(self) -> bool: def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects Loop with all the necessary things like connectors and accelerators.""" - self.trainer = proxy(trainer) + # TODO(@justusschock): Make the trainer a weakref/proxy + self.trainer = trainer def on_skip(self) -> Optional[Any]: """ diff --git a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py index 53e7b00b83b16..e5565d6a8912b 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_dataloader_loop.py @@ -70,7 +70,6 @@ def predictions(self): def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: """Connects the loop to everything necessary (like trainer and accelerators)""" super().connect(trainer, *args, **kwargs) - # TODO: Make the trainer a weakref/proxy self.epoch_loop.connect(trainer) @property diff --git a/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py b/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py new file mode 100644 index 0000000000000..80077e1e2aaae --- /dev/null +++ b/pytorch_lightning/loops/dataloader/prediction_dataloader_loop.py @@ -0,0 +1,148 @@ +from typing import Any, List, Optional, Sequence, Union + +from deprecate.utils import void +from torch.utils.data import DataLoader + +import pytorch_lightning as pl +from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop +from pytorch_lightning.loops.prediction_epoch_loop import PredictionEpochLoop +from pytorch_lightning.plugins import DDPSpawnPlugin +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.types import _PREDICT_OUTPUT + + +class PredictionDataLoaderLoop(DataLoaderLoop): + """Loop to run over dataloaders for prediction""" + + def __init__(self): + super().__init__() + self.epoch_loop: PredictionEpochLoop = PredictionEpochLoop() + self.predictions: Optional[List[List[Any]]] = None + self.epoch_batch_indices: Optional[List[List[int]]] = None + self._return_predictions: bool = False + + @property + def return_predictions(self) -> bool: + """Whether to return the predictions or not""" + return self._return_predictions + + @return_predictions.setter + def return_predictions(self, return_predictions: Optional[bool] = None) -> None: + # ``DDPSpawnPlugin`` plugins and derivate don't support return predictions. + is_ddp_spawn = isinstance(self.trainer.training_type_plugin, DDPSpawnPlugin) + if return_predictions and is_ddp_spawn: + raise MisconfigurationException( + "`return_predictions` should be set to `False` when using the `DDPSpawnPlugin` or children class. " + f"Found {return_predictions} with training_type_plugin {type(self.trainer.training_type_plugin)}." + ) + # For non ``DDPSpawnPlugin`` plugin, the `return_predictions` is True by default unless user decide otherwise. + self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions + + @property + def num_dataloaders(self) -> int: + """Returns the number of prediction dataloaders""" + # case where user does: + # return dl1, dl2 + dataloaders = self.dataloaders + length = len(dataloaders) + if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): + length = len(dataloaders[0]) + return length + + @property + def max_batches(self) -> List[int]: + """The max number of batches this loop will run for each dataloader.""" + max_batches = self.trainer.num_predict_batches + if isinstance(max_batches, int): + max_batches = [max_batches] * len(self.dataloaders) + return max_batches + + @property + def dataloaders(self) -> Sequence[DataLoader]: + """Returns all prediction dataloaders""" + return self.trainer.predict_dataloaders + + @property + def done(self) -> bool: + """Whether prediction is finished: Max batches run or all dataloaders processed""" + return self.current_dataloader_idx >= len(self.dataloaders) + + @property + def skip(self) -> bool: + return sum(self.max_batches) == 0 + + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: + """Connects the loop with all necessary things (like trainer)""" + super().connect(trainer, *args, **kwargs) + self.epoch_loop.connect(trainer, *args, **kwargs) + + def reset(self) -> None: + """Resets the internal state of the loop for a new run""" + super().reset() + self.predictions = [] + self.epoch_batch_indices = [] + + def on_run_start(self) -> None: + """Calls ``on_predict_start`` hook""" + self.on_predict_start() + + def advance(self, *args: Any, **kwargs: Any) -> None: + """Predicts one entire dataloader""" + void(*args, **kwargs) + dataloader = self.trainer.accelerator.process_dataloader(self.current_dataloader) + dataloader_iter = enumerate(dataloader) + dl_max_batches = self.max_batches[self.current_dataloader_idx] + + dl_predictions, dl_batch_indices = self.epoch_loop.run( + dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders, self.return_predictions + ) + self.predictions.append(dl_predictions) + self.epoch_batch_indices.append(dl_batch_indices) + + def on_run_end(self) -> Union[List[Any], List[List[Any]]]: + """Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders""" + results = self.on_predict_epoch_end() + self.on_predict_end() + return results + + def on_predict_start(self) -> None: + """ + Sets model to eval mode and disables gradients. Also calls ``on_predict_start`` and + ``on_predict_epoch_start`` hooks. + """ + # enable eval mode + no grads + self.on_predict_model_eval() + self.trainer.lightning_module.zero_grad() + + # hook + self.trainer.call_hook("on_predict_start") + self.trainer.call_hook("on_predict_epoch_start") + + def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: + """Calls ``on_predict_epoch_end`` hook. + + Returns: + the results for all dataloaders + """ + self.trainer.profiler.describe() + + results = self.predictions + + self.trainer.call_hook("on_predict_epoch_end", results) + + if self.return_predictions: + return results[0] if self.num_dataloaders == 1 else results + + def on_predict_end(self) -> None: + """Resets previous gradient status and calls ``on_predict_end`` hook""" + # clear memory. the predictions are extracted in `on_predict_epoch_end`. + self.predictions = [] + self.epoch_batch_indices = [] + + # hook + self.trainer.call_hook("on_predict_end") + + def on_predict_model_eval(self): + """Calls ``on_predict_model_eval`` hook""" + model_ref = self.trainer.lightning_module + model_ref.on_predict_model_eval() diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 3d726a7cc3d6c..d05a061660e93 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -16,7 +16,6 @@ from contextlib import suppress from typing import Any, List, Optional, Tuple -from deprecate import void from torch.optim import Optimizer import pytorch_lightning as pl @@ -167,10 +166,7 @@ def skip(self) -> bool: def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" - # TODO(@justusschock): Do we want to forward *args and **kwargs to the inner loop here? - # TODO(@justusschock): Can we make the trainer a weakref/proxy? - void(*args, **kwargs) - self.trainer = trainer + super().connect(trainer, *args, **kwargs) self.training_loop.connect(trainer) self.validation_loop.connect(trainer) diff --git a/pytorch_lightning/loops/prediction_epoch_loop.py b/pytorch_lightning/loops/prediction_epoch_loop.py new file mode 100644 index 0000000000000..258a81648a3e0 --- /dev/null +++ b/pytorch_lightning/loops/prediction_epoch_loop.py @@ -0,0 +1,151 @@ +from collections import OrderedDict +from typing import Any, Dict, Iterator, List, Optional, Tuple + +from deprecate import void + +from pytorch_lightning.loops.base import Loop +from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper +from pytorch_lightning.utilities.warnings import WarningCache + + +class PredictionEpochLoop(Loop): + """Loop performing prediction on arbitrary sequentially used dataloaders.""" + + def __init__(self) -> None: + super().__init__() + self.return_predictions: bool = False + self.predictions: List[Any] = [] + self.current_batch_indices: List[int] = [] + self._dl_max_batches: Optional[int] = None + self._num_dataloaders: Optional[int] = None + self._warning_cache = WarningCache() + self._all_batch_indices: List[int] = [] + + @property + def done(self) -> bool: + """Ends prediction when the iteration count exceeds the total number of available batches""" + return self.iteration_count >= self._dl_max_batches + + @property + def should_store_predictions(self) -> bool: + """Whether the predictions should be stored for later usage (e.g. aggregation or returning)""" + any_pred = any(cb.interval.on_epoch for cb in self.trainer.prediction_writer_callbacks) + return self.return_predictions or any_pred + + def reset(self) -> None: + """Resets the loops internal state""" + self.iteration_count = 0 + self._all_batch_indices: List[int] = [] + self.predictions: List[Any] = [] + + def on_run_start( + self, + dataloader_iter: Iterator, + dataloader_idx: int, + dl_max_batches: int, + num_dataloaders: int, + return_predictions: bool = False + ) -> None: + """ + Prepares the loops internal state + + Args: + dataloader_iter: the iterator over the current dataloader + dataloader_idx: the index of the current dataloader + dl_max_batches: the maximum number of batches the current loader can produce + num_dataloaders: the total number of dataloaders + return_predictions: whether to return the obtained predictions + """ + void(dataloader_iter, dataloader_idx) + self._dl_max_batches = dl_max_batches + self._num_dataloaders = num_dataloaders + self.return_predictions = return_predictions + + def advance( + self, + dataloader_iter: Iterator, + dataloader_idx: int, + dl_max_batches: int, + num_dataloaders: int, + return_predictions: bool = False + ) -> None: + """ + Runs one prediction step. + + Args: + dataloader_iter: the iterator over the current dataloader + dataloader_idx: the index of the current dataloader + dl_max_batches: the maximum number of batches the current loader can produce + num_dataloaders: the total number of dataloaders + return_predictions: whether to return the obtained predictions + """ + batch_idx, batch = next(dataloader_iter) + if batch is None: + raise StopIteration + + with self.trainer.profiler.profile("predict_step"): + self._predict_step(batch, batch_idx, dataloader_idx) + + def on_run_end(self) -> Tuple[Any, Any]: + """Returns the predictions and the corresponding batch indices""" + return self.predictions, self._all_batch_indices + + def teardown(self) -> None: + """Frees memory of collected predictions.""" + self.predictions = [] + self._all_batch_indices = [] + + def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + """Runs the actual predict step together with all the + necessary bookkeeping and the hooks tied to the predict step. + + Args: + batch: the current batch to run the prediction on + batch_idx: the index of the current batch + dataloader_idx: the index of the dataloader producing the current batch + """ + # configure step_kwargs + step_kwargs = self._build_kwargs(batch, batch_idx, dataloader_idx) + + # extract batch_indices and store them + self._store_batch_indices(dataloader_idx) + + model_ref = self.trainer.lightning_module + + self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx) + + model_ref._current_fx_name = "predict_step" + predictions = self.trainer.accelerator.predict_step(step_kwargs) + + if predictions is None: + self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") + + self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx) + + if self.should_store_predictions: + self.predictions.append(predictions) + + def _build_kwargs(self, batch: Any, batch_idx: int, dataloader_idx: int) -> Dict[str, Any]: + """ + Assembles the keyword arguments for the ``predict_step`` + + Args: + batch: the current batch to run the prediction on + batch_idx: the index of the current batch + dataloader_idx: the index of the dataloader producing the current batch + + Returns: + the dictionary containing all the keyboard arguments for the predict step + """ + step_kwargs = OrderedDict([('batch', batch), ('batch_idx', batch_idx)]) + if self._num_dataloaders > 1: + step_kwargs['dataloader_idx'] = dataloader_idx + return step_kwargs + + def _store_batch_indices(self, dataloader_idx: int) -> None: + """Stores the batch indices if the predictions should be stored""" + batch_sampler = self.trainer.predict_dataloaders[dataloader_idx].batch_sampler + if isinstance(batch_sampler, IndexBatchSamplerWrapper): + self.current_batch_indices = batch_sampler.batch_indices + if self.should_store_predictions: + self._all_batch_indices.append(batch_sampler.batch_indices) diff --git a/pytorch_lightning/loops/training_batch_loop.py b/pytorch_lightning/loops/training_batch_loop.py index e0ff96ac9d43f..2956da83faebd 100644 --- a/pytorch_lightning/loops/training_batch_loop.py +++ b/pytorch_lightning/loops/training_batch_loop.py @@ -23,7 +23,6 @@ from torch import Tensor from torch.optim import Optimizer -import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin @@ -67,11 +66,6 @@ def optimizer_freq_cumsum(self) -> int: self._optimizer_freq_cumsum = np.cumsum(self.trainer.optimizer_frequencies) return self._optimizer_freq_cumsum - def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: - # TODO(@justusschock): can we make this a weakref/proxy? - void(*args, **kwargs) - self.trainer = trainer - def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: """Runs all the data splits and the ``on_batch_start`` and ``on_train_batch_start`` hooks @@ -96,8 +90,9 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: return AttributeDict(signal=-1) super().run(batch, batch_idx, dataloader_idx) - - return AttributeDict(signal=0, training_step_output=self.batch_outputs) + output = AttributeDict(signal=0, training_step_output=self.batch_outputs) + self.batch_outputs = None # free memory + return output def reset(self) -> None: """Resets the loop state""" diff --git a/pytorch_lightning/loops/training_epoch_loop.py b/pytorch_lightning/loops/training_epoch_loop.py index 2a7000bc17ca8..b26f82557dfec 100644 --- a/pytorch_lightning/loops/training_epoch_loop.py +++ b/pytorch_lightning/loops/training_epoch_loop.py @@ -76,10 +76,7 @@ def done(self) -> bool: def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with all necessary parts like trainer and accelerators""" - - # TODO(@justusschock): should we forward *args and **kwargs to lower loops? - # TODO(@justusschock): can we make the trainer a proxy here? - self.trainer = trainer + super().connect(trainer, *args, **kwargs) self.batch_loop = TrainingBatchLoop() self.batch_loop.connect(trainer) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 009b5ea228056..dd201b49e427b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -28,6 +28,7 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loops.dataloader.evaluation_dataloader_loop import EvaluationDataLoaderLoop +from pytorch_lightning.loops.dataloader.prediction_dataloader_loop import PredictionDataLoaderLoop from pytorch_lightning.loops.fit_loop import FitLoop from pytorch_lightning.plugins import Plugin from pytorch_lightning.plugins.environments import ClusterEnvironment @@ -56,7 +57,6 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin -from pytorch_lightning.trainer.predict_loop import PredictLoop from pytorch_lightning.trainer.properties import TrainerProperties from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin @@ -345,11 +345,11 @@ def __init__( self.fit_loop = FitLoop(min_epochs, max_epochs, min_steps, max_steps) self.validation_loop = EvaluationDataLoaderLoop() self.test_loop = EvaluationDataLoaderLoop() - self.predict_loop = PredictLoop(self) - + self.predict_loop = PredictionDataLoaderLoop() self.fit_loop.connect(self) self.validation_loop.connect(self) self.test_loop.connect(self) + self.predict_loop.connect(self) # training state if weights_summary is not None and weights_summary not in ModelSummary.MODES: @@ -392,8 +392,6 @@ def __init__( truncated_bptt_steps, terminate_on_nan, ) - - self.predict_loop.on_trainer_init() self._setup_on_init(num_sanity_val_steps) # configure tuner @@ -424,12 +422,12 @@ def __init__( # Callback system self.on_init_end() - self._log_device_info() - def _setup_on_init( self, num_sanity_val_steps: int, ) -> None: + self._log_device_info() + self.should_stop = False self.state = TrainerState() self.num_training_batches = 0 @@ -453,6 +451,9 @@ def _setup_on_init( # when true, print evaluation results in .validate() and .test() self.verbose_evaluate = True + self.num_predict_batches = [] + self.predicted_ckpt_path = None + def _setup_fit(self, model, train_dataloader=None, val_dataloaders=None, datamodule=None): # clean hparams if hasattr(model, "hparams"): @@ -1019,42 +1020,9 @@ def _run_evaluate(self) -> _EVALUATE_OUTPUT: return eval_loop_results def _run_predict(self) -> Optional[_PREDICT_OUTPUT]: - # prepare dataloaders - dataloaders, max_batches = self.predict_loop.get_predict_dataloaders() - - # check if we want to skip this evaluation - if self.predict_loop.should_skip_predict(max_batches): - return [] - - # set up the eval loop - self.predict_loop.setup(max_batches, dataloaders) - - # call hook - self.predict_loop.on_predict_start() - - # run validation/testing - for dataloader_idx, dataloader in enumerate(dataloaders): - dataloader = self.accelerator.process_dataloader(dataloader) - dl_max_batches = self.predict_loop.max_batches[dataloader_idx] - for batch_idx, batch in enumerate(dataloader): - if batch is None: - continue - - # stop short when running on limited batches - if batch_idx >= dl_max_batches: - break - - # lightning module methods - with self.profiler.profile("predict_step"): - self.predict_loop.predict_step(batch, batch_idx, dataloader_idx) - - # call hook - results = self.predict_loop.on_predict_epoch_end() - - # call hook - self.predict_loop.on_predict_end() - - return results + self.reset_predict_dataloader(self.lightning_module) + with torch.no_grad(): + return self.predict_loop.run() def _run_sanity_check(self, ref_model): using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model) diff --git a/setup.cfg b/setup.cfg index 74e02d932dc3c..92e6f526944f8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,6 +45,8 @@ exclude_lines = omit = pytorch_lightning/cluster_environments/*.py pytorch_lightning/utilities/distributed.py + pytorch_lightning/trainer/evaluation_loop.py + pytorch_lightning/trainer/predict_loop.py pytorch_lightning/tuner/auto_gpu_select.py