diff --git a/CHANGELOG.md b/CHANGELOG.md index d569b031e3841..484c1362f9d83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,9 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Progress tracking - * Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140)) + * Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) * Add `{,load_}state_dict` to the progress tracking dataclasses ([#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140)) - * Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244)) + * Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) - Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431)) @@ -92,6 +92,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault-tolerant training * Added `{,load_}state_dict` to `ResultCollection` ([#7948](https://github.com/PyTorchLightning/pytorch-lightning/pull/7948)) * Added `{,load_}state_dict` to `Loops` ([#8197](https://github.com/PyTorchLightning/pytorch-lightning/pull/8197)) + * Set `Loop.restarting=False` at the end of the first iteration ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) + * Save the loops state with the checkpoint (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) + * Save a checkpoint to restore the state on exception (opt-in) ([#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) - Added `rank_zero_only` to `LightningModule.log` function ([#7966](https://github.com/PyTorchLightning/pytorch-lightning/pull/7966)) @@ -393,8 +396,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed deprecated `optimizer` argument in `LightningModule.manual_backward()`; Toggling optimizers in manual optimization should be done using `LightningModule.{un}toggle_optimizer()` ([#8287](https://github.com/PyTorchLightning/pytorch-lightning/pull/8287)) - - ### Fixed - Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877)) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 639aa62784493..1efd67bb26f8e 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -46,6 +46,7 @@ class Loop(ABC): """ def __init__(self) -> None: + # TODO: replace by progress tracking self.iteration_count: int = 0 self.restarting = False self._trainer: Optional['pl.Trainer'] = None @@ -56,8 +57,8 @@ def trainer(self) -> Optional['pl.Trainer']: @trainer.setter def trainer(self, trainer: 'pl.Trainer'): - """Connect the Trainer to this loop and all children.""" - if not isinstance(trainer, pl.Trainer) and trainer is not None: + """Connects this loop's trainer and its children""" + if not isinstance(trainer, pl.Trainer): raise MisconfigurationException( f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}." ) @@ -112,6 +113,7 @@ def run(self, *args: Any, **kwargs: Any) -> Optional[Any]: self.advance(*args, **kwargs) self.on_advance_end() self.iteration_count += 1 + self.restarting = False except StopIteration: break @@ -158,7 +160,7 @@ def on_save_checkpoint(self) -> Dict: """ return {} - def on_load_checkpoint(self, state_dict: Dict): + def on_load_checkpoint(self, state_dict: Dict) -> None: """Called when loading a model checkpoint, use to reload loop state.""" def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = "") -> Dict: @@ -183,14 +185,14 @@ def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = return destination - def load_state_dict(self, state_dict: Dict, prefix="", restart_progress: bool = True): + def load_state_dict(self, state_dict: Dict, prefix: str = "", restart_progress: bool = True) -> None: """ Loads the state of this loop and all its children. """ self._load_from_state_dict(state_dict.copy(), prefix, restart_progress) for k, v in self.__dict__.items(): if isinstance(v, Loop): v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress) - def _load_from_state_dict(self, state_dict, prefix, restart_progress): + def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress: bool) -> None: for k, v in self.__dict__.items(): if isinstance(v, BaseProgress): v.load_state_dict(state_dict[prefix + k]) diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index 886fce3bb3961..3e5a8081f9eca 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -23,12 +23,11 @@ from torch import Tensor from torch.optim import Optimizer -import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loops.base import Loop from pytorch_lightning.plugins import ParallelPlugin from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import BatchProgress, OptimizationProgress +from pytorch_lightning.trainer.progress import OptimizationProgress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AMPType, AttributeDict, DeviceType, grad_norm from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -50,7 +49,6 @@ def __init__(self) -> None: self.running_loss: TensorRunningAccum = TensorRunningAccum(window_length=20) self.batch_idx: int = 0 self.split_idx: Optional[int] = None - self.progress = BatchProgress() self.optim_progress = OptimizationProgress() self._warning_cache: WarningCache = WarningCache() @@ -59,21 +57,6 @@ def __init__(self) -> None: self._remaining_splits: Optional[List[Any]] = None self._skip_backward: bool = False - def connect( - self, - trainer: 'pl.Trainer', - *args: Any, - progress: Optional[BatchProgress] = None, - optim_progress: Optional[OptimizationProgress] = None, - **kwargs: Any - ) -> None: - """Connects the loop with necessary arguments like the trainer""" - super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - if optim_progress is not None: - self.optim_progress = optim_progress - @property def done(self) -> bool: """Returns if all batch splits have been processed already""" @@ -109,6 +92,8 @@ def run(self, batch: Any, batch_idx: int, dataloader_idx: int) -> AttributeDict: if response == -1: return AttributeDict(signal=-1) + self.trainer.fit_loop.epoch_loop.batch_progress.increment_started() + super().run(batch, batch_idx, dataloader_idx) output = AttributeDict(signal=0, training_step_output=self.batch_outputs) self.batch_outputs = None # free memory @@ -149,6 +134,13 @@ def advance(self, batch, batch_idx, dataloader_idx): if self.trainer.lightning_module.automatic_optimization: for opt_idx, optimizer in self.get_active_optimizers(batch_idx): + # handle optimization restart + if self.restarting: + if opt_idx < self.optim_progress.optimizer_idx: + continue + + self.optim_progress.optimizer_idx = opt_idx + result = self._run_optimization(batch_idx, split_batch, opt_idx, optimizer) if result: self.batch_outputs[opt_idx].append(result.training_step_output) @@ -395,6 +387,8 @@ def _optimizer_step( # wraps into LightningOptimizer only for running step optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) + self.optim_progress.optimizer.step.increment_ready() + # model hook model_ref.optimizer_step( self.trainer.current_epoch, @@ -407,13 +401,17 @@ def _optimizer_step( using_lbfgs=is_lbfgs, ) + self.optim_progress.optimizer.step.increment_completed() + def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None: """Calls the ``on_before_zero_grad`` hook. Args: optimizer: the current optimizer """ + self.optim_progress.optimizer.zero_grad.increment_ready() self.trainer.call_hook('on_before_zero_grad', optimizer) + self.optim_progress.optimizer.zero_grad.increment_started() def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int) -> None: """Zeroes out all gradients of parameters optimized by the current optimizer. @@ -424,6 +422,7 @@ def _optimizer_zero_grad(self, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: the index of the current optimizer """ self.trainer.accelerator.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx) + self.optim_progress.optimizer.zero_grad.increment_completed() def _track_and_norm_grad(self, optimizer: torch.optim.Optimizer) -> Dict[str, Tensor]: """Tracks gradient norms and clips the gradients of all parameters optimized by the current optimizer. diff --git a/pytorch_lightning/loops/dataloader/dataloader_loop.py b/pytorch_lightning/loops/dataloader/dataloader_loop.py index ce255b73d0bba..65521aea547d8 100644 --- a/pytorch_lightning/loops/dataloader/dataloader_loop.py +++ b/pytorch_lightning/loops/dataloader/dataloader_loop.py @@ -13,16 +13,21 @@ # limitations under the License. from abc import abstractmethod -from typing import Sequence +from typing import Any, Sequence from torch.utils.data import DataLoader from pytorch_lightning.loops.base import Loop +from pytorch_lightning.trainer.progress import DataLoaderProgress class DataLoaderLoop(Loop): """Base class to loop over all dataloaders""" + def __init__(self): + super().__init__() + self.dataloader_progress = DataLoaderProgress() + @property @abstractmethod def dataloaders(self) -> Sequence[DataLoader]: @@ -31,7 +36,7 @@ def dataloaders(self) -> Sequence[DataLoader]: @property def current_dataloader_idx(self) -> int: """Returns the index of the current dataloader""" - return self.iteration_count + return self.dataloader_progress.current.ready - 1 @property def current_dataloader(self) -> DataLoader: @@ -46,8 +51,15 @@ def num_dataloaders(self) -> int: @property def done(self) -> bool: """Returns whether all dataloaders have been processed""" - return self.current_dataloader_idx >= self.num_dataloaders + return self.dataloader_progress.current.completed >= self.num_dataloaders def reset(self) -> None: """Resets the internal state""" - self.iteration_count = 0 + if not self.restarting: + self.dataloader_progress.current.reset() + + def on_advance_start(self, *args: Any, **kwargs: Any) -> None: + self.dataloader_progress.increment_ready() + + def on_advance_end(self) -> None: + self.dataloader_progress.increment_completed() diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 2f6e14b93b767..eab89eaf415b8 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -21,7 +21,6 @@ from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import EpochLoopProgress from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -33,8 +32,6 @@ class EvaluationLoop(DataLoaderLoop): def __init__(self): super().__init__() self.outputs = [] - self.progress = EpochLoopProgress() - self.epoch_loop = EvaluationEpochLoop() self._results = ResultCollection(training=False) @@ -66,19 +63,15 @@ def predictions(self): """Returns the predictions from all dataloaders""" return self.epoch_loop.predictions - def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any - ) -> None: + def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - self.epoch_loop.connect(trainer, progress=self.progress.epoch) + self.epoch_loop.connect(trainer) @property def done(self) -> bool: """Returns whether all dataloaders are processed or evaluation should be skipped altogether""" - return (self.current_dataloader_idx >= len(self.dataloaders)) or self.skip + return super().done or self.skip @property def skip(self) -> bool: @@ -88,7 +81,6 @@ def skip(self) -> bool: def reset(self) -> None: """Resets the internal state of the loop""" - self.iteration_count = 0 self._max_batches = self.get_max_batches() # bookkeeping self.outputs = [] @@ -96,6 +88,8 @@ def reset(self) -> None: if isinstance(self._max_batches, int): self._max_batches = [self._max_batches] * len(self.dataloaders) + super().reset() + def on_skip(self) -> List: return [] diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index 55647e5d7f2a3..e1de8669ddf68 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -7,7 +7,6 @@ from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop from pytorch_lightning.plugins import DDPSpawnPlugin -from pytorch_lightning.trainer.progress import EpochLoopProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import _PREDICT_OUTPUT @@ -19,8 +18,6 @@ def __init__(self): super().__init__() self.predictions: Optional[List[List[Any]]] = None self.epoch_batch_indices: Optional[List[List[int]]] = None - self.progress = EpochLoopProgress() - self.epoch_loop = PredictionEpochLoop() self._results = None # for `trainer._results` access @@ -67,23 +64,14 @@ def dataloaders(self) -> Sequence[DataLoader]: """Returns all prediction dataloaders""" return self.trainer.predict_dataloaders - @property - def done(self) -> bool: - """Whether prediction is finished: Max batches run or all dataloaders processed""" - return self.current_dataloader_idx >= len(self.dataloaders) - @property def skip(self) -> bool: return sum(self.max_batches) == 0 - def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochLoopProgress] = None, **kwargs: Any - ) -> None: + def connect(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - self.epoch_loop.connect(trainer, progress=self.progress.epoch) + self.epoch_loop.connect(trainer) def reset(self) -> None: """Resets the internal state of the loop for a new run""" diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 6f52ef8a5915e..bd697d8cc8653 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -18,10 +18,9 @@ from deprecate import void from torch import Tensor -import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import EpochProgress +from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import PredictionCollection from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -40,29 +39,23 @@ def __init__(self) -> None: self._dl_max_batches: Optional[int] = None self._num_dataloaders: Optional[int] = None self.outputs: List[STEP_OUTPUT] = [] - self.progress = EpochProgress() - - def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any - ) -> None: - """Connects the loop with necessary arguments like the trainer""" - super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress + self.batch_progress = Progress() @property def done(self) -> bool: """Returns ``True`` if the current iteration count reaches the number of dataloader batches.""" - return self.iteration_count >= self._dl_max_batches + return self.batch_progress.current.completed >= self._dl_max_batches def reset(self) -> None: """Resets the loop's internal state.""" - self.iteration_count = 0 self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) self._dl_max_batches = None self._num_dataloaders = None self.outputs = [] + if not self.restarting: + self.batch_progress.current.reset() + def on_run_start( self, dataloader_iter: Iterator, @@ -110,17 +103,25 @@ def advance( with self.trainer.profiler.profile("evaluation_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) + self.batch_progress.increment_ready() + # hook self.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) + self.batch_progress.increment_started() + # lightning module methods with self.trainer.profiler.profile("evaluation_step_and_end"): output = self.evaluation_step(batch, batch_idx, dataloader_idx) output = self.evaluation_step_end(output) + self.batch_progress.increment_processed() + # hook + store predictions self.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx) + self.batch_progress.increment_completed() + # log batch metrics self.trainer.logger_connector.update_eval_step_metrics() diff --git a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index ea03be5ef0096..da1aa0e42f210 100644 --- a/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -3,10 +3,9 @@ from deprecate import void -import pytorch_lightning as pl from pytorch_lightning.loops.base import Loop from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper -from pytorch_lightning.trainer.progress import EpochProgress +from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.utilities.warnings import WarningCache @@ -18,25 +17,17 @@ def __init__(self) -> None: self.return_predictions: bool = False self.predictions: List[Any] = [] self.current_batch_indices: List[int] = [] - self.progress = EpochProgress() + self.batch_progress = Progress() self._dl_max_batches: Optional[int] = None self._num_dataloaders: Optional[int] = None self._warning_cache = WarningCache() self._all_batch_indices: List[int] = [] - def connect( - self, trainer: "pl.Trainer", *args: Any, progress: Optional[EpochProgress] = None, **kwargs: Any - ) -> None: - """Connects the loop with necessary arguments like the trainer""" - super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - @property def done(self) -> bool: """Ends prediction when the iteration count exceeds the total number of available batches""" - return self.iteration_count >= self._dl_max_batches + return self.batch_progress.current.completed >= self._dl_max_batches @property def should_store_predictions(self) -> bool: @@ -46,9 +37,9 @@ def should_store_predictions(self) -> bool: def reset(self) -> None: """Resets the loops internal state""" - self.iteration_count = 0 self._all_batch_indices: List[int] = [] self.predictions: List[Any] = [] + self.batch_progress.current.reset() def on_run_start( self, @@ -98,6 +89,8 @@ def advance( with self.trainer.profiler.profile("predict_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=dataloader_idx) + self.batch_progress.increment_ready() + with self.trainer.profiler.profile("predict_step"): self._predict_step(batch, batch_idx, dataloader_idx) @@ -129,14 +122,20 @@ def _predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None self.trainer.call_hook("on_predict_batch_start", batch, batch_idx, dataloader_idx) + self.batch_progress.increment_started() + model_ref._current_fx_name = "predict_step" predictions = self.trainer.accelerator.predict_step(step_kwargs) + self.batch_progress.increment_processed() + if predictions is None: self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...") self.trainer.call_hook("on_predict_batch_end", predictions, batch, batch_idx, dataloader_idx) + self.batch_progress.increment_completed() + if self.should_store_predictions: self.predictions.append(predictions) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 9063e2153d31b..d9a2e6bb8cbb3 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -20,7 +20,7 @@ from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import TrainingEpochProgress +from pytorch_lightning.trainer.progress import Progress, SchedulerProgress from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature @@ -38,14 +38,14 @@ def __init__(self, min_steps: int, max_steps: int): self.global_step: int = 0 # the total batch index across all epochs self.total_batch_idx: int = 0 - # the current batch index in the loop that runs over the dataloader(s) - self.iteration_count: int = 0 # the current split index when the batch gets split into chunks in truncated backprop through time self.split_idx: Optional[int] = None # the number of batches seen this run, updates immediately after batch_loop.run() + # TODO: replace by progress tracking self.batches_seen: int = 0 self.is_last_batch: Optional[bool] = None - self.progress = TrainingEpochProgress() + self.batch_progress = Progress() + self.scheduler_progress = SchedulerProgress() self.batch_loop = TrainingBatchLoop() self.val_loop = loops.EvaluationLoop() @@ -69,19 +69,11 @@ def done(self) -> bool: max_steps_reached = self.max_steps is not None and self.global_step >= self.max_steps return max_steps_reached or self.trainer.should_stop or self._num_training_batches_reached(self.is_last_batch) - def connect( - self, - trainer: 'pl.Trainer', - *args: Any, - progress: Optional[TrainingEpochProgress] = None, - **kwargs: Any - ) -> None: + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - self.batch_loop.connect(trainer, progress=self.progress.batch, optim_progress=self.progress.optim) - self.val_loop.connect(trainer, progress=self.progress.val) + self.batch_loop.connect(trainer) + self.val_loop.connect(trainer) def reset(self) -> None: """Resets the internal state of the loop for a new run""" @@ -93,11 +85,19 @@ def reset(self) -> None: # track epoch output self._epoch_output = [[] for _ in range(self.batch_loop.num_active_optimizers(self.total_batch_idx))] + if self.restarting: + self.iteration_count = self.batches_seen = self.batch_progress.current.completed + else: + self.batch_progress.current.reset() + self.scheduler_progress.current.reset() + self.batch_loop.optim_progress.reset_on_epoch() + def on_run_start(self, *args: Any, **kwargs: Any) -> None: # hook self.trainer.logger_connector.on_epoch_start() self.trainer.call_hook("on_epoch_start") self.trainer.call_hook("on_train_epoch_start") + self.trainer.fit_loop.epoch_progress.increment_started() def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: """Runs a single training batch. @@ -117,10 +117,14 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: with self.trainer.profiler.profile("training_batch_to_device"): batch = self.trainer.accelerator.batch_to_device(batch, dataloader_idx=self._dataloader_idx) + self.batch_progress.increment_ready() + with self.trainer.profiler.profile("run_training_batch"): batch_output = self.batch_loop.run(batch, self.iteration_count, self._dataloader_idx) self.batches_seen += 1 + self.batch_progress.increment_processed() + # when returning -1 from train_step, we end epoch early if batch_output.signal == -1: raise StopIteration @@ -141,6 +145,8 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None: self.trainer.call_hook('on_batch_end') self.trainer.logger_connector.on_batch_end() + self.batch_progress.increment_completed() + # figure out what to track for epoch end self._track_epoch_end_reduce_metrics(self._epoch_output, batch_end_outputs) @@ -216,6 +222,8 @@ def on_run_end(self) -> List[List[STEP_OUTPUT]]: 'HINT: remove the return statement in training_epoch_end' ) + self.trainer.fit_loop.epoch_progress.increment_processed() + # call train epoch end hooks self._on_train_epoch_end_hook(processed_outputs) self.trainer.call_hook('on_epoch_end') @@ -432,10 +440,3 @@ def _save_loggers_on_train_batch_end(self) -> None: should_flush_logs = self.trainer.logger_connector.should_flush_logs if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None: self.trainer.logger.save() - - def state_dict(self) -> Dict: - return {"batch_loop": self.batch_loop.state_dict(), "val_loop": self.val_loop.state_dict()} - - def load_state_dict(self, state_dict: Dict) -> None: - self.batch_loop.load_state_dict(state_dict["batch_loop"]) - self.val_loop.load_state_dict(state_dict["val_loop"]) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 15ebde94c8997..7df0d1445e3b3 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -14,13 +14,13 @@ import logging from contextlib import suppress -from typing import Any, Dict, Optional +from typing import Any, Optional import pytorch_lightning as pl from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.progress import FitLoopProgress +from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_info @@ -51,19 +51,19 @@ def __init__( super().__init__() self.max_epochs = 1000 if (max_epochs is None and max_steps is None) else max_epochs self.min_epochs = 1 if (min_epochs is None and min_steps is None) else min_epochs - self.progress = FitLoopProgress() + self.epoch_progress = Progress() self.epoch_loop = TrainingEpochLoop(min_steps, max_steps) @property def current_epoch(self) -> int: """Return the current epoch""" - return self.iteration_count + return self.epoch_progress.current.completed @current_epoch.setter def current_epoch(self, value: int) -> None: """Setter for the current epoch""" - self.iteration_count = value + self.epoch_progress.current.completed = value @property def global_step(self) -> int: @@ -83,7 +83,7 @@ def total_batch_idx(self) -> int: @property def batch_idx(self) -> int: """Returns the number of batches already run within this epoch""" - return self.epoch_loop.iteration_count + return self.epoch_loop.batch_progress.current.ready - 1 @property def split_idx(self) -> int: @@ -169,14 +169,10 @@ def skip(self) -> bool: """Whether we should skip the training and immediately return from the call to :meth:`run`.""" return self.done or self.trainer.num_training_batches == 0 - def connect( - self, trainer: 'pl.Trainer', *args: Any, progress: Optional[FitLoopProgress] = None, **kwargs: Any - ) -> None: + def connect(self, trainer: 'pl.Trainer', *args: Any, **kwargs: Any) -> None: """Connects the loop with necessary arguments like the trainer""" super().connect(trainer, *args, **kwargs) - if progress is not None: - self.progress = progress - self.epoch_loop.connect(trainer, progress=self.progress.epoch) + self.epoch_loop.connect(trainer) def reset(self) -> None: """Resets the internal state of this loop""" @@ -207,6 +203,8 @@ def on_advance_start(self) -> None: window_length=self.trainer.accumulate_grad_batches ) + self.epoch_progress.increment_ready() + def advance(self) -> None: """Runs one whole epoch.""" train_dataloader = self.trainer.accelerator.process_dataloader(self.trainer.train_dataloader) @@ -230,14 +228,14 @@ def advance(self) -> None: def on_advance_end(self) -> None: """Updates the LR schedulers and does some internal bookkeeping""" - if self.epoch_loop.batches_seen == 0: - return + if self.epoch_loop.batches_seen != 0: + did_train_only = not self.trainer.enable_validation or self.epoch_loop.val_loop.skip + if did_train_only: + self.global_step -= 1 + self._check_checkpoint_callback(True) + self.global_step += 1 - did_train_only = not self.trainer.enable_validation or self.epoch_loop.val_loop.skip - if did_train_only: - self.global_step -= 1 - self._check_checkpoint_callback(True) - self.global_step += 1 + self.epoch_progress.increment_completed() def on_run_end(self) -> None: """Calls the ``on_train_end`` hook""" @@ -287,11 +285,5 @@ def _check_checkpoint_callback(self, should_update: bool, is_last: bool = False) for cb in callbacks: cb.on_validation_end(self.trainer, model) - def state_dict(self) -> Dict: - return {"epoch_loop": self.epoch_loop.state_dict()} - - def load_state_dict(self, state_dict: Dict) -> None: - self.epoch_loop.load_state_dict(state_dict["epoch_loop"]) - def teardown(self) -> None: self.epoch_loop.teardown() diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 1eccfbe52cc3d..1f282bc7c240c 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -23,6 +23,7 @@ from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS if _OMEGACONF_AVAILABLE: @@ -160,8 +161,8 @@ def restore_training_state(self) -> None: # restore precision plugin (scaler etc.) self.trainer.precision_plugin.on_load_checkpoint(self._loaded_checkpoint) - # restore progress (loops etc.) - self.restore_progress() + # restore loops and their progress + self.restore_loops() self.restore_optimizers_and_schedulers() @@ -179,10 +180,10 @@ def restore_callbacks(self) -> None: ) self.trainer.on_load_checkpoint(self._loaded_checkpoint) - def restore_progress(self) -> None: + def restore_loops(self) -> None: """ - Restores the training progress from the pre-loaded checkpoint. This currently includes only the global step - and current epoch. + Restores the loop progress from the pre-loaded checkpoint. + Calls hooks on the loops to give it a chance to restore its state from the checkpoint. """ if not self._loaded_checkpoint: return @@ -209,6 +210,13 @@ def restore_progress(self) -> None: " consider using an end of epoch checkpoint." ) + state_dict = self._loaded_checkpoint.get("loops") + if state_dict: + self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"]) + self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"]) + self.trainer.test_loop.load_state_dict(state_dict["test_loop"]) + self.trainer.predict_loop.load_state_dict(state_dict["predict_loop"]) + def restore_optimizers_and_schedulers(self) -> None: """ Restores the optimizers and learning rate scheduler states from the pre-loaded checkpoint. """ if not self._loaded_checkpoint: @@ -331,6 +339,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: 'pytorch-lightning_version': pl.__version__, 'state_dict': self.trainer.accelerator.lightning_module_state_dict(), } + if _fault_tolerant_enabled(): + checkpoint["loops"] = self._get_loops_state_dict() if not weights_only: # dump callbacks @@ -428,3 +438,11 @@ def save_checkpoint(self, filepath, weights_only: bool = False) -> None: """ _checkpoint = self.dump_checkpoint(weights_only) self.trainer.accelerator.save_checkpoint(_checkpoint, filepath) + + def _get_loops_state_dict(self): + return { + "fit_loop": self.trainer.fit_loop.state_dict(), + "validate_loop": self.trainer.validate_loop.state_dict(), + "test_loop": self.trainer.test_loop.state_dict(), + "predict_loop": self.trainer.predict_loop.state_dict(), + } diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 06ae55a1ca672..4c49b6e028cb4 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -83,6 +83,8 @@ def update_learning_rates( # update LR old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] + self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_ready() + if lr_scheduler['reduce_on_plateau']: lr_scheduler['scheduler'].step(monitor_val) else: @@ -90,6 +92,8 @@ def update_learning_rates( new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr'] + self.trainer.fit_loop.epoch_loop.scheduler_progress.increment_completed() + if self.trainer.dev_debugger.enabled: self.trainer.dev_debugger.track_lr_schedulers_update( self.trainer.fit_loop.batch_idx, diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index 1098957033855..fe9f90613ea9c 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -98,26 +98,18 @@ class Progress(BaseProgress): current: Tracker = field(default_factory=Tracker) def increment_ready(self) -> None: - if self.total.ready is None or self.current.ready is None: - return self.total.ready += 1 self.current.ready += 1 def increment_started(self) -> None: - if self.total.started is None or self.current.started is None: - return self.total.started += 1 self.current.started += 1 def increment_processed(self) -> None: - if self.total.processed is None or self.current.processed is None: - return self.total.processed += 1 self.current.processed += 1 def increment_completed(self) -> None: - if self.total.completed is None or self.current.completed is None: - return self.total.completed += 1 self.current.completed += 1 @@ -130,36 +122,34 @@ def load_state_dict(self, state_dict: dict) -> None: self.current.load_state_dict(state_dict["current"]) -class BatchProgress(Progress): +@dataclass +class DataLoaderProgress(Progress): """ - Tracks the batch progress + Tracks the dataloader progress + These counters are local to a trainer rank. By default, they are not globally synced across all ranks. Args: - total: Tracks the total epoch progress - current: Tracks the current epoch progress + total: Tracks the total dataloader progress + current: Tracks the current dataloader progress """ + total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) + current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) + @dataclass -class EpochProgress(Progress): +class SchedulerProgress(Progress): """ - Tracks the epoch progress + Tracks the scheduler progress These counters are local to a trainer rank. By default, they are not globally synced across all ranks. Args: - total: Tracks the total epoch progress - current: Tracks the current epoch progress - batch: Tracks batch progress. + total: Tracks the total scheduler progress + current: Tracks the current scheduler progress """ - batch: BatchProgress = field(default_factory=BatchProgress) - - def reset_on_epoch(self) -> None: - self.batch.current.reset() - - def load_state_dict(self, state_dict: dict) -> None: - super().load_state_dict(state_dict) - self.batch.load_state_dict(state_dict["batch"]) + total: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) + current: Tracker = field(default_factory=lambda: Tracker(started=None, processed=None)) @dataclass @@ -172,7 +162,7 @@ class OptimizerProgress(BaseProgress): zero_grad: Tracks ``optimizer.zero_grad`` calls. """ - step: Progress = field(default_factory=lambda: Progress.from_defaults(processed=None)) + step: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None)) zero_grad: Progress = field(default_factory=lambda: Progress.from_defaults(processed=None)) def reset_on_epoch(self) -> None: @@ -191,88 +181,21 @@ class OptimizationProgress(BaseProgress): Args: optimizer: Tracks optimizer progress. - scheduler: Tracks scheduler progress. + optimizer_idx: The index of the current optimizer. """ # TODO: support for multiple optimizers optimizer: OptimizerProgress = field(default_factory=OptimizerProgress) - scheduler: Progress = field(default_factory=lambda: Progress.from_defaults(started=None, processed=None)) + optimizer_idx: int = 0 @property def optimizer_steps(self) -> int: return self.optimizer.step.total.completed - @property - def scheduler_steps(self) -> int: - return self.scheduler.total.completed - def reset_on_epoch(self) -> None: self.optimizer.reset_on_epoch() - self.scheduler.current.reset() + self.optimizer_idx = 0 def load_state_dict(self, state_dict: dict) -> None: self.optimizer.load_state_dict(state_dict["optimizer"]) - self.scheduler.load_state_dict(state_dict["scheduler"]) - - -@dataclass -class EpochLoopProgress(BaseProgress): - """ - Tracks epoch loop progress. - These counters are local to a trainer rank. By default, they are not globally synced across all ranks. - - Args: - epoch: Tracks epochs progress. - """ - - epoch: EpochProgress = field(default_factory=EpochProgress) - - def increment_epoch_completed(self) -> None: - self.epoch.increment_completed() - self.reset_on_epoch() - - def reset_on_epoch(self) -> None: - self.epoch.reset_on_epoch() - self.epoch.current.reset() - - def load_state_dict(self, state_dict: dict) -> None: - self.epoch.load_state_dict(state_dict["epoch"]) - - -@dataclass -class TrainingEpochProgress(EpochProgress): - """ - Extends ``EpochProgress`` with training specific attributes - - Args: - total: Tracks the total epoch progress. - current: Tracks the current epoch progress. - batch: Tracks batch progress. - optim: Tracks optimization progress. - val: Tracks val_loop progress. - """ - - optim: OptimizationProgress = field(default_factory=OptimizationProgress) - val: EpochLoopProgress = field(default_factory=EpochLoopProgress) - - def load_state_dict(self, state_dict: dict) -> None: - super().load_state_dict(state_dict) - self.optim.load_state_dict(state_dict["optim"]) - self.val.load_state_dict(state_dict["val"]) - - -@dataclass -class FitLoopProgress(EpochLoopProgress): - """ - Extends ``EpochLoopProgress`` with fit specific attributes - - Args: - epoch: Tracks epochs progress. - """ - - epoch: TrainingEpochProgress = field(default_factory=TrainingEpochProgress) - - def reset_on_epoch(self) -> None: - # do not reset `epoch.current` as it should track the number of epochs this `fit` call - self.epoch.reset_on_epoch() - self.epoch.optim.reset_on_epoch() + self.optimizer_idx = state_dict["optimizer_idx"] diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ac7a41e3808f2..0f4dc6bc96cb8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -13,6 +13,7 @@ # limitations under the License. """Trainer to automate the training.""" import logging +import os import traceback import warnings from datetime import timedelta @@ -57,7 +58,6 @@ from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin -from pytorch_lightning.trainer.progress import EpochLoopProgress, FitLoopProgress from pytorch_lightning.trainer.properties import TrainerProperties from pytorch_lightning.trainer.states import TrainerFn, TrainerState, TrainerStatus from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin @@ -77,6 +77,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PREDICT_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS @@ -359,10 +360,10 @@ def __init__( self.validate_loop = EvaluationLoop() self.test_loop = EvaluationLoop() self.predict_loop = PredictionLoop() - self.fit_loop.connect(self, progress=FitLoopProgress()) - self.validate_loop.connect(self, progress=EpochLoopProgress()) - self.test_loop.connect(self, progress=EpochLoopProgress()) - self.predict_loop.connect(self, progress=EpochLoopProgress()) + self.fit_loop.connect(self) + self.validate_loop.connect(self) + self.test_loop.connect(self) + self.predict_loop.connect(self) # training state if weights_summary is not None and weights_summary not in ModelSummary.MODES: @@ -1020,6 +1021,7 @@ def _run_train(self) -> None: self.training_type_plugin.reconciliate_processes(traceback.format_exc()) # give accelerators a chance to finish self.accelerator.on_train_end() + self._on_expection() # reset bookkeeping self.state.stage = None raise @@ -1259,3 +1261,10 @@ def _log_device_info(self) -> None: "IPU available but not used. Set the `ipus` flag in your trainer" " `Trainer(ipus=8)` or script `--ipus=8`." ) + + def _on_expection(self): + if not self.is_global_zero or not _fault_tolerant_enabled(): + return + # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. + file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") + self.save_checkpoint(file_path) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 3247baf30ebff..75c453f9d995e 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -14,6 +14,7 @@ """General utilities""" import importlib import operator +import os import platform import sys from importlib.util import find_spec @@ -99,3 +100,7 @@ def _compare_version(package: str, op, version) -> bool: _IPU_AVAILABLE = poptorch.ipuHardwareIsAvailable() else: _IPU_AVAILABLE = False + + +def _fault_tolerant_enabled(): + return os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") == "1" diff --git a/tests/loops/test_loop_progress_integration.py b/tests/loops/test_loop_progress_integration.py deleted file mode 100644 index 986ea2543d6d8..0000000000000 --- a/tests/loops/test_loop_progress_integration.py +++ /dev/null @@ -1,22 +0,0 @@ -from pytorch_lightning import Trainer - - -def test_loop_progress_integration(): - trainer = Trainer() - fit_loop = trainer.fit_loop - # check identities inside the fit loop - assert fit_loop.progress.epoch is fit_loop.epoch_loop.progress - assert fit_loop.epoch_loop.progress.batch is fit_loop.epoch_loop.batch_loop.progress - assert fit_loop.epoch_loop.progress.optim is fit_loop.epoch_loop.batch_loop.optim_progress - assert fit_loop.epoch_loop.progress.val is fit_loop.epoch_loop.val_loop.progress - assert fit_loop.epoch_loop.val_loop.progress.epoch is fit_loop.epoch_loop.val_loop.epoch_loop.progress - # check identities inside the evaluation and predict loops - assert trainer.validate_loop.progress.epoch is trainer.validate_loop.epoch_loop.progress - assert trainer.test_loop.progress.epoch is trainer.test_loop.epoch_loop.progress - assert trainer.predict_loop.progress.epoch is trainer.predict_loop.epoch_loop.progress - # check no progresses are shared - assert trainer.fit_loop.progress is not trainer.validate_loop.progress - assert trainer.validate_loop.progress is not trainer.test_loop.progress - assert trainer.test_loop.progress is not trainer.predict_loop.progress - # check the validation progresses are not shared - assert trainer.fit_loop.epoch_loop.val_loop.progress is not trainer.validate_loop.progress diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index eed23a89a8b36..f014f8c619b54 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import pytest from pytorch_lightning.loops import FitLoop @@ -33,214 +32,89 @@ def test_loops_state_dict(): def test_loops_state_dict_structure(): trainer = Trainer() - # structure saved by the checkpoint connector - state_dict = { - "fit_loop": trainer.fit_loop.state_dict(), - "validate_loop": trainer.validate_loop.state_dict(), - "test_loop": trainer.test_loop.state_dict(), - "predict_loop": trainer.predict_loop.state_dict(), - } - # todo (tchaton) Update this once new progress as been added. + state_dict = trainer.checkpoint_connector._get_loops_state_dict() # yapf: disable expected = { "fit_loop": { - "epoch_loop": { - "batch_loop": { - "state_dict": {}, - "progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, - "optim_progress": { - "optimizer": { - "step": { - "total": { - "ready": 0, - "started": 0, - "processed": None, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": None, - "completed": 0, - }, - }, - "zero_grad": { - "total": { - "ready": 0, - "started": 0, - "processed": None, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": None, - "completed": 0, - }, - }, - }, - "scheduler": { - "total": { - "ready": 0, - "started": None, - "processed": None, - "completed": 0, - }, - "current": { - "ready": 0, - "started": None, - "processed": None, - "completed": 0, - }, - }, - }, - }, - "val_loop": { - "state_dict": {}, - "progress": { - "epoch": { - "total": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - "batch": { - "total": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, - } + "state_dict": {}, + "epoch_progress": { + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + + "epoch_loop.state_dict": {}, + "epoch_loop.batch_progress": { + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, + "epoch_loop.scheduler_progress": { + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + }, + + "epoch_loop.batch_loop.optim_progress": { + "optimizer": { + "step": { + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, - "epoch_loop.state_dict": {}, - "epoch_loop.progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - "batch": { - "total": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, + "zero_grad": { + "current": {"ready": 0, "started": 0, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": 0, "processed": None, "completed": 0}, }, }, - } + "optimizer_idx": 0, + }, + "epoch_loop.batch_loop.state_dict": {}, + + "epoch_loop.val_loop.state_dict": {}, + "epoch_loop.val_loop.dataloader_progress": { + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, + }, + + "epoch_loop.val_loop.epoch_loop.state_dict": {}, + "epoch_loop.val_loop.epoch_loop.batch_progress": { + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + }, }, - "validate_loop": { + "predict_loop": { "state_dict": {}, - "progress": { - "epoch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "batch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, - } + "dataloader_progress": { + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, + "epoch_loop.state_dict": {}, - "epoch_loop.progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "epoch_loop.batch_progress": { "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "batch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - }, + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, "test_loop": { "state_dict": {}, - "progress": { - "epoch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "batch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, - } + "dataloader_progress": { + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, "epoch_loop.state_dict": {}, - "epoch_loop.progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "epoch_loop.batch_progress": { "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "batch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - }, + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, - "predict_loop": { + "validate_loop": { "state_dict": {}, - "progress": { - "epoch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "batch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": { - "ready": 0, - "started": 0, - "processed": 0, - "completed": 0, - }, - }, - } + "dataloader_progress": { + "current": {"ready": 0, "started": None, "processed": None, "completed": 0}, + "total": {"ready": 0, "started": None, "processed": None, "completed": 0}, }, + "epoch_loop.state_dict": {}, - "epoch_loop.progress": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "epoch_loop.batch_progress": { "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "batch": { - "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, - }, + "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, } diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index aa1a0a74750a3..695a0c7be16a0 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -11,32 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections import OrderedDict +import os from copy import deepcopy from dataclasses import dataclass from typing import Any, Dict, Iterator +from unittest import mock +from unittest.mock import ANY + +import pytest +import torch from pytorch_lightning.loops.base import Loop from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.trainer.trainer import Trainer +from tests.helpers import BoringModel -def _collect_loop_progress(loop: Loop) -> Dict[str, Any]: - """Return the progress for the current loop and its children.""" - progress = {} - for k, v in loop.__dict__.items(): - if isinstance(v, BaseProgress): - progress[k] = v - elif isinstance(v, Loop): - progress[k] = _collect_loop_progress(v) - return progress +class CustomException(Exception): + pass def test_loop_restore(): - class CustomExpection(Exception): - pass - class Simple(Loop): def __init__(self, dataset: Iterator): @@ -53,12 +49,10 @@ def done(self) -> bool: def reset(self) -> None: self.iter_dataset = iter(self.dataset) - if self.restarting: for _ in range(self.iteration_count): next(self.iter_dataset) self.iteration_count += 1 - self.restarting = False else: self.outputs = [] @@ -66,7 +60,7 @@ def advance(self) -> None: value = next(self.iter_dataset) if self.iteration_count == 5: - raise CustomExpection + raise CustomException self.outputs.append(value) @@ -85,7 +79,7 @@ def load_state_dict(self, state_dict: Dict) -> None: try: loop.run() state_dict = {} - except CustomExpection: + except CustomException: state_dict = loop.state_dict() loop = Simple(data) @@ -102,15 +96,8 @@ def test_loop_hierarchy(): @dataclass class SimpleProgress(BaseProgress): - increment: int = 0 - def state_dict(self): - return {"increment": self.increment} - - def load_state_dict(self, state_dict): - self.increment = state_dict["increment"] - class Simple(Loop): def __init__(self, a): @@ -123,18 +110,16 @@ def advance(self, *args: Any, **kwargs: Any) -> None: if not loop: return loop.run() - self.progress.increment += 1 - @property - def skip(self) -> bool: - return False + def on_advance_end(self): + self.progress.increment += 1 @property def done(self) -> bool: - return self.iteration_count > 0 + return self.progress.increment > 0 def reset(self) -> None: - self.restarting = False + ... def on_save_checkpoint(self) -> Dict: return {"a": self.a} @@ -142,80 +127,358 @@ def on_save_checkpoint(self) -> Dict: def on_load_checkpoint(self, state_dict: Dict) -> None: self.a = state_dict["a"] - grand_loop_parent = Simple(0) loop_parent = Simple(1) loop_child = Simple(2) - loop_parent.loop_child = loop_child - assert not loop_parent.skip - - state_dict = loop_parent.state_dict() - - loop_progress = _collect_loop_progress(loop_parent) - assert loop_progress["progress"] == loop_parent.progress - assert loop_progress["loop_child"]["progress"] == loop_child.progress - - loop_progress = _collect_loop_progress(loop_child) - assert loop_progress["progress"] == loop_child.progress - + # check the trainer reference is propagated loop_parent.trainer = Trainer() - assert loop_child.trainer == loop_parent.trainer - - assert state_dict == OrderedDict([('state_dict', { - 'a': 1 - }), ('progress', { - 'increment': 0 - }), ('loop_child.state_dict', { - 'a': 2 - }), ('loop_child.progress', { - 'increment': 0 - })]) + assert loop_child.trainer is loop_parent.trainer - loop_parent.progress + state_dict = loop_parent.state_dict() + assert state_dict == { + 'state_dict': { + 'a': 1 + }, + 'progress': { + 'increment': 0 + }, + 'loop_child.state_dict': { + 'a': 2 + }, + 'loop_child.progress': { + 'increment': 0 + }, + } state_dict["loop_child.state_dict"]["a"] = 3 - + # check restarting after `load_state_dict` loop_parent.load_state_dict(state_dict) assert loop_parent.restarting loop_parent.run() + # check the new state after `run` + state_dict = loop_parent.state_dict() + assert state_dict == { + 'state_dict': { + 'a': 1 + }, + 'progress': { + 'increment': 1 + }, + 'loop_child.state_dict': { + 'a': 3 + }, + 'loop_child.progress': { + 'increment': 1 + }, + } + loop_parent_copy = deepcopy(loop_parent) assert loop_parent_copy.state_dict() == loop_parent.state_dict() - assert loop_parent_copy.on_save_checkpoint() == {'a': 1} - assert loop_parent_copy.loop_child.on_save_checkpoint() == {'a': 3} - - assert not loop_parent.restarting - - state_dict = loop_parent.state_dict() - assert state_dict == OrderedDict([('state_dict', { - 'a': 1 - }), ('progress', { - 'increment': 1 - }), ('loop_child.state_dict', { - 'a': 3 - }), ('loop_child.progress', { - 'increment': 0 - })]) + assert loop_parent_copy.on_save_checkpoint() == state_dict['state_dict'] + assert loop_parent_copy.loop_child.on_save_checkpoint() == state_dict['loop_child.state_dict'] loop_parent = Simple(1) loop_child = Simple(2) loop_parent.loop_child = loop_child loop_parent.load_state_dict(state_dict) assert loop_parent.progress.increment == 1 - assert loop_parent.loop_child.progress.increment == 0 + assert loop_parent.loop_child.progress.increment == 1 del loop_parent.loop_child state_dict = loop_parent.state_dict() - assert state_dict == OrderedDict([('state_dict', {'a': 1}), ('progress', {'increment': 1})]) + assert state_dict == {'state_dict': {'a': 1}, 'progress': {'increment': 1}} - grand_loop_parent = Simple(0) - loop_parent = Simple(1) - loop_child = Simple(2) - grand_loop_parent.loop_child = loop_parent - loop_parent.loop_child = loop_child - grand_loop_parent.trainer = Trainer() - assert loop_child.trainer is not None +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.parametrize("stop_epoch", (1, 2)) +@pytest.mark.parametrize("stop_batch", (1, 2)) +@pytest.mark.parametrize("n_dataloaders,stop_dataloader", [(2, 0), (2, 1), (3, 2)]) +def test_loop_restart_progress_multiple_dataloaders(tmpdir, n_dataloaders, stop_dataloader, stop_epoch, stop_batch): + n_batches = 5 + n_epochs = 3 + + class ValidationModel(BoringModel): + + def __init__(self): + super().__init__() + + def validation_step(self, batch, batch_idx, dataloader_idx): + if self.current_epoch == stop_epoch and batch_idx == stop_batch and dataloader_idx == stop_dataloader: + raise CustomException + return super().validation_step(batch, batch_idx) + + def val_dataloader(self): + return [super(ValidationModel, self).val_dataloader() for _ in range(n_dataloaders)] + + model = ValidationModel() + model.validation_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=n_epochs, + limit_train_batches=1, + limit_val_batches=n_batches, + num_sanity_val_steps=0, + ) + + # simulate a failure + try: + trainer.fit(model) + except CustomException: + pass + + ckpt_path = str(tmpdir / '.pl_auto_save.ckpt') + checkpoint = torch.load(ckpt_path)["loops"]["fit_loop"] + + total_dataloader = stop_epoch * n_dataloaders + stop_dataloader + expected = { + "total": { + "ready": total_dataloader + 1, + "started": None, + "processed": None, + "completed": total_dataloader + }, + "current": { + "ready": stop_dataloader + 1, + "started": None, + "processed": None, + "completed": stop_dataloader, + }, + } + assert checkpoint["epoch_loop.val_loop.dataloader_progress"] == expected + + trainer.fit_loop.load_state_dict(checkpoint, restart_progress=False) + + # `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch + nbe_total_val_batch = stop_epoch * n_dataloaders * n_batches + be_total_val_batch = stop_dataloader * n_batches + stop_batch + total_val_batch = nbe_total_val_batch + be_total_val_batch + expected = { + "total": { + "ready": total_val_batch + 1, + "started": total_val_batch + 1, + "processed": total_val_batch, + "completed": total_val_batch + }, + "current": { + "ready": stop_batch + 1, + "started": stop_batch + 1, + "processed": stop_batch, + "completed": stop_batch, + }, + } + assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected + + trainer.fit_loop.load_state_dict(checkpoint) + expected = { + "total": { + "ready": total_val_batch, + "started": total_val_batch, + "processed": total_val_batch, + "completed": total_val_batch + }, + "current": { + "ready": stop_batch, + "started": stop_batch, + "processed": stop_batch, + "completed": stop_batch + }, + } + assert trainer.fit_loop.epoch_loop.val_loop.epoch_loop.batch_progress.state_dict() == expected + + +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +@pytest.mark.parametrize("accumulate_grad_batches", (1, 2, 3)) +@pytest.mark.parametrize("n_optimizers", (1, 3, 5)) +@pytest.mark.parametrize("stop_epoch", (1, 2)) +@pytest.mark.parametrize("stop_batch", (1, 2)) +@pytest.mark.parametrize("stop_optimizer", (1, 2)) +def test_loop_state_on_exception(accumulate_grad_batches, stop_epoch, stop_batch, stop_optimizer, n_optimizers, tmpdir): + stop_optimizer = stop_optimizer if stop_optimizer < n_optimizers else 0 + n_epochs = 3 + n_batches = 3 + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + if n_optimizers > 1: + self.configure_optimizers = self.configure_optimizers_multiple + + def training_step(self, batch, batch_idx, optimizer_idx=0): + if self.trainer.current_epoch == stop_epoch and batch_idx == stop_batch and optimizer_idx == stop_optimizer: + raise CustomException + return super().training_step(batch, batch_idx) + + def configure_optimizers_multiple(self): + optimizers = [torch.optim.Adam(self.layer.parameters(), lr=0.1) for _ in range(n_optimizers)] + + lr_scheduler_0 = torch.optim.lr_scheduler.StepLR(optimizers[0], step_size=1) + lr_scheduler_1 = torch.optim.lr_scheduler.StepLR(optimizers[1], step_size=1) + # no scheduler for optimizer_2 + lr_schedulers = [lr_scheduler_0, {"scheduler": lr_scheduler_1, "interval": "step"}] + + return optimizers, lr_schedulers + + model = TestModel() + model.training_epoch_end = None + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=n_epochs, + limit_train_batches=n_batches, + limit_val_batches=0, + accumulate_grad_batches=accumulate_grad_batches, + progress_bar_refresh_rate=0, + logger=False, + checkpoint_callback=False, + ) + + # simulate a failure + try: + trainer.fit(model) + except CustomException: + pass + + ckpt_path = str(tmpdir / ".pl_auto_save.ckpt") + checkpoint = torch.load(ckpt_path) + + optim_progress = trainer.fit_loop.epoch_loop.batch_loop.optim_progress + sch_progress = trainer.fit_loop.epoch_loop.scheduler_progress + + # `nbe_`: non-breaking epoch, as in, no exception will be raised. `be_`: breaking epoch + nbe_batches_completed = stop_epoch * n_batches + be_batches_completed = stop_batch + be_batches_ready = stop_batch + 1 + # lightning applies leftover accumulated gradients when the epoch ends + has_leftover_accumulation_batches = n_batches % accumulate_grad_batches != 0 + # number of batches that will call `optimizer.step()` during non-breaking and breaking epochs + nbe_stepping_batches = nbe_batches_completed // accumulate_grad_batches + be_stepping_batches = be_batches_completed // accumulate_grad_batches + + nbe_total_opt_steps = (nbe_stepping_batches + has_leftover_accumulation_batches) * n_optimizers + does_last_be_batch_step = be_batches_ready % accumulate_grad_batches == 0 or has_leftover_accumulation_batches + be_total_opt_steps = be_stepping_batches * n_optimizers + does_last_be_batch_step * stop_optimizer + assert optim_progress.optimizer_steps == nbe_total_opt_steps + be_total_opt_steps + assert optim_progress.optimizer.step.current.completed == be_total_opt_steps + has_opt_stepped_in_be = stop_batch + 1 >= accumulate_grad_batches + + nbe_total_zero_grad = (nbe_stepping_batches + has_leftover_accumulation_batches) * n_optimizers + does_last_be_batch_zero_grad = be_batches_completed % accumulate_grad_batches == 0 + # `max` because the first batch always zero-grads + be_total_zero_grad = max(1, be_stepping_batches) * n_optimizers + stop_optimizer * does_last_be_batch_zero_grad + assert optim_progress.optimizer.zero_grad.total.completed == nbe_total_zero_grad + be_total_zero_grad + assert optim_progress.optimizer.zero_grad.current.completed == be_total_zero_grad + + nbe_sch_steps = stop_epoch + be_sch_steps = 0 # the current epoch did not complete + if n_optimizers > 1: + # assumes that the scheduler config is unchanged + # `* 1` because there is only one step-level scheduler + nbe_sch_steps = stop_epoch + nbe_stepping_batches + has_leftover_accumulation_batches * 1 + # `0 +` for the epoch-level scheduler + be_sch_steps = 0 + be_stepping_batches + assert sch_progress.total.completed == nbe_sch_steps + be_sch_steps + assert sch_progress.current.completed == be_sch_steps + + # yapf: disable + expected = { + "state_dict": ANY, + "epoch_progress": { + "total": { + "ready": stop_epoch + 1, + "started": stop_epoch + 1, + "processed": stop_epoch, + "completed": stop_epoch, + }, + "current": { + "ready": stop_epoch + 1, + "started": stop_epoch + 1, + "processed": stop_epoch, + "completed": stop_epoch, + }, + }, + "epoch_loop.state_dict": ANY, + "epoch_loop.batch_progress": { + "total": { + "ready": nbe_batches_completed + be_batches_completed + 1, + "started": nbe_batches_completed + be_batches_completed + 1, + "processed": nbe_batches_completed + be_batches_completed, + "completed": nbe_batches_completed + be_batches_completed, + }, + "current": { + "ready": stop_batch + 1, + "started": stop_batch + 1, + "processed": stop_batch, + "completed": stop_batch, + }, + }, + "epoch_loop.scheduler_progress": { + "total": { + "ready": nbe_sch_steps + be_sch_steps, + "started": None, + "processed": None, + "completed": nbe_sch_steps + be_sch_steps, + }, + "current": { + "ready": be_sch_steps, + "started": None, + "processed": None, + "completed": be_sch_steps, + }, + }, + "epoch_loop.batch_loop.state_dict": ANY, + "epoch_loop.batch_loop.optim_progress": { + "optimizer_idx": stop_optimizer, + "optimizer": { + "step": { + "total": { + "ready": nbe_total_opt_steps + be_total_opt_steps + has_opt_stepped_in_be, + "started": None, + "processed": None, + "completed": nbe_total_opt_steps + be_total_opt_steps, + }, + "current": { + "ready": be_total_opt_steps + has_opt_stepped_in_be, + "started": None, + "processed": None, + "completed": be_total_opt_steps, + }, + }, + "zero_grad": { + "total": { + "ready": nbe_total_zero_grad + be_total_zero_grad, + "started": nbe_total_zero_grad + be_total_zero_grad, + "processed": None, + "completed": nbe_total_zero_grad + be_total_zero_grad, + }, + "current": { + "ready": be_total_zero_grad, + "started": be_total_zero_grad, + "processed": None, + "completed": be_total_zero_grad, + }, + }, + }, + }, + "epoch_loop.val_loop.state_dict": ANY, + "epoch_loop.val_loop.dataloader_progress": ANY, + "epoch_loop.val_loop.epoch_loop.state_dict": ANY, + "epoch_loop.val_loop.epoch_loop.batch_progress": ANY, + } + # yapf: enable + assert checkpoint["loops"]["fit_loop"] == expected + + trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"], restart_progress=False) + assert trainer.fit_loop.state_dict() == checkpoint["loops"]["fit_loop"] + + trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) + state_dict = trainer.fit_loop.state_dict() + assert state_dict != checkpoint["loops"]["fit_loop"] + # TODO(@carmocca): do not reset for total + assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch diff --git a/tests/trainer/test_progress.py b/tests/trainer/test_progress.py index a3bbd5a36a2c1..4057a2a686134 100644 --- a/tests/trainer/test_progress.py +++ b/tests/trainer/test_progress.py @@ -15,18 +15,10 @@ import pytest -from pytorch_lightning.trainer.progress import ( - BatchProgress, - EpochLoopProgress, - EpochProgress, - FitLoopProgress, - OptimizerProgress, - Progress, - Tracker, -) - - -def test_progress_geattr_setattr(): +from pytorch_lightning.trainer.progress import BaseProgress, OptimizerProgress, Progress, Tracker + + +def test_progress_getattr_setattr(): p = Tracker(ready=10, completed=None) # can read assert p.completed is None @@ -70,56 +62,25 @@ def test_base_progress_from_defaults(): assert actual == expected -def test_epoch_loop_progress_increment_epoch(): - p = EpochLoopProgress() - p.increment_epoch_completed() - p.increment_epoch_completed() - assert p.epoch.total == Tracker(completed=2) - assert p.epoch.current == Tracker() - assert p.epoch.batch.current == Tracker() - - def test_epoch_loop_progress_increment_sequence(): """Test sequences for incrementing batches reads and epochs.""" - batch = BatchProgress(total=Tracker(started=None)) - epoch = EpochProgress(batch=batch) - loop = EpochLoopProgress(epoch=epoch) + batch = Progress() batch.increment_ready() - assert batch.total == Tracker(ready=1, started=None) + assert batch.total == Tracker(ready=1) assert batch.current == Tracker(ready=1) batch.increment_started() - assert batch.total == Tracker(ready=1, started=None) - assert batch.current == Tracker(ready=1) + assert batch.total == Tracker(ready=1, started=1) + assert batch.current == Tracker(ready=1, started=1) batch.increment_processed() - assert batch.total == Tracker(ready=1, started=None, processed=1) - assert batch.current == Tracker(ready=1, processed=1) + assert batch.total == Tracker(ready=1, started=1, processed=1) + assert batch.current == Tracker(ready=1, started=1, processed=1) batch.increment_completed() - assert batch.total == Tracker(ready=1, started=None, processed=1, completed=1) - assert batch.current == Tracker(ready=1, processed=1, completed=1) - - assert epoch.total == Tracker() - assert epoch.current == Tracker() - loop.increment_epoch_completed() - assert batch.total == Tracker(ready=1, started=None, processed=1, completed=1) - assert batch.current == Tracker() - assert epoch.total == Tracker(completed=1) - assert epoch.current == Tracker() - - batch.increment_ready() - assert batch.total == Tracker(ready=2, started=None, processed=1, completed=1) - assert batch.current == Tracker(ready=1) - assert epoch.total == Tracker(completed=1) - assert epoch.current == Tracker() - - loop.reset_on_epoch() - assert batch.total == Tracker(ready=2, started=None, processed=1, completed=1) - assert batch.current == Tracker() - assert epoch.total == Tracker(completed=1) - assert epoch.current == Tracker() + assert batch.total == Tracker(ready=1, started=1, processed=1, completed=1) + assert batch.current == Tracker(ready=1, started=1, processed=1, completed=1) def test_optimizer_progress_default_factory(): @@ -135,93 +96,7 @@ def test_optimizer_progress_default_factory(): assert p2.step.total.completed == 0 -def test_fit_loop_progress_serialization(): - fit_loop = FitLoopProgress() - _ = deepcopy(fit_loop) - fit_loop.epoch.increment_completed() # check `TrainingEpochProgress.load_state_dict` calls `super` - - state_dict = fit_loop.state_dict() - # yapf: disable - assert state_dict == { - 'epoch': { - # number of epochs across `fit` calls - 'total': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0}, - # number of epochs this `fit` call - 'current': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0}, - 'batch': { - # number of batches across `fit` calls - 'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - # number of batches this epoch - 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - }, - # `fit` optimization progress - 'optim': { - # optimizers progress - 'optimizer': { - 'step': { - # `optimizer.step` calls across `fit` calls - 'total': {'completed': 0, 'processed': None, 'ready': 0, 'started': 0}, - # `optimizer.step` calls this epoch - 'current': {'completed': 0, 'processed': None, 'ready': 0, 'started': 0}, - }, - 'zero_grad': { - # `optimizer.zero_grad` calls across `fit` calls - 'total': {'completed': 0, 'processed': None, 'ready': 0, 'started': 0}, - # `optimizer.zero_grad` calls this epoch - 'current': {'completed': 0, 'processed': None, 'ready': 0, 'started': 0}, - }, - }, - 'scheduler': { - # `scheduler.step` calls across `fit` calls - 'total': {'completed': 0, 'processed': None, 'ready': 0, 'started': None}, - # `scheduler.step` calls this epoch - 'current': {'completed': 0, 'processed': None, 'ready': 0, 'started': None}, - }, - }, - # `fit` validation progress - 'val': { - 'epoch': { - # number of `validation` calls across `fit` calls - 'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - # number of `validation` calls this `fit` call - 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - 'batch': { - # number of batches across `fit` `validation` calls - 'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - # number of batches this `fit` `validation` call - 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - }, - } - }, - } - } - # yapf: enable - - new_loop = FitLoopProgress.from_state_dict(state_dict) - assert fit_loop == new_loop - - -def test_epoch_loop_progress_serialization(): - loop = EpochLoopProgress() - _ = deepcopy(loop) - state_dict = loop.state_dict() - - # yapf: disable - assert state_dict == { - 'epoch': { - # number of times `validate` has been called - 'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - # either 0 or 1 as `max_epochs` does not apply to the `validate` loop - 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - 'batch': { - # number of batches across `validate` calls - 'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - # number of batches this `validate` call - 'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0}, - }, - } - } - # yapf: enable - - new_loop = EpochLoopProgress.from_state_dict(state_dict) - assert loop == new_loop +def test_deepcopy(): + _ = deepcopy(BaseProgress()) + _ = deepcopy(Progress()) + _ = deepcopy(Tracker())