Skip to content

Commit

Permalink
Fix setup callback hook to pass LightningModule through (#4608)
Browse files Browse the repository at this point in the history
* Fix setup callback hook

* Update CHANGELOG.md

* Update test_trainer.py

* Update test_trainer.py

* Update test_trainer.py

* fix chlog

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
2 people authored and tchaton committed Nov 17, 2020
1 parent 5d45a47 commit c3335ec
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 45 deletions.
28 changes: 28 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,40 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed

- Ci: tpu drop install horovod ([#4622](https://github.com/PyTorchLightning/pytorch-lightning/pull/4622))
- Ci: Added isort import check for the code on pull-request ([#4242](https://github.com/PyTorchLightning/pytorch-lightning/pull/4242))


### Fixed

- Prevent crash if `sync_dist=True` on CPU ([#4626](https://github.com/PyTorchLightning/pytorch-lightning/pull/4626))
- Fixed average pbar Metrics ([#4534](https://github.com/PyTorchLightning/pytorch-lightning/pull/4534))
- Fixed logger docs and api docs ([#3950](https://github.com/PyTorchLightning/pytorch-lightning/pull/3950))
- Fixed `setup` callback hook to correctly pass the LightningModule through ([#4608](https://github.com/PyTorchLightning/pytorch-lightning/pull/4608))




## [unreleased.BugFix] - YYYY-MM-DD

### Added



### Changed



### Deprecated



### Removed



### Fixed



## [1.0.6] - 2020-11-11

Expand Down
24 changes: 9 additions & 15 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

"""nn.Module with additional great features."""

import os
import tempfile
import collections
import copy
import inspect
Expand All @@ -25,28 +23,24 @@
import types
from abc import ABC
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, Mapping
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer

from pytorch_lightning import _logger as log
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.utilities import rank_zero_warn, AMPType
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.parsing import (
AttributeDict,
collect_init_args,
get_init_args,
)
from torch import ScriptModule, Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer

from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
from pytorch_lightning.utilities.xla_device_utils import XLADeviceUtils

TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ class TrainerCallbackHookMixin(ABC):
callbacks: List[Callback] = []
get_model: Callable

def setup(self, stage: str):
def setup(self, model, stage: str):
"""Called in the beginning of fit and test"""
for callback in self.callbacks:
callback.setup(self, self.get_model(), stage)
callback.setup(self, model, stage)

def teardown(self, stage: str):
"""Called at the end of fit and test"""
Expand Down
44 changes: 22 additions & 22 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,46 +21,46 @@
import torch
from torch.utils.data import DataLoader

from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.accelerators.cpu_accelerator import CPUAccelerator
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import EvalResult
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.plugins.plugin_connector import PluginConnector
from pytorch_lightning.profiler import BaseProfiler
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector
from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
from pytorch_lightning.trainer.connectors.precision_connector import PrecisionConnector
from pytorch_lightning.trainer.connectors.profiler_connector import ProfilerConnector
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.properties import TrainerProperties
from pytorch_lightning.trainer.states import TrainerState, trainer_state
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning import _logger as log
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.trainer.connectors.precision_connector import PrecisionConnector
from pytorch_lightning.trainer.connectors.profiler_connector import ProfilerConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.properties import TrainerProperties
from pytorch_lightning.plugins.plugin_connector import PluginConnector
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.cpu_accelerator import CPUAccelerator

# warnings to ignore in trainer
warnings.filterwarnings(
Expand Down Expand Up @@ -822,7 +822,7 @@ def call_setup_hook(self, model):
called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit
if not called:
self.datamodule.setup(stage_name)
self.setup(stage_name)
self.setup(model, stage_name)
model.setup(stage_name)

def call_hook(self, hook_name, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions tests/trainer/optimization/test_manual_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
# limitations under the License.
import collections
import os
from unittest.mock import ANY, call, patch

import pytest
import torch
import torch.nn.functional as F
from unittest.mock import patch, call, ANY

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities import APEX_AVAILABLE
Expand Down Expand Up @@ -698,7 +698,7 @@ def configure_optimizers(self):

trainer.fit(model)
assert trainer.dev_debugger.count_events('backward_call') == limit_train_batches * 2
# TODO: Remove me on 1.1 - Releases/1.0.x currently has a bug fixed on master - Decided to wait for next feature releases
# todo: Remove me on 1.1 - Releases/1.0.x currently has a bug fixed on master - Decided to wait for next feature releases
# assert trainer.logger_connector.progress_bar_metrics["train_loss_step"] == model._losses[-1]
# assert trainer.logger_connector.progress_bar_metrics["train_loss_epoch"] == torch.stack(model._losses).mean()

Expand Down
9 changes: 5 additions & 4 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from argparse import Namespace
from copy import deepcopy
from pathlib import Path
from unittest.mock import patch, call, ANY
from unittest.mock import ANY, call, patch

import cloudpickle
import pytest
Expand All @@ -34,10 +34,10 @@
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE
from tests.base import EvalModelTemplate, BoringModel
from tests.base import BoringModel, EvalModelTemplate


@pytest.mark.parametrize("url_ckpt", [True, False])
Expand Down Expand Up @@ -1431,7 +1431,8 @@ def setup(self, stage):
self.stage = stage

class TrainerSubclass(Trainer):
def setup(self, stage):
def setup(self, model, stage):
assert model is not None
self.stage = stage

model = CurrentModel()
Expand Down

0 comments on commit c3335ec

Please sign in to comment.