Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not reset Loops total counters #8475

Merged
merged 2 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Added dataclasses for progress tracking ([#6603](https://github.com/PyTorchLightning/pytorch-lightning/pull/6603), [#7574](https://github.com/PyTorchLightning/pytorch-lightning/pull/7574), [#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Add `{,load_}state_dict` to the progress tracking dataclasses ([#8140](https://github.com/PyTorchLightning/pytorch-lightning/pull/8140))
* Connect the progress tracking dataclasses to the loops ([#8244](https://github.com/PyTorchLightning/pytorch-lightning/pull/8244), [#8362](https://github.com/PyTorchLightning/pytorch-lightning/pull/8362))
* Do not reset the progress tracking dataclasses total counters ([#8475](https://github.com/PyTorchLightning/pytorch-lightning/pull/8475))


- Added support for passing a `LightningDataModule` positionally as the second argument to `trainer.{validate,test,predict}` ([#7431](https://github.com/PyTorchLightning/pytorch-lightning/pull/7431))
Expand Down
9 changes: 2 additions & 7 deletions pytorch_lightning/loops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from deprecate import void

import pytorch_lightning as pl
from pytorch_lightning.trainer.progress import BaseProgress, Tracker
from pytorch_lightning.trainer.progress import BaseProgress, Progress
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -195,11 +195,6 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, restart_progress:
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[prefix + k])
if restart_progress:

def restart(tracker: Tracker):
tracker.reset_on_restart()

apply_to_collection(v, Tracker, restart)

apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart())
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
self.restarting = True
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def __setattr__(self, key: str, value: int) -> None:
raise AttributeError(f"The '{key}' attribute is meant to be unused")
return super().__setattr__(key, value)

def __repr__(self):
def __repr__(self) -> str:
# hide `None` fields
args = [f"{k}={v}" for k, v in self.__dict__.items() if v is not None]
return f"{self.__class__.__name__}({', '.join(args)})"

def reset_on_restart(self):
def reset_on_restart(self) -> None:
"""Reset the progress on restart"""
value = self.completed if self.processed is None else self.processed

Expand Down
7 changes: 3 additions & 4 deletions tests/loops/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ def val_dataloader(self):
trainer.fit_loop.load_state_dict(checkpoint)
expected = {
"total": {
"ready": total_val_batch,
"started": total_val_batch,
"ready": total_val_batch + 1,
"started": total_val_batch + 1,
"processed": total_val_batch,
"completed": total_val_batch
},
Expand Down Expand Up @@ -555,6 +555,5 @@ def configure_optimizers_multiple(self):
trainer.fit_loop.load_state_dict(checkpoint["loops"]["fit_loop"])
state_dict = trainer.fit_loop.state_dict()
assert state_dict != checkpoint["loops"]["fit_loop"]
# TODO(@carmocca): do not reset for total
assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch
assert state_dict["epoch_progress"]["total"]["started"] == stop_epoch + 1
assert state_dict["epoch_progress"]["current"]["started"] == stop_epoch