diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 5b9c3abc4709a..931a39e07af89 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os + from enum import Enum -from typing import Any, Optional, Union, List +from typing import Any, Optional, Union import torch from torch.optim import Optimizer @@ -22,8 +22,8 @@ from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.core.lightning import LightningModule import torch.distributed as torch_distrib -from pytorch_lightning import _logger as log if torch.distributed.is_available(): from torch.distributed import ReduceOp @@ -208,6 +208,23 @@ def optimizer_state(self, optimizer: Optimizer) -> dict: return self.ddp_plugin.optimizer_state(optimizer) return optimizer.state_dict() + def get_reference_model(self, model) -> LightningModule: + """ + Override to modify returning base :class:`LightningModule` + when accessing variable and functions if the accelerator has wrapped the model. + + Example:: + ref_model = accelerator.get_reference_model(model) + ref_model.training_step(...) + + Args: + model: Accelerator model. + + Returns: Reference :class:`LightningModule`. + + """ + return model + def __getstate__(self): return { 'trainer': self.trainer, diff --git a/pytorch_lightning/accelerators/ddp2_accelerator.py b/pytorch_lightning/accelerators/ddp2_accelerator.py index 9fcbdd4668ee9..142e077cc461c 100644 --- a/pytorch_lightning/accelerators/ddp2_accelerator.py +++ b/pytorch_lightning/accelerators/ddp2_accelerator.py @@ -218,3 +218,6 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + + def get_reference_model(self, model) -> LightningModule: + return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_accelerator.py b/pytorch_lightning/accelerators/ddp_accelerator.py index 69d41cd024646..9dcc85594efbc 100644 --- a/pytorch_lightning/accelerators/ddp_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_accelerator.py @@ -319,3 +319,6 @@ def sync_tensor(self, """ return sync_ddp_if_available(tensor, group, reduce_op) + + def get_reference_model(self, model) -> LightningModule: + return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py index 2a090a72e2b5a..109393dbc770a 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_accelerator.py @@ -246,3 +246,6 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + + def get_reference_model(self, model) -> LightningModule: + return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py index 2ff9c2b7ddaae..ab221f466f54a 100644 --- a/pytorch_lightning/accelerators/ddp_hpc_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_hpc_accelerator.py @@ -213,3 +213,6 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + + def get_reference_model(self, model) -> LightningModule: + return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py index eac51393a5f2e..07051f13d0255 100644 --- a/pytorch_lightning/accelerators/ddp_spawn_accelerator.py +++ b/pytorch_lightning/accelerators/ddp_spawn_accelerator.py @@ -272,3 +272,6 @@ def sync_tensor(self, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: return sync_ddp_if_available(tensor, group, reduce_op) + + def get_reference_model(self, model) -> LightningModule: + return self.ddp_plugin.get_model_from_plugin(model) diff --git a/pytorch_lightning/accelerators/dp_accelerator.py b/pytorch_lightning/accelerators/dp_accelerator.py index 2f6c5dce97c46..a894afba27b71 100644 --- a/pytorch_lightning/accelerators/dp_accelerator.py +++ b/pytorch_lightning/accelerators/dp_accelerator.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union import torch from torch import optim +from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.core.step_result import Result @@ -172,3 +174,8 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list): scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer) if state is not None: scheduler.load_state_dict(state) + + def get_reference_model(self, model) -> LightningModule: + if isinstance(model, LightningDataParallel): + return model.module + return model diff --git a/pytorch_lightning/plugins/ddp_plugin.py b/pytorch_lightning/plugins/ddp_plugin.py index 6a38da2e0c2bc..f85714b89434a 100644 --- a/pytorch_lightning/plugins/ddp_plugin.py +++ b/pytorch_lightning/plugins/ddp_plugin.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, List +from typing import Any, Dict, List, Union import torch.distributed as torch_distrib from pytorch_lightning import _logger as log @@ -108,3 +108,25 @@ def on_before_forward(self, model, *args): def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() + + def get_model_from_plugin( + self, + model: Union[LightningDistributedDataParallel, LightningModule] + ) -> LightningModule: + """ + Override to modify returning base :class:`LightningModule` + when accessing variable and functions outside of the parallel wrapper. + + Example:: + ref_model = ddp_plugin.get_model_from_plugin(model) + ref_model.training_step(...) + + Args: + model: Model with parallel wrapper. + + Returns: Reference :class:`LightningModule` within parallel wrapper. + + """ + if isinstance(model, LightningDistributedDataParallel): + return model.module + return model diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index dbdceb1532288..c5a8c48357b44 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -17,10 +17,6 @@ Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU. """ -from pytorch_lightning.overrides.data_parallel import ( - LightningDistributedDataParallel, - LightningDataParallel, -) class ModelConnector: @@ -28,12 +24,7 @@ def __init__(self, trainer): self.trainer = trainer def copy_trainer_model_properties(self, model): - if isinstance(model, LightningDataParallel): - ref_model = model.module - elif isinstance(model, LightningDistributedDataParallel): - ref_model = model.module - else: - ref_model = model + ref_model = self._get_reference_model(model) automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization self.trainer.train_loop.automatic_optimization = automatic_optimization @@ -55,6 +46,9 @@ def copy_trainer_model_properties(self, model): m.local_rank = self.trainer.local_rank def get_model(self): - is_dp_module = isinstance(self.trainer.model, (LightningDistributedDataParallel, LightningDataParallel)) - model = self.trainer.model.module if is_dp_module else self.trainer.model + return self._get_reference_model(self.trainer.model) + + def _get_reference_model(self, model): + if self.trainer.accelerator_backend: + return self.trainer.accelerator_backend.get_reference_model(model) return model diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py new file mode 100644 index 0000000000000..36bed99498e68 --- /dev/null +++ b/tests/trainer/properties/test_get_model.py @@ -0,0 +1,110 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +import pytest +import torch + +from pytorch_lightning import Trainer +from tests.backends.launcher import DDPLauncher +from tests.base.boring_model import BoringModel + + +class TrainerGetModel(BoringModel): + def on_fit_start(self): + assert self == self.trainer.get_model() + + def on_fit_end(self): + assert self == self.trainer.get_model() + + +def test_get_model(tmpdir): + """ + Tests that :meth:`trainer.get_model` extracts the model correctly + """ + + model = TrainerGetModel() + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +def test_get_model_ddp_cpu(tmpdir): + """ + Tests that :meth:`trainer.get_model` extracts the model correctly when using ddp on cpu + """ + + model = TrainerGetModel() + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + accelerator='ddp_cpu', + num_processes=2 + ) + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_get_model_gpu(tmpdir): + """ + Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + """ + + model = TrainerGetModel() + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + gpus=1 + ) + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +@DDPLauncher.run("--accelerator [accelerator]", + max_epochs=["1"], + accelerator=["ddp", "ddp_spawn"]) +def test_get_model_ddp_gpu(tmpdir, args=None): + """ + Tests that :meth:`trainer.get_model` extracts the model correctly when using GPU + ddp accelerators + """ + + model = TrainerGetModel() + + limit_train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=limit_train_batches, + limit_val_batches=2, + max_epochs=1, + gpus=1, + accelerator=args.accelerator + ) + trainer.fit(model) + return 1