Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor access to trainer attributes in LightningModule #5730

Merged
merged 12 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516))


- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730))


- Refactored Accelerators and Plugins
* Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715))
* Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod ([#5714](https://github.com/PyTorchLightning/pytorch-lightning/pull/5714))
Expand Down
21 changes: 18 additions & 3 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ class LightningModule(
"current_epoch",
"global_step",
"running_stage",
"global_rank",
"local_rank",
"logger",
] + DeviceDtypeModuleMixin.__jit_unused_properties__

def __init__(self, *args, **kwargs):
Expand All @@ -83,9 +86,6 @@ def __init__(self, *args, **kwargs):
#: Pointer to the trainer object
self.trainer = None

#: Pointer to the logger object
self.logger = None

self._distrib_type = None
self._device_type = None

Expand Down Expand Up @@ -132,6 +132,16 @@ def global_step(self) -> int:
"""Total training batches seen across all epochs"""
return self.trainer.global_step if self.trainer else 0

@property
def global_rank(self) -> int:
""" The index of the current process across all nodes and devices. """
return self.trainer.global_rank if self.trainer else 0

@property
def local_rank(self) -> int:
""" The index of the current process within a single node. """
return self.trainer.local_rank if self.trainer else 0

@example_input_array.setter
def example_input_array(self, example: Any) -> None:
self._example_input_array = example
Expand Down Expand Up @@ -163,6 +173,11 @@ def automatic_optimization(self) -> bool:
def automatic_optimization(self, automatic_optimization: bool) -> None:
self._automatic_optimization = automatic_optimization

@property
def logger(self):
""" Reference to the logger object in the Trainer. """
return self.trainer.logger if self.trainer else None

def print(self, *args, **kwargs) -> None:
r"""
Prints only from process 0. Use this in any distributed mode to log only once.
Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/trainer/connectors/model_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.

"""
from weakref import proxy


class ModelConnector:
Expand All @@ -30,17 +31,14 @@ def copy_trainer_model_properties(self, model):
self.trainer.train_loop.automatic_optimization = automatic_optimization

for m in [model, ref_model]:
m.trainer = self.trainer
m.logger = self.trainer.logger
m.trainer = proxy(self.trainer)
m._device_type = str(self.trainer._device_type)
m._distrib_type = str(self.trainer._distrib_type)
m.use_amp = self.trainer.amp_backend is not None
m.testing = self.trainer.testing
m.tpu_local_core_rank = self.trainer.tpu_local_core_rank
m.tpu_global_core_rank = self.trainer.tpu_global_core_rank
m.precision = self.trainer.precision
m.global_rank = self.trainer.global_rank
m.local_rank = self.trainer.local_rank

def get_model(self):
return self._get_reference_model(self.trainer.model)
Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/tuner/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,10 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule):
val_dataloaders=val_dataloaders,
datamodule=datamodule,
)
model.logger = self.trainer.logger # reset logger binding

# Run learning rate finder:
if self.trainer.auto_lr_find:
self.internal_find_lr(model)
model.logger = self.trainer.logger # reset logger binding

def scale_batch_size(
self,
Expand Down
54 changes: 53 additions & 1 deletion tests/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import patch
from unittest.mock import patch, Mock, PropertyMock

import pytest
from torch.optim import Adam, SGD

from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel


def test_property_current_epoch():
Borda marked this conversation as resolved.
Show resolved Hide resolved
""" Test that the current_epoch in LightningModule is accessible via the Trainer. """
model = BoringModel()
assert model.current_epoch == 0

trainer = Mock(current_epoch=123)
model.trainer = trainer
assert model.current_epoch == 123


def test_property_global_step():
""" Test that the global_step in LightningModule is accessible via the Trainer. """
model = BoringModel()
assert model.global_step == 0

trainer = Mock(global_step=123)
model.trainer = trainer
assert model.global_step == 123


def test_property_global_rank():
""" Test that the global rank in LightningModule is accessible via the Trainer. """
model = BoringModel()
assert model.global_rank == 0

trainer = Mock(global_rank=123)
model.trainer = trainer
assert model.global_rank == 123


def test_property_local_rank():
""" Test that the local rank in LightningModule is accessible via the Trainer. """
model = BoringModel()
assert model.local_rank == 0

trainer = Mock(local_rank=123)
model.trainer = trainer
assert model.local_rank == 123


def test_property_logger(tmpdir):
""" Test that the logger in LightningModule is accessible via the Trainer. """
model = BoringModel()
assert model.logger is None

logger = TensorBoardLogger(tmpdir)
trainer = Mock(logger=logger)
model.trainer = trainer
assert model.logger == logger


def test_automatic_optimization(tmpdir):
class TestModel(BoringModel):
def optimizer_step(self, *_, **__):
Expand Down
2 changes: 2 additions & 0 deletions tests/trainer/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def test_trainer_reset_correctly(tmpdir):
assert attributes_before[key] == attributes_after[key], \
f'Attribute {key} was not reset correctly after learning rate finder'

assert model.trainer == trainer


@pytest.mark.parametrize('use_hparams', [False, True])
def test_trainer_arg_bool(tmpdir, use_hparams):
Expand Down