From 848ea5942e2f99e557a78af5b9e0a97cc978e9a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 31 Jan 2021 04:14:12 +0100 Subject: [PATCH 1/8] running stage --- pytorch_lightning/core/lightning.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bd97b7951cfa8..c7984deaa1924 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,6 +38,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin @@ -103,7 +104,6 @@ def __init__(self, *args, **kwargs): self._running_manual_backward = False self._current_hook_fx_name = None self._current_dataloader_idx = None - self.running_stage = None self._automatic_optimization: bool = True def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]: @@ -169,6 +169,10 @@ def automatic_optimization(self) -> bool: """ return self._automatic_optimization + @property + def running_stage(self) -> RunningStage: + return self.trainer._running_stage if self.trainer else None + @automatic_optimization.setter def automatic_optimization(self, automatic_optimization: bool) -> None: self._automatic_optimization = automatic_optimization From 89bb3c4353d6dd6078a147446da46f8c299c5f98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 31 Jan 2021 04:24:07 +0100 Subject: [PATCH 2/8] circular import --- pytorch_lightning/core/lightning.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c7984deaa1924..f0c25f6c70f2c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,7 +38,6 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result -from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin @@ -170,7 +169,7 @@ def automatic_optimization(self) -> bool: return self._automatic_optimization @property - def running_stage(self) -> RunningStage: + def running_stage(self): return self.trainer._running_stage if self.trainer else None @automatic_optimization.setter From e8e6130128521832f448de351c5f50ea80179b0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Feb 2021 01:51:59 +0100 Subject: [PATCH 3/8] running stage cleanup --- pytorch_lightning/trainer/trainer.py | 26 +++++++--------------- pytorch_lightning/trainer/training_loop.py | 4 ++-- 2 files changed, 10 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 10545a075cb32..d8adb093313e1 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -450,7 +450,7 @@ def fit( # bookkeeping # we reuse fit in .test() and .predict(). When already set, it shouldn't be modified. if self._running_stage is None: - self._set_running_stage(RunningStage.TRAINING, model) + self._running_stage = RunningStage.TRAINING # set local properties on the model self.model_connector.copy_trainer_model_properties(model) @@ -531,7 +531,7 @@ def fit( if self._state != TrainerState.INTERRUPTED: self._state = TrainerState.FINISHED - self._set_running_stage(None, model) + self._running_stage = None return self.accelerator.results or 1 @@ -564,14 +564,6 @@ def train_or_test_or_predict(self): return results - def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule): - """ - This function is used to set the running_state on both - the trainer and the model - """ - model_ref.running_stage = stage - self._running_stage = stage - def _pre_training_routine(self): # wait for all to join if on distributed self.accelerator.barrier("setup_training") @@ -614,7 +606,7 @@ def run_train(self): self.run_sanity_check(self.lightning_module) # set stage for logging - self._set_running_stage(RunningStage.TRAINING, self.lightning_module) + self._running_stage = RunningStage.TRAINING self.checkpoint_connector.has_trained = False @@ -678,9 +670,7 @@ def run_train(self): def run_evaluation(self, max_batches=None, on_epoch=False): # used to know if we are logging for val, test + reset cached results - self._set_running_stage( - RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.lightning_module - ) + self._running_stage = RunningStage.TESTING if self.testing else RunningStage.EVALUATING self.logger_connector.reset() # bookkeeping @@ -907,7 +897,7 @@ def test( # -------------------- self.verbose_test = verbose - self._set_running_stage(RunningStage.TESTING, model or self.lightning_module) + self._running_stage = RunningStage.TESTING # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: @@ -924,7 +914,7 @@ def test( results = self.__test_using_best_weights(ckpt_path, test_dataloaders) self.teardown('test') - self._set_running_stage(None, model or self.lightning_module) + self._running_stage = None return results def __test_using_best_weights(self, ckpt_path, test_dataloaders): @@ -1016,7 +1006,7 @@ def predict( model = model or self.lightning_module - self._set_running_stage(RunningStage.PREDICTING, model) + self._running_stage = RunningStage.PREDICTING if dataloaders and datamodule: raise MisconfigurationException( @@ -1033,7 +1023,7 @@ def predict( self.model = model results = self.fit(model) - self._set_running_stage(None, model) + self._running_stage = None return results diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9d10a1f67c5dc..d2298c8c4e860 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -517,7 +517,7 @@ def run_training_epoch(self): self.trainer.run_evaluation() # reset stage to train - self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module) + self.trainer._running_stage = RunningStage.TRAINING # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -564,7 +564,7 @@ def run_training_epoch(self): self.trainer.run_evaluation(on_epoch=True) # reset stage to train - self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module) + self.trainer._running_stage = RunningStage.TRAINING should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches) should_train_only = self.trainer.disable_validation or should_skip_eval From bab769156e7279cb03d0e3ac294ed7a4f4b135ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 19 Feb 2021 01:56:41 +0100 Subject: [PATCH 4/8] fix unused import --- pytorch_lightning/trainer/trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index d8adb093313e1..cf3bfd7a3e5a3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -59,7 +59,6 @@ from pytorch_lightning.utilities import DeviceType, rank_zero_warn from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.debugging import InternalDebugger -from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.model_helpers import is_overridden From cbdf2a854ffc70e7ccaa6d6b4be5d24313ebdef1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 20 Feb 2021 20:50:50 +0100 Subject: [PATCH 5/8] fix running stage access --- tests/models/test_restore.py | 2 +- tests/overrides/test_data_parallel.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index d28ab6177f21c..a3f88e37bb09a 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -453,7 +453,7 @@ def on_train_start(self): # haven't trained with the new loaded model dp_model = new_trainer.model dp_model.eval() - dp_model.module.module.running_stage = RunningStage.EVALUATING + new_trainer._running_stage = RunningStage.EVALUATING dataloader = self.train_dataloader() tpipes.run_prediction(self.trainer.lightning_module, dataloader) diff --git a/tests/overrides/test_data_parallel.py b/tests/overrides/test_data_parallel.py index 64481bd70390d..90bb6fac88457 100644 --- a/tests/overrides/test_data_parallel.py +++ b/tests/overrides/test_data_parallel.py @@ -1,4 +1,4 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock import pytest import torch @@ -103,7 +103,8 @@ def training_step(self, batch, batch_idx): return {"loss": loss} model = TestModel() - model.running_stage = RunningStage.TRAINING + model.trainer = Mock() + model.trainer._running_stage = RunningStage.TRAINING batch = torch.rand(2, 32).cuda() batch_idx = 0 @@ -146,7 +147,8 @@ def training_step(self, batch, batch_idx): model = TestModel() model.to(device) - model.running_stage = RunningStage.TRAINING + model.trainer = Mock() + model.trainer._running_stage = RunningStage.TRAINING batch = torch.rand(2, 32).to(device) batch_idx = 0 From 65b0fe269c6547213e34b6a88b97bee31cdfe8c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 21 Feb 2021 02:54:11 +0100 Subject: [PATCH 6/8] add return type --- pytorch_lightning/core/lightning.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f0c25f6c70f2c..3b79dd30d0913 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,6 +38,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result +from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin @@ -169,7 +170,7 @@ def automatic_optimization(self) -> bool: return self._automatic_optimization @property - def running_stage(self): + def running_stage(self) -> Optional[RunningStage]: return self.trainer._running_stage if self.trainer else None @automatic_optimization.setter From 56446866ef6772c6ffe9addb67ab8b86dc1e36fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 21 Feb 2021 02:54:11 +0100 Subject: [PATCH 7/8] Revert "add return type" This reverts commit 65b0fe269c6547213e34b6a88b97bee31cdfe8c7. --- pytorch_lightning/core/lightning.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 3b79dd30d0913..f0c25f6c70f2c 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -38,7 +38,6 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITIVE_TYPES from pytorch_lightning.core.step_result import Result -from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin @@ -170,7 +169,7 @@ def automatic_optimization(self) -> bool: return self._automatic_optimization @property - def running_stage(self) -> Optional[RunningStage]: + def running_stage(self): return self.trainer._running_stage if self.trainer else None @automatic_optimization.setter From 74ddc390fd757f0d364ed7c6c33dcb755d9b7318 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 21 Feb 2021 03:08:56 +0100 Subject: [PATCH 8/8] try fix typing --- pytorch_lightning/core/lightning.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f0c25f6c70f2c..4a2d919da2a88 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -24,7 +24,7 @@ from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch from torch import ScriptModule, Tensor @@ -44,6 +44,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args +if TYPE_CHECKING: + from pytorch_lightning.trainer.states import RunningStage + class LightningModule( ABC, @@ -169,7 +172,7 @@ def automatic_optimization(self) -> bool: return self._automatic_optimization @property - def running_stage(self): + def running_stage(self) -> Optional["RunningStage"]: return self.trainer._running_stage if self.trainer else None @automatic_optimization.setter