From 2a48730ee68a7a579a4da7f7ca5573b3b3d59f40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 1 Feb 2021 02:07:36 +0100 Subject: [PATCH 01/12] rank access --- pytorch_lightning/core/lightning.py | 10 ++++++++++ .../trainer/connectors/model_connector.py | 2 -- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c453bd5d607d6..b219e8396536f 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -67,6 +67,8 @@ class LightningModule( "current_epoch", "global_step", "running_stage", + "global_rank", + "local_rank", ] + DeviceDtypeModuleMixin.__jit_unused_properties__ def __init__(self, *args, **kwargs): @@ -132,6 +134,14 @@ def global_step(self) -> int: """Total training batches seen across all epochs""" return self.trainer.global_step if self.trainer else 0 + @property + def global_rank(self): + return self.trainer.global_rank if self.trainer else 0 + + @property + def local_rank(self): + return self.trainer.local_rank if self.trainer else 0 + @example_input_array.setter def example_input_array(self, example: Any) -> None: self._example_input_array = example diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index a3759d1075ee5..f14b60801781b 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -39,8 +39,6 @@ def copy_trainer_model_properties(self, model): m.tpu_local_core_rank = self.trainer.tpu_local_core_rank m.tpu_global_core_rank = self.trainer.tpu_global_core_rank m.precision = self.trainer.precision - m.global_rank = self.trainer.global_rank - m.local_rank = self.trainer.local_rank def get_model(self): return self._get_reference_model(self.trainer.model) From fb37c4f89c7d55017b77c84695390c57ba2faf31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 1 Feb 2021 02:35:32 +0100 Subject: [PATCH 02/12] tests for property --- tests/core/test_lightning_module.py | 42 ++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 9cea8cf28c07f..8a9e57f09b3a7 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import patch +from unittest.mock import patch, Mock, PropertyMock import pytest from torch.optim import Adam, SGD @@ -21,6 +21,46 @@ from tests.base import BoringModel +def test_property_current_epoch(): + """ Test that the current_epoch in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.current_epoch == 0 + + trainer = Mock(current_epoch=123) + model.trainer = trainer + assert model.current_epoch == 123 + + +def test_property_global_step(): + """ Test that the global_step in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.global_step == 0 + + trainer = Mock(global_step=123) + model.trainer = trainer + assert model.global_step == 123 + + +def test_property_global_rank(): + """ Test that the global rank in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.global_rank == 0 + + trainer = Mock(global_rank=123) + model.trainer = trainer + assert model.global_rank == 123 + + +def test_property_local_rank(): + """ Test that the local rank in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.local_rank == 0 + + trainer = Mock(local_rank=123) + model.trainer = trainer + assert model.local_rank == 123 + + def test_automatic_optimization(tmpdir): class TestModel(BoringModel): def optimizer_step(self, *_, **__): From 20245ebafdf0669cce0b5e096e289f3aec72dcc0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 1 Feb 2021 02:52:08 +0100 Subject: [PATCH 03/12] weekref --- pytorch_lightning/trainer/connectors/model_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index f14b60801781b..0db31edd33f12 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -17,6 +17,7 @@ Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU. """ +from weakref import proxy class ModelConnector: @@ -30,7 +31,7 @@ def copy_trainer_model_properties(self, model): self.trainer.train_loop.automatic_optimization = automatic_optimization for m in [model, ref_model]: - m.trainer = self.trainer + m.trainer = proxy(self.trainer) m.logger = self.trainer.logger m._device_type = str(self.trainer._device_type) m._distrib_type = str(self.trainer._distrib_type) From bd160b7fe11139af6a862776cad7e75243aab273 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 1 Feb 2021 03:19:57 +0100 Subject: [PATCH 04/12] logger --- pytorch_lightning/core/lightning.py | 14 +++++++++----- .../trainer/connectors/model_connector.py | 1 - pytorch_lightning/tuner/tuning.py | 2 -- tests/core/test_lightning_module.py | 12 ++++++++++++ tests/trainer/test_lr_finder.py | 2 ++ 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index b219e8396536f..92c2484252df4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -85,9 +85,6 @@ def __init__(self, *args, **kwargs): #: Pointer to the trainer object self.trainer = None - #: Pointer to the logger object - self.logger = None - self._distrib_type = None self._device_type = None @@ -135,11 +132,13 @@ def global_step(self) -> int: return self.trainer.global_step if self.trainer else 0 @property - def global_rank(self): + def global_rank(self) -> int: + """ The index of the current process across all nodes and devices. """ return self.trainer.global_rank if self.trainer else 0 @property - def local_rank(self): + def local_rank(self) -> int: + """ The index of the current process within a single node. """ return self.trainer.local_rank if self.trainer else 0 @example_input_array.setter @@ -173,6 +172,11 @@ def automatic_optimization(self) -> bool: def automatic_optimization(self, automatic_optimization: bool) -> None: self._automatic_optimization = automatic_optimization + @property + def logger(self): + """ Reference to the logger object in the Trainer. """ + return self.trainer.logger if self.trainer else None + def print(self, *args, **kwargs) -> None: r""" Prints only from process 0. Use this in any distributed mode to log only once. diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 0db31edd33f12..673e8765ed51f 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -32,7 +32,6 @@ def copy_trainer_model_properties(self, model): for m in [model, ref_model]: m.trainer = proxy(self.trainer) - m.logger = self.trainer.logger m._device_type = str(self.trainer._device_type) m._distrib_type = str(self.trainer._distrib_type) m.use_amp = self.trainer.amp_backend is not None diff --git a/pytorch_lightning/tuner/tuning.py b/pytorch_lightning/tuner/tuning.py index dae3fed868520..0567399970ae7 100644 --- a/pytorch_lightning/tuner/tuning.py +++ b/pytorch_lightning/tuner/tuning.py @@ -50,12 +50,10 @@ def tune(self, model, train_dataloader, val_dataloaders, datamodule): val_dataloaders=val_dataloaders, datamodule=datamodule, ) - model.logger = self.trainer.logger # reset logger binding # Run learning rate finder: if self.trainer.auto_lr_find: self.internal_find_lr(model) - model.logger = self.trainer.logger # reset logger binding def scale_batch_size( self, diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 8a9e57f09b3a7..8895d8475b656 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -17,6 +17,7 @@ from torch.optim import Adam, SGD from pytorch_lightning import Trainer +from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel @@ -61,6 +62,17 @@ def test_property_local_rank(): assert model.local_rank == 123 +def test_property_logger(tmpdir): + """ Test that the logger in LightningModule is accessible via the Trainer. """ + model = BoringModel() + assert model.logger is None + + logger = TensorBoardLogger(tmpdir) + trainer = Mock(logger=logger) + model.trainer = trainer + assert model.logger == logger + + def test_automatic_optimization(tmpdir): class TestModel(BoringModel): def optimizer_step(self, *_, **__): diff --git a/tests/trainer/test_lr_finder.py b/tests/trainer/test_lr_finder.py index 3b59095fcf393..228246fb18e4d 100755 --- a/tests/trainer/test_lr_finder.py +++ b/tests/trainer/test_lr_finder.py @@ -90,6 +90,8 @@ def test_trainer_reset_correctly(tmpdir): assert attributes_before[key] == attributes_after[key], \ f'Attribute {key} was not reset correctly after learning rate finder' + assert model.trainer == trainer + @pytest.mark.parametrize('use_hparams', [False, True]) def test_trainer_arg_bool(tmpdir, use_hparams): From 9cc9402b85a0710d48da421b8531f604f8a6a49e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 1 Feb 2021 03:38:42 +0100 Subject: [PATCH 05/12] changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 36ec81b9f3524..c2244c72dc9ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -107,6 +107,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516)) +- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#xxxx](https://github.com/PyTorchLightning/pytorch-lightning/pull/xxxx)) + - Refactored Accelerators and Plugins * Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715)) * Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod ([#5714](https://github.com/PyTorchLightning/pytorch-lightning/pull/5714)) From dfeb757ebe87a082a828f1d4fca44b08f70e4bf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 1 Feb 2021 03:47:59 +0100 Subject: [PATCH 06/12] torchscript --- pytorch_lightning/core/lightning.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 92c2484252df4..965dba8ad3a30 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -69,6 +69,7 @@ class LightningModule( "running_stage", "global_rank", "local_rank", + "logger", ] + DeviceDtypeModuleMixin.__jit_unused_properties__ def __init__(self, *args, **kwargs): From e697f774e64536da1e91c9c489262edd6f7873f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 1 Feb 2021 03:49:15 +0100 Subject: [PATCH 07/12] changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c2244c72dc9ed..bc1eb7c7e1877 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -107,7 +107,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516)) -- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#xxxx](https://github.com/PyTorchLightning/pytorch-lightning/pull/xxxx)) +- Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730)) + - Refactored Accelerators and Plugins * Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715)) From b07ad864d56c6a22ee7ee67b82b3501844d788e2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 1 Feb 2021 09:33:37 +0100 Subject: [PATCH 08/12] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc1eb7c7e1877..2bd713cb1d729 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -110,6 +110,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730)) +- Changed trainer-related attributes in LightningModule as read-only property: `global_rank`, `local_rank`, `logger` ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730)) + + - Refactored Accelerators and Plugins * Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715)) * Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod ([#5714](https://github.com/PyTorchLightning/pytorch-lightning/pull/5714)) From fade48d2b60bcd829ad76b48dda1abded1cb93db Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 1 Feb 2021 09:35:29 +0100 Subject: [PATCH 09/12] . --- CHANGELOG.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bd713cb1d729..bc1eb7c7e1877 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -110,9 +110,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Made `LightningModule.global_rank`, `LightningModule.local_rank` and `LightningModule.logger` read-only properties ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730)) -- Changed trainer-related attributes in LightningModule as read-only property: `global_rank`, `local_rank`, `logger` ([#5730](https://github.com/PyTorchLightning/pytorch-lightning/pull/5730)) - - - Refactored Accelerators and Plugins * Added base classes for plugins ([#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715)) * Added parallel plugins for DP, DDP, DDPSpawn, DDP2 and Horovod ([#5714](https://github.com/PyTorchLightning/pytorch-lightning/pull/5714)) From 13eb0c3e25b2ec1f7e59239208a1dc836bbc0560 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 1 Feb 2021 12:04:46 +0100 Subject: [PATCH 10/12] amp --- pytorch_lightning/plugins/precision/native_amp.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index daba223169fc6..d5f5fc707c066 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -18,9 +18,14 @@ from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin -from pytorch_lightning.utilities import AMPType +from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException +if _NATIVE_AMP_AVAILABLE: + from torch.cuda.amp import autocast +else: + autocast = None + class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): def __init__(self): @@ -74,6 +79,6 @@ def backward( return closure_loss @contextmanager - def train_step_context(self) -> Generator[torch.cuda.amp.autocast, None, None]: + def train_step_context(self) -> Generator[autocast, None, None]: """Enable autocast context""" yield torch.cuda.amp.autocast() From 4fda04a4f57d0e11821a7a56887557e7574d7999 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 1 Feb 2021 12:06:48 +0100 Subject: [PATCH 11/12] yapf --- pytorch_lightning/overrides/data_parallel.py | 4 ++-- pytorch_lightning/plugins/base_plugin.py | 3 ++- pytorch_lightning/plugins/precision/native_amp.py | 1 + pytorch_lightning/plugins/precision/precision_plugin.py | 3 ++- pytorch_lightning/plugins/precision/sharded_native_amp.py | 1 + 5 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 8d1710e471197..b027502f99e8a 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -30,8 +30,7 @@ class LightningDataParallel(DataParallel): def __init__(self, module: LightningModule, *args, **kwargs): warnings.warn( "The usage of `LightningDataParallel` is deprecated since v1.2 and will be removed in v1.4." - " From now on we recommend to directly subclass `torch.nn.parallel.DataParallel`.", - DeprecationWarning + " From now on we recommend to directly subclass `torch.nn.parallel.DataParallel`.", DeprecationWarning ) super().__init__(LightningParallelModule(module), *args, **kwargs) @@ -67,6 +66,7 @@ class LightningParallelModule(_LightningModuleWrapperBase): pl_module: the model to wrap """ + def __init__(self, pl_module: LightningModule): super().__init__(pl_module) diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py index 0160afa559496..bca155b750047 100644 --- a/pytorch_lightning/plugins/base_plugin.py +++ b/pytorch_lightning/plugins/base_plugin.py @@ -22,7 +22,8 @@ class Plugin(ABC): """Basic Plugin class to derive precision and training type plugins from.""" @abstractmethod - def connect(self, model: torch.nn.Module, *args: Sequence, **kwargs: Sequence) -> Optional[Tuple[torch.nn.Module, Sequence, Sequence]]: + def connect(self, model: torch.nn.Module, *args: Sequence, + **kwargs: Sequence) -> Optional[Tuple[torch.nn.Module, Sequence, Sequence]]: """Connects the plugin with the accelerator (and thereby with trainer and model). Will be called by the accelerator. """ diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index d5f5fc707c066..8cdaba833af85 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -28,6 +28,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin): + def __init__(self): self.backend = AMPType.NATIVE self.scaler = torch.cuda.amp.GradScaler() diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 031b588737614..3e74442e92277 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -37,7 +37,8 @@ def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Ten for p in group["params"]: yield p - def connect(self, model: torch.nn.Module, optimizers: Sequence, lr_schedulers: Sequence) -> Tuple[torch.nn.Module, Sequence, Sequence]: + def connect(self, model: torch.nn.Module, optimizers: Sequence, + lr_schedulers: Sequence) -> Tuple[torch.nn.Module, Sequence, Sequence]: """Connects this plugin to the accelerator and the training process""" return model, optimizers, lr_schedulers diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index ef8e1b8a95efe..b3b01fc720d2b 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -26,6 +26,7 @@ class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): """Mixed Precision for Sharded Training """ + def __init__(self): super().__init__() self.scaler = ShardedGradScaler() From 6ef068a9cef7073347894b50eade46b0d47f5f85 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 1 Feb 2021 12:11:46 +0100 Subject: [PATCH 12/12] flake8 --- tests/core/test_lightning_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 8895d8475b656..17d25b6c9b75a 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import patch, Mock, PropertyMock +from unittest.mock import patch, Mock import pytest from torch.optim import Adam, SGD