Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mini refactor for _running_stage access #5724

Merged
merged 9 commits into from
Feb 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from argparse import Namespace
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

import torch
from torch import ScriptModule, Tensor
Expand All @@ -44,6 +44,9 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args

if TYPE_CHECKING:
from pytorch_lightning.trainer.states import RunningStage


class LightningModule(
ABC,
Expand Down Expand Up @@ -103,7 +106,6 @@ def __init__(self, *args, **kwargs):
self._running_manual_backward = False
self._current_hook_fx_name = None
self._current_dataloader_idx = None
self.running_stage = None
self._automatic_optimization: bool = True

def optimizers(self, use_pl_optimizer: bool = True) -> Union[Optimizer, List[Optimizer], List[LightningOptimizer]]:
Expand Down Expand Up @@ -169,6 +171,10 @@ def automatic_optimization(self) -> bool:
"""
return self._automatic_optimization

@property
def running_stage(self) -> Optional["RunningStage"]:
return self.trainer._running_stage if self.trainer else None

@automatic_optimization.setter
def automatic_optimization(self, automatic_optimization: bool) -> None:
self._automatic_optimization = automatic_optimization
Expand Down
27 changes: 8 additions & 19 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -450,7 +449,7 @@ def fit(
# bookkeeping
# we reuse fit in .test() and .predict(). When already set, it shouldn't be modified.
if self._running_stage is None:
self._set_running_stage(RunningStage.TRAINING, model)
self._running_stage = RunningStage.TRAINING

# set local properties on the model
self.model_connector.copy_trainer_model_properties(model)
Expand Down Expand Up @@ -531,7 +530,7 @@ def fit(
if self._state != TrainerState.INTERRUPTED:
self._state = TrainerState.FINISHED

self._set_running_stage(None, model)
self._running_stage = None

return self.accelerator.results or 1

Expand Down Expand Up @@ -564,14 +563,6 @@ def train_or_test_or_predict(self):

return results

def _set_running_stage(self, stage: LightningEnum, model_ref: LightningModule):
"""
This function is used to set the running_state on both
the trainer and the model
"""
model_ref.running_stage = stage
self._running_stage = stage

def _pre_training_routine(self):
# wait for all to join if on distributed
self.accelerator.barrier("setup_training")
Expand Down Expand Up @@ -614,7 +605,7 @@ def run_train(self):
self.run_sanity_check(self.lightning_module)

# set stage for logging
self._set_running_stage(RunningStage.TRAINING, self.lightning_module)
self._running_stage = RunningStage.TRAINING

self.checkpoint_connector.has_trained = False

Expand Down Expand Up @@ -678,9 +669,7 @@ def run_train(self):
def run_evaluation(self, max_batches=None, on_epoch=False):

# used to know if we are logging for val, test + reset cached results
self._set_running_stage(
RunningStage.TESTING if self.testing else RunningStage.EVALUATING, self.lightning_module
)
self._running_stage = RunningStage.TESTING if self.testing else RunningStage.EVALUATING
self.logger_connector.reset()

# bookkeeping
Expand Down Expand Up @@ -907,7 +896,7 @@ def test(
# --------------------
self.verbose_test = verbose

self._set_running_stage(RunningStage.TESTING, model or self.lightning_module)
self._running_stage = RunningStage.TESTING

# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if test_dataloaders and datamodule:
Expand All @@ -924,7 +913,7 @@ def test(
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)

self.teardown('test')
self._set_running_stage(None, model or self.lightning_module)
self._running_stage = None
return results

def __test_using_best_weights(self, ckpt_path, test_dataloaders):
Expand Down Expand Up @@ -1016,7 +1005,7 @@ def predict(

model = model or self.lightning_module

self._set_running_stage(RunningStage.PREDICTING, model)
self._running_stage = RunningStage.PREDICTING

if dataloaders and datamodule:
raise MisconfigurationException(
Expand All @@ -1033,7 +1022,7 @@ def predict(

self.model = model
results = self.fit(model)
self._set_running_stage(None, model)
self._running_stage = None

return results

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def run_training_epoch(self):
self.trainer.run_evaluation()

# reset stage to train
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
self.trainer._running_stage = RunningStage.TRAINING

# -----------------------------------------
# SAVE LOGGERS (ie: Tensorboard, etc...)
Expand Down Expand Up @@ -564,7 +564,7 @@ def run_training_epoch(self):
self.trainer.run_evaluation(on_epoch=True)

# reset stage to train
self.trainer._set_running_stage(RunningStage.TRAINING, self.trainer.lightning_module)
self.trainer._running_stage = RunningStage.TRAINING

should_skip_eval = self.trainer.evaluation_loop.should_skip_evaluation(self.trainer.num_val_batches)
should_train_only = self.trainer.disable_validation or should_skip_eval
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ def on_train_start(self):
# haven't trained with the new loaded model
dp_model = new_trainer.model
dp_model.eval()
dp_model.module.module.running_stage = RunningStage.EVALUATING
new_trainer._running_stage = RunningStage.EVALUATING

dataloader = self.train_dataloader()
tpipes.run_prediction(self.trainer.lightning_module, dataloader)
Expand Down
8 changes: 5 additions & 3 deletions tests/overrides/test_data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock

import pytest
import torch
Expand Down Expand Up @@ -103,7 +103,8 @@ def training_step(self, batch, batch_idx):
return {"loss": loss}

model = TestModel()
model.running_stage = RunningStage.TRAINING
model.trainer = Mock()
model.trainer._running_stage = RunningStage.TRAINING
batch = torch.rand(2, 32).cuda()
batch_idx = 0

Expand Down Expand Up @@ -146,7 +147,8 @@ def training_step(self, batch, batch_idx):

model = TestModel()
model.to(device)
model.running_stage = RunningStage.TRAINING
model.trainer = Mock()
model.trainer._running_stage = RunningStage.TRAINING
batch = torch.rand(2, 32).to(device)
batch_idx = 0

Expand Down