diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a3803b4e8134..3f02e8dd3de3f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) * Add `{,load_}state_dict` to the progress tracking dataclasses ([#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140)) * Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362)) + * Do not reset the progress tracking dataclasses total counters ([#8475](https://github.com/PyTorchLightning/pytorch-lightning/pull/8475)) - Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431)) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index d3b6ce8a03c02..c4ec3c8dbc69a 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -18,7 +18,7 @@ from deprecate import void import pytorch_lightning as pl -from pytorch_lightning.trainer.progress import BaseProgress, Tracker +from pytorch_lightning.trainer.progress import BaseProgress, Progress from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -195,11 +195,6 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress: if isinstance(v, BaseProgress): v.load_state_dict(state_dict[prefix + k]) if restart_progress: - - def restart(tracker: Tracker): - tracker.reset_on_restart() - - apply_to_collection(v, Tracker, restart) - + apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart()) self.on_load_checkpoint(state_dict[prefix + "state_dict"]) self.restarting = True diff --git a/pytorch_lightning/trainer/progress.py b/pytorch_lightning/trainer/progress.py index fe9f90613ea9c..8120cfd31eeb3 100644 --- a/pytorch_lightning/trainer/progress.py +++ b/pytorch_lightning/trainer/progress.py @@ -65,12 +65,12 @@ def __setattr__(self, key: str, value: int) -> None: raise AttributeError(f"The '{key}' attribute is meant to be unused") return super().__setattr__(key, value) - def __repr__(self): + def __repr__(self) -> str: # hide `None` fields args = [f"{k}={v}" for k, v in self.__dict__.items() if v is not None] return f"{self.__class__.__name__}({', '.join(args)})" - def reset_on_restart(self): + def reset_on_restart(self) -> None: """Reset the progress on restart""" value = self.completed if self.processed is None else self.processed diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index 41d706e081ba0..26ef2c6fcc6cc 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -351,8 +351,8 @@ def val_dataloader(self): trainer.fit_loop.load_state_dict(checkpoint) expected = { "total": { - "ready": total_val_batch, - "started": total_val_batch, + "ready": total_val_batch + 1, + "started": total_val_batch + 1, "processed": total_val_batch, "completed": total_val_batch }, @@ -555,6 +555,5 @@ def configure_optimizers_multiple(self): trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"]) state_dict = trainer.fit_loop.state_dict() assert state_dict != checkpoint["loops"]["fit_loop"] - # TODO(@carmocca): do not reset for total - assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1 assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch