From 57b917c5368289955b35e89c9a077ad1d33475da Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 3 Jul 2022 01:02:32 +0200 Subject: [PATCH 01/53] wip --- src/pytorch_lightning/overrides/base.py | 38 ++++++------------- src/pytorch_lightning/overrides/fairscale.py | 38 ------------------- .../plugins/precision/sharded_native_amp.py | 1 - src/pytorch_lightning/strategies/bagua.py | 10 +---- src/pytorch_lightning/strategies/parallel.py | 5 --- src/pytorch_lightning/strategies/sharded.py | 28 +++++++------- .../strategies/sharded_spawn.py | 25 ++++++------ src/pytorch_lightning/strategies/strategy.py | 18 ++++----- tests/tests_pytorch/overrides/test_base.py | 11 ------ 9 files changed, 50 insertions(+), 124 deletions(-) delete mode 100644 src/pytorch_lightning/overrides/fairscale.py diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 8064154579bae..2b34e2178fc60 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union +from typing import Any, Optional import torch import torch.nn as nn @@ -57,7 +57,7 @@ def on_post_move_to_device(self) -> None: class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: + def __init__(self, forward_module: nn.Module, lightning_module: Optional["pl.LightningDataModule"] = None) -> None: """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step``, ``test_step``, or ``predict_step``. @@ -67,15 +67,19 @@ def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionMod pl_module: the model to wrap """ super().__init__() - self.module = pl_module + self._forward_module = forward_module + self._lightning_module = lightning_module or forward_module - # set the parameters_to_ignore from LightningModule. - _ddp_params_and_buffers_to_ignore = getattr(pl_module, "_ddp_params_and_buffers_to_ignore", []) + # set the parameters_to_ignore from LightningModule + _ddp_params_and_buffers_to_ignore = getattr(self._lightning_module, "_ddp_params_and_buffers_to_ignore", []) self._ddp_params_and_buffers_to_ignore = [f"module.{p}" for p in _ddp_params_and_buffers_to_ignore] + @property + def module(self): + return self._forward_module + def forward(self, *inputs: Any, **kwargs: Any) -> Any: - pl_module = unwrap_lightning_module(self.module) - trainer = pl_module.trainer + trainer = self._lightning_module.trainer if trainer is not None: if trainer.training: @@ -84,7 +88,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: # it is done manually in `LightningModule.manual_backward` # `require_backward_grad_sync` will be reset in the # ddp_strategy `post_training_step` hook - if not pl_module.automatic_optimization: + if not self._lightning_module.automatic_optimization: trainer.model.require_backward_grad_sync = False # type: ignore[assignment] return output if trainer.testing: @@ -97,21 +101,3 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: def on_post_move_to_device(self) -> None: pass - - -def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule": - """Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module`` - attributes on the wrapper. - - Raises: - TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped - further. - """ - model = wrapped_model - if isinstance(model, (DistributedDataParallel, DataParallel)): - model = unwrap_lightning_module(model.module) - if isinstance(model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)): - model = unwrap_lightning_module(model.module) - if not isinstance(model, pl.LightningModule): - raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.") - return model diff --git a/src/pytorch_lightning/overrides/fairscale.py b/src/pytorch_lightning/overrides/fairscale.py deleted file mode 100644 index f48fa8dcf9ccf..0000000000000 --- a/src/pytorch_lightning/overrides/fairscale.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import torch.nn as nn - -import pytorch_lightning as pl -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module -from pytorch_lightning.utilities import _IS_WINDOWS, _module_available - -_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") - -if _FAIRSCALE_AVAILABLE: - from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel - - class LightningShardedDataParallel(_LightningModuleWrapperBase): - # Just do this for later docstrings - pass - - def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule": - model = wrapped_model - if isinstance(model, ShardedDataParallel): - model = model.module - - return unwrap_lightning_module(model) - -else: - LightningShardedDataParallel = ... # type: ignore[assignment,misc] - unwrap_lightning_module_sharded = ... # type: ignore[assignment] diff --git a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py index f5646c2094253..fd85fd6fe93f7 100644 --- a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -15,7 +15,6 @@ import torch -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/src/pytorch_lightning/strategies/bagua.py b/src/pytorch_lightning/strategies/bagua.py index f7beac54fae85..3236a41c24b3f 100644 --- a/src/pytorch_lightning/strategies/bagua.py +++ b/src/pytorch_lightning/strategies/bagua.py @@ -10,7 +10,6 @@ from pytorch_lightning.overrides.base import ( _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase, - unwrap_lightning_module, ) from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -109,13 +108,6 @@ def __init__( self._bagua_flatten = flatten self._bagua_kwargs = bagua_kwargs - @property - def lightning_module(self) -> Optional["pl.LightningModule"]: - model = self.model - if isinstance(model, BaguaDistributedDataParallel): - model = model.module - return unwrap_lightning_module(model) if model is not None else None - def setup_distributed(self) -> None: reset_seed() @@ -189,7 +181,7 @@ def _check_qadam_optimizer(self) -> None: def _configure_bagua_model(self, trainer: "pl.Trainer") -> None: model = LightningBaguaModule(self.model) # type: ignore[arg-type] - self._model = self._setup_model(model) + self.model = self._setup_model(model) # start the background communication for async algorithm if trainer.training and self._bagua_algorithm == "async": diff --git a/src/pytorch_lightning/strategies/parallel.py b/src/pytorch_lightning/strategies/parallel.py index 4fc846870ad59..a39cc9699dd40 100644 --- a/src/pytorch_lightning/strategies/parallel.py +++ b/src/pytorch_lightning/strategies/parallel.py @@ -20,7 +20,6 @@ from torch.nn.parallel import DistributedDataParallel import pytorch_lightning as pl -from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import LayerSync from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -56,10 +55,6 @@ def __init__( def root_device(self) -> torch.device: """Return the root device.""" - @property - def lightning_module(self) -> Optional["pl.LightningModule"]: - return unwrap_lightning_module(self.model) if self.model is not None else None - @property def global_rank(self) -> int: return self.cluster_environment.global_rank() if self.cluster_environment is not None else 0 diff --git a/src/pytorch_lightning/strategies/sharded.py b/src/pytorch_lightning/strategies/sharded.py index 01401bd53bb56..375563e40d878 100644 --- a/src/pytorch_lightning/strategies/sharded.py +++ b/src/pytorch_lightning/strategies/sharded.py @@ -20,7 +20,6 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.enums import PrecisionType @@ -28,12 +27,16 @@ from pytorch_lightning.utilities.imports import _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.utilities import _IS_WINDOWS, _module_available + + +_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") + if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel from fairscale.optim import OSS - - from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded else: OSS = ShardedDataParallel = object @@ -44,6 +47,14 @@ class DDPShardedStrategy(DDPStrategy): strategy_name = "ddp_sharded" _REDUCE_BUFFER_SIZE_DEFAULT: int = 2**23 # 8M + def connect(self, model: pl.LightningModule) -> None: + if not _FAIRSCALE_AVAILABLE: # pragma: no cover + raise MisconfigurationException( + "`DDPShardedStrategy` requires `fairscale` to be installed." + " Install it by running `pip install fairscale`." + ) + return super().connect(model) + def setup(self, trainer: "pl.Trainer") -> None: # share ddp pids to all processes self._rank_0_will_call_children_scripts = self.broadcast(self._rank_0_will_call_children_scripts) @@ -70,7 +81,7 @@ def configure_ddp(self) -> None: self._set_ddp_kwargs() self.setup_optimizers(self.model.trainer) self.model, self.optimizers = self._setup_model_and_optimizers( - model=LightningShardedDataParallel(self.model), + model=_LightningModuleWrapperBase(self.model), optimizers=self.optimizers, ) optimizers_to_device(self.optimizers, self.root_device) @@ -128,15 +139,6 @@ def _optim_state_dict(self, optimizer): """ return optimizer.state_dict() - @property - def lightning_module(self) -> Optional["pl.LightningModule"]: - if not _FAIRSCALE_AVAILABLE: # pragma: no cover - raise MisconfigurationException( - "`DDPShardedStrategy` requires `fairscale` to be installed." - " Install it by running `pip install fairscale`." - ) - return unwrap_lightning_module_sharded(self.model) if self.model is not None else None - def pre_backward(self, closure_loss: Tensor) -> None: pass diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 4550e397ded80..89d0b7ab351ea 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -19,18 +19,20 @@ from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +from pytorch_lightning.utilities import _IS_WINDOWS, _module_available + +_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel from fairscale.optim import OSS - from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded else: OSS = ShardedDataParallel = object @@ -40,11 +42,19 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): strategy_name = "ddp_sharded_spawn" + def connect(self, model: pl.LightningModule) -> None: + if not _FAIRSCALE_AVAILABLE: # pragma: no cover + raise MisconfigurationException( + "`DDPSpawnShardedStrategy` requires `fairscale` to be installed." + " Install it by running `pip install fairscale`." + ) + return super().connect(model) + def configure_ddp(self) -> None: # set up optimizers after the wrapped module has been moved to the device self.setup_optimizers(self.lightning_module.trainer) self.model, self.optimizers = self._setup_model_and_optimizers( - model=LightningShardedDataParallel(self.model), optimizers=self.optimizers + model=_LightningModuleWrapperBase(self.model), optimizers=self.optimizers ) optimizers_to_device(self.optimizers, self.root_device) @@ -100,15 +110,6 @@ def _optim_state_dict(self, optimizer): """ return optimizer.state_dict() - @property - def lightning_module(self) -> Optional["pl.LightningModule"]: - if not _FAIRSCALE_AVAILABLE: # pragma: no cover - raise MisconfigurationException( - "`DDPSpawnShardedStrategy` requires `fairscale` to be installed." - " Install it by running `pip install fairscale`." - ) - return unwrap_lightning_module_sharded(self.model) if self.model is not None else None - def pre_backward(self, closure_loss: Tensor) -> None: pass diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 9de30889336fe..98d0dc4625bdf 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -24,7 +24,6 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import _init_optimizers_and_lr_schedulers, LightningOptimizer -from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins import TorchCheckpointIO from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -52,14 +51,15 @@ def __init__( precision_plugin: Optional[PrecisionPlugin] = None, ) -> None: self.accelerator = accelerator - self._launcher: Optional[_Launcher] = None - self._model: Optional[Module] = None self.checkpoint_io = checkpoint_io self.precision_plugin = precision_plugin - self._optimizers: List[Optimizer] = [] - self._lightning_optimizers: Dict[int, LightningOptimizer] = {} self.lr_scheduler_configs: List[LRSchedulerConfig] = [] self.optimizer_frequencies: List[int] = [] + self._lightning_module: Optional[pl.LightningModule] = None + self._model: Optional[Module] = None + self._launcher: Optional[_Launcher] = None + self._optimizers: List[Optimizer] = [] + self._lightning_optimizers: Dict[int, LightningOptimizer] = {} if is_overridden("post_dispatch", self, parent=Strategy): rank_zero_deprecation( f"`{self.__class__.__name__}.post_dispatch()` has been deprecated in v1.6 and will be removed in v1.7." @@ -105,9 +105,9 @@ def optimizers(self, optimizers: List[Optimizer]) -> None: idx: LightningOptimizer._to_lightning_optimizer(opt, self, idx) for idx, opt in enumerate(self.optimizers) } - def connect(self, model: Module) -> None: + def connect(self, model: pl.LightningModule) -> None: """Called by the accelerator to connect the accelerator and the model with this plugin.""" - self.model = model + self._lightning_module = model def _configure_launcher(self): """Attach the launcher based on Strategy.""" @@ -303,7 +303,7 @@ def post_backward(self, closure_loss: Tensor) -> None: @property def model(self) -> Optional[Module]: """Returns the potentially wrapped LightningModule.""" - return self._model + return self._model if self._model is not None else self._lightning_module @model.setter def model(self, new_model: Optional[Module]) -> None: @@ -312,7 +312,7 @@ def model(self, new_model: Optional[Module]) -> None: @property def lightning_module(self) -> Optional["pl.LightningModule"]: """Returns the pure LightningModule without potential wrappers.""" - return unwrap_lightning_module(self.model) if self.model is not None else None + return self._lightning_module def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() diff --git a/tests/tests_pytorch/overrides/test_base.py b/tests/tests_pytorch/overrides/test_base.py index fa07912d0d44e..ca131f21d2aef 100644 --- a/tests/tests_pytorch/overrides/test_base.py +++ b/tests/tests_pytorch/overrides/test_base.py @@ -13,13 +13,11 @@ # limitations under the License. import pytest import torch -from torch.nn import DataParallel from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.overrides.base import ( _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase, - unwrap_lightning_module, ) @@ -30,12 +28,3 @@ def test_wrapper_device_dtype(wrapper_class): wrapped_model.to(dtype=torch.float16) assert model.dtype == torch.float16 - - -def test_unwrap_lightning_module(): - model = BoringModel() - wrapped_model = _LightningPrecisionModuleWrapperBase(model) - wrapped_model = _LightningModuleWrapperBase(wrapped_model) - wrapped_model = DataParallel(wrapped_model) - - assert unwrap_lightning_module(wrapped_model) == model From 626ac9e57a550487e642b347ee10ab87e54f68cc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 00:03:56 +0200 Subject: [PATCH 02/53] model setter --- .../plugins/precision/sharded_native_amp.py | 1 + src/pytorch_lightning/strategies/ddp.py | 2 +- src/pytorch_lightning/strategies/sharded.py | 2 +- src/pytorch_lightning/strategies/sharded_spawn.py | 2 +- src/pytorch_lightning/strategies/strategy.py | 2 +- src/pytorch_lightning/trainer/trainer.py | 8 ++++---- tests/tests_pytorch/helpers/runif.py | 2 +- tests/tests_pytorch/strategies/test_sharded_strategy.py | 2 +- 8 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py index fd85fd6fe93f7..570e25bd85caa 100644 --- a/src/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -17,6 +17,7 @@ from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE if _FAIRSCALE_AVAILABLE: from fairscale.optim import OSS diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index 922730df35269..bd528aaaa0a62 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -33,7 +33,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin diff --git a/src/pytorch_lightning/strategies/sharded.py b/src/pytorch_lightning/strategies/sharded.py index 375563e40d878..b3531bbdc4b7e 100644 --- a/src/pytorch_lightning/strategies/sharded.py +++ b/src/pytorch_lightning/strategies/sharded.py @@ -47,7 +47,7 @@ class DDPShardedStrategy(DDPStrategy): strategy_name = "ddp_sharded" _REDUCE_BUFFER_SIZE_DEFAULT: int = 2**23 # 8M - def connect(self, model: pl.LightningModule) -> None: + def connect(self, model: "pl.LightningModule") -> None: if not _FAIRSCALE_AVAILABLE: # pragma: no cover raise MisconfigurationException( "`DDPShardedStrategy` requires `fairscale` to be installed." diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 89d0b7ab351ea..a523d6cac3dff 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -42,7 +42,7 @@ class DDPSpawnShardedStrategy(DDPSpawnStrategy): strategy_name = "ddp_sharded_spawn" - def connect(self, model: pl.LightningModule) -> None: + def connect(self, model: "pl.LightningModule") -> None: if not _FAIRSCALE_AVAILABLE: # pragma: no cover raise MisconfigurationException( "`DDPSpawnShardedStrategy` requires `fairscale` to be installed." diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 557f385f7393b..87d960a59c92c 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -114,7 +114,7 @@ def optimizers(self, optimizers: List[Optimizer]) -> None: idx: LightningOptimizer._to_lightning_optimizer(opt, self, idx) for idx, opt in enumerate(self.optimizers) } - def connect(self, model: pl.LightningModule) -> None: + def connect(self, model: "pl.LightningModule") -> None: """Called by the accelerator to connect the accelerator and the model with this plugin.""" self._lightning_module = model diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 561fe799f1010..f74217813e48c 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -696,7 +696,7 @@ def fit( datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """ - self.strategy.model = model + self.strategy._lightning_module = model self._call_and_handle_interrupt( self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) @@ -776,7 +776,7 @@ def validate( :meth:`~pytorch_lightning.core.module.LightningModule.validation_epoch_end`, etc. The length of the list corresponds to the number of validation dataloaders used. """ - self.strategy.model = model or self.lightning_module + self.strategy._lightning_module = model or self.lightning_module return self._call_and_handle_interrupt(self._validate_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _validate_impl( @@ -864,7 +864,7 @@ def test( :meth:`~pytorch_lightning.core.module.LightningModule.test_epoch_end`, etc. The length of the list corresponds to the number of test dataloaders used. """ - self.strategy.model = model or self.lightning_module + self.strategy._lightning_module = model or self.lightning_module return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule) def _test_impl( @@ -951,7 +951,7 @@ def predict( Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """ - self.strategy.model = model or self.lightning_module + self.strategy._lightning_module = model or self.lightning_module return self._call_and_handle_interrupt( self._predict_impl, model, dataloaders, datamodule, return_predictions, ckpt_path ) diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index d8e38e7101fe0..f7b29f6f1740e 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -22,7 +22,7 @@ from pytorch_lightning.accelerators.mps import _MPS_AVAILABLE from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.imports import ( diff --git a/tests/tests_pytorch/strategies/test_sharded_strategy.py b/tests/tests_pytorch/strategies/test_sharded_strategy.py index bfd2a3abfc411..9dc3400163536 100644 --- a/tests/tests_pytorch/strategies/test_sharded_strategy.py +++ b/tests/tests_pytorch/strategies/test_sharded_strategy.py @@ -7,7 +7,7 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies import DDPShardedStrategy, DDPSpawnShardedStrategy from pytorch_lightning.trainer.states import TrainerFn from tests_pytorch.helpers.runif import RunIf From b07e11cb0e0e674152205efd38c96d144d4fcdaa Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 00:05:59 +0200 Subject: [PATCH 03/53] fix import --- tests/tests_pytorch/plugins/precision/test_sharded_precision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/plugins/precision/test_sharded_precision.py b/tests/tests_pytorch/plugins/precision/test_sharded_precision.py index 0c08c8e9540eb..b53ddc4084242 100644 --- a/tests/tests_pytorch/plugins/precision/test_sharded_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_sharded_precision.py @@ -15,7 +15,7 @@ import pytest import torch -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins import ShardedNativeMixedPrecisionPlugin from tests_pytorch.helpers.runif import RunIf From 352397f94f390ab2d98b2d9a1da2e65d7ec01739 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Jul 2022 22:07:42 +0000 Subject: [PATCH 04/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/overrides/base.py | 1 - src/pytorch_lightning/strategies/bagua.py | 5 +---- src/pytorch_lightning/strategies/ddp.py | 8 ++++++-- src/pytorch_lightning/strategies/sharded.py | 5 ++--- src/pytorch_lightning/strategies/sharded_spawn.py | 4 ++-- src/pytorch_lightning/strategies/strategy.py | 1 - tests/tests_pytorch/helpers/runif.py | 2 +- tests/tests_pytorch/overrides/test_base.py | 5 +---- .../plugins/precision/test_sharded_precision.py | 2 +- tests/tests_pytorch/strategies/test_sharded_strategy.py | 2 +- 10 files changed, 15 insertions(+), 20 deletions(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 1d29c9c12583d..1aba7142df9e5 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -95,4 +95,3 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: if trainer.predicting: return self.module.predict_step(*inputs, **kwargs) return self.module(*inputs, **kwargs) - diff --git a/src/pytorch_lightning/strategies/bagua.py b/src/pytorch_lightning/strategies/bagua.py index 2b56229b32aaf..acc8875fe1660 100644 --- a/src/pytorch_lightning/strategies/bagua.py +++ b/src/pytorch_lightning/strategies/bagua.py @@ -7,10 +7,7 @@ from torch.nn import Module import pytorch_lightning as pl -from pytorch_lightning.overrides.base import ( - _LightningModuleWrapperBase, - _LightningPrecisionModuleWrapperBase, -) +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin diff --git a/src/pytorch_lightning/strategies/ddp.py b/src/pytorch_lightning/strategies/ddp.py index bd528aaaa0a62..3e2e987288103 100644 --- a/src/pytorch_lightning/strategies/ddp.py +++ b/src/pytorch_lightning/strategies/ddp.py @@ -33,7 +33,6 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward -from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin @@ -53,7 +52,12 @@ sync_ddp_if_available, ) from pytorch_lightning.utilities.exceptions import DeadlockDetectedException -from pytorch_lightning.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11 +from pytorch_lightning.utilities.imports import ( + _FAIRSCALE_AVAILABLE, + _IS_WINDOWS, + _TORCH_GREATER_EQUAL_1_10, + _TORCH_GREATER_EQUAL_1_11, +) from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.seed import reset_seed diff --git a/src/pytorch_lightning/strategies/sharded.py b/src/pytorch_lightning/strategies/sharded.py index b3531bbdc4b7e..ee512db310f69 100644 --- a/src/pytorch_lightning/strategies/sharded.py +++ b/src/pytorch_lightning/strategies/sharded.py @@ -20,16 +20,15 @@ import pytorch_lightning as pl from pytorch_lightning.core.optimizer import LightningOptimizer +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities import _IS_WINDOWS, _module_available from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase -from pytorch_lightning.utilities import _IS_WINDOWS, _module_available - _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index a523d6cac3dff..b06b7c7c2d51b 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -19,13 +19,13 @@ from torch.optim import Optimizer import pytorch_lightning as pl +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities import _IS_WINDOWS, _module_available from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase -from pytorch_lightning.utilities import _IS_WINDOWS, _module_available _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 87d960a59c92c..40fccb2dd70fb 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -68,7 +68,6 @@ def __init__( self._lightning_optimizers: Dict[int, LightningOptimizer] = {} self.lr_scheduler_configs: List[LRSchedulerConfig] = [] self.optimizer_frequencies: List[int] = [] - @property def launcher(self) -> Optional[_Launcher]: diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index f7b29f6f1740e..f568e01449290 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -22,11 +22,11 @@ from pytorch_lightning.accelerators.mps import _MPS_AVAILABLE from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE -from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities.imports import ( _APEX_AVAILABLE, + _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _HIVEMIND_AVAILABLE, _HOROVOD_AVAILABLE, diff --git a/tests/tests_pytorch/overrides/test_base.py b/tests/tests_pytorch/overrides/test_base.py index ca131f21d2aef..101cf415713d5 100644 --- a/tests/tests_pytorch/overrides/test_base.py +++ b/tests/tests_pytorch/overrides/test_base.py @@ -15,10 +15,7 @@ import torch from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.overrides.base import ( - _LightningModuleWrapperBase, - _LightningPrecisionModuleWrapperBase, -) +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase @pytest.mark.parametrize("wrapper_class", [_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase]) diff --git a/tests/tests_pytorch/plugins/precision/test_sharded_precision.py b/tests/tests_pytorch/plugins/precision/test_sharded_precision.py index b53ddc4084242..ab7a4a432a2c6 100644 --- a/tests/tests_pytorch/plugins/precision/test_sharded_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_sharded_precision.py @@ -15,8 +15,8 @@ import pytest import torch -from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from pytorch_lightning.plugins import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from tests_pytorch.helpers.runif import RunIf ShardedGradScaler = None diff --git a/tests/tests_pytorch/strategies/test_sharded_strategy.py b/tests/tests_pytorch/strategies/test_sharded_strategy.py index 9dc3400163536..6d0c34b734689 100644 --- a/tests/tests_pytorch/strategies/test_sharded_strategy.py +++ b/tests/tests_pytorch/strategies/test_sharded_strategy.py @@ -7,9 +7,9 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies import DDPShardedStrategy, DDPSpawnShardedStrategy from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from tests_pytorch.helpers.runif import RunIf if _FAIRSCALE_AVAILABLE: From 7184b59d73c3c192012deae2efb3b7ab8c0cc1e8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 00:29:15 +0200 Subject: [PATCH 05/53] refactor wrappers overrides --- src/pytorch_lightning/overrides/base.py | 38 +++++++++---------- .../overrides/data_parallel.py | 7 ++-- src/pytorch_lightning/strategies/bagua.py | 6 +-- src/pytorch_lightning/strategies/deepspeed.py | 4 +- src/pytorch_lightning/strategies/ipu.py | 4 +- 5 files changed, 29 insertions(+), 30 deletions(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 1d29c9c12583d..3d928f656bf9a 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -11,13 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Union import torch -import torch.nn as nn -from torch.nn import DataParallel -from torch.nn.parallel import DistributedDataParallel - import pytorch_lightning as pl from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin @@ -54,45 +50,47 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, forward_module: nn.Module, lightning_module: Optional["pl.LightningDataModule"] = None) -> None: + def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step``, ``test_step``, or ``predict_step``. Inheriting classes may also modify the inputs or outputs of forward. Args: - pl_module: the model to wrap + forward_module: The module to wrap. If it's not a LightningModule, it must have an attribute ``.module`` + which must be a LightningModule reference. """ super().__init__() self._forward_module = forward_module - self._lightning_module = lightning_module or forward_module - # set the parameters_to_ignore from LightningModule - _ddp_params_and_buffers_to_ignore = getattr(self._lightning_module, "_ddp_params_and_buffers_to_ignore", []) + # set the parameters_to_ignore from LightningModule. + _ddp_params_and_buffers_to_ignore = getattr(self._forward_module, "_ddp_params_and_buffers_to_ignore", []) self._ddp_params_and_buffers_to_ignore = [f"module.{p}" for p in _ddp_params_and_buffers_to_ignore] @property - def module(self): - return self._forward_module + def lightning_module(self) -> "pl.LightningModule": + if isinstance(self._forward_module, pl.LightningModule): + return self._forward_module + return self._forward_module.module def forward(self, *inputs: Any, **kwargs: Any) -> Any: - trainer = self._lightning_module.trainer + pl_module = self.lightning_module + trainer = pl_module._trainer if trainer is not None: if trainer.training: - output = self.module.training_step(*inputs, **kwargs) + output = self._forward_module.training_step(*inputs, **kwargs) # In manual_optimization, we need to prevent DDP reducer as # it is done manually in `LightningModule.manual_backward` # `require_backward_grad_sync` will be reset in the # ddp_strategy `post_training_step` hook - if not self._lightning_module.automatic_optimization: + if not pl_module.automatic_optimization: trainer.model.require_backward_grad_sync = False # type: ignore[assignment] return output if trainer.testing: - return self.module.test_step(*inputs, **kwargs) + return self._forward_module.test_step(*inputs, **kwargs) if trainer.sanity_checking or trainer.validating: - return self.module.validation_step(*inputs, **kwargs) + return self._forward_module.validation_step(*inputs, **kwargs) if trainer.predicting: - return self.module.predict_step(*inputs, **kwargs) - return self.module(*inputs, **kwargs) - + return self._forward_module.predict_step(*inputs, **kwargs) + return self._forward_module(*inputs, **kwargs) diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index 9fa253b9d8321..df48cd9602e08 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -52,11 +52,12 @@ class LightningParallelModule(_LightningModuleWrapperBase): ) Args: - pl_module: the model to wrap + forward_module: The module to wrap. If it's not a LightningModule, it must have an attribute ``.module`` which + must be a LightningModule reference. """ - def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: - super().__init__(pl_module) + def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: + super().__init__(forward_module) _ignore_scalar_return_in_dp() def forward(self, *inputs: Any, **kwargs: Any) -> Any: diff --git a/src/pytorch_lightning/strategies/bagua.py b/src/pytorch_lightning/strategies/bagua.py index 2b56229b32aaf..cbf8738f19e84 100644 --- a/src/pytorch_lightning/strategies/bagua.py +++ b/src/pytorch_lightning/strategies/bagua.py @@ -53,10 +53,10 @@ class LightningBaguaModule(_LightningModuleWrapperBase): - def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: - super().__init__(pl_module) + def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: + super().__init__(forward_module) # Bagua use `bagua_module_name` to distinguish different modules - self._bagua_module_name = f"{pl_module.__class__.__name__}{id(pl_module)}" + self._bagua_module_name = f"{forward_module.__class__.__name__}{id(forward_module)}" class BaguaStrategy(DDPStrategy): diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index ede42754aafc9..9e3f628f36ce0 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -71,9 +71,9 @@ def remove_module_hooks(model: torch.nn.Module) -> None: class LightningDeepSpeedModule(_LightningModuleWrapperBase): def __init__( - self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int] + self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int] ) -> None: - super().__init__(pl_module) + super().__init__(forward_module) self.precision = precision def forward(self, *inputs, **kwargs): diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 5413756c15271..c477c2463c681 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -43,9 +43,9 @@ class LightningIPUModule(_LightningModuleWrapperBase): def __init__( - self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int] + self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int] ) -> None: - super().__init__(pl_module) + super().__init__(forward_module) self.precision = precision def forward(self, *inputs: Any, **kwargs: Any) -> Any: From f861a01c7b1f9357e2e8f277dfe0aef4d565e59c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 00:31:19 +0200 Subject: [PATCH 06/53] refactor --- tests/tests_pytorch/models/test_amp.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index 159a3767c1df2..786de99f59714 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -96,8 +96,6 @@ def test_amp_cpus(tmpdir, strategy, precision, devices): trainer.test(model) trainer.predict(model, DataLoader(RandomDataset(32, 64))) - assert trainer.state.finished, f"Training failed with {trainer.state}" - @RunIf(min_cuda_gpus=2, min_torch="1.10") @pytest.mark.parametrize("strategy", [None, "dp", "ddp_spawn"]) @@ -121,8 +119,6 @@ def test_amp_gpus(tmpdir, strategy, precision, devices): trainer.test(model) trainer.predict(model, DataLoader(RandomDataset(32, 64))) - assert trainer.state.finished, f"Training failed with {trainer.state}" - @RunIf(min_cuda_gpus=2) @mock.patch.dict( @@ -162,9 +158,6 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): ) trainer.fit(model) - # correct result and ok accuracy - assert trainer.state.finished, "amp + ddp model failed to complete" - # test root model address assert isinstance(trainer.strategy.cluster_environment, SLURMEnvironment) assert trainer.strategy.cluster_environment.resolve_root_node_address("abc") == "abc" @@ -185,7 +178,6 @@ def test_amp_without_apex(bwd_mock, tmpdir): trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, amp_backend="apex") assert trainer.amp_backend is None trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" assert not bwd_mock.called @@ -213,7 +205,6 @@ def configure_optimizers(self): ) assert str(trainer.amp_backend) == "AMPType.APEX" trainer.fit(model) - assert trainer.state.finished, f"Training failed with {trainer.state}" # `max_steps` is fulfilled in the third batch first optimizer, but we don't check the loop # `done` condition until all optimizers have run, so the number of backwards is higher than `max_steps` assert bwd_mock.call_count == 6 From 78da9affbe8a44d256ed2e78da64bf55b10cf1ec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Jul 2022 22:32:51 +0000 Subject: [PATCH 07/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/overrides/base.py | 1 + src/pytorch_lightning/strategies/deepspeed.py | 4 +++- src/pytorch_lightning/strategies/ipu.py | 4 +++- 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 3d928f656bf9a..759d2531bdbb8 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -14,6 +14,7 @@ from typing import Any, Union import torch + import pytorch_lightning as pl from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 9e3f628f36ce0..1a4564634a180 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -71,7 +71,9 @@ def remove_module_hooks(model: torch.nn.Module) -> None: class LightningDeepSpeedModule(_LightningModuleWrapperBase): def __init__( - self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int] + self, + forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], + precision: Union[str, int], ) -> None: super().__init__(forward_module) self.precision = precision diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index c477c2463c681..e66dccbe282aa 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -43,7 +43,9 @@ class LightningIPUModule(_LightningModuleWrapperBase): def __init__( - self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int] + self, + forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], + precision: Union[str, int], ) -> None: super().__init__(forward_module) self.precision = precision From 2a2e1e803b70a4bf3c5e7c44333552821582e134 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 00:34:06 +0200 Subject: [PATCH 08/53] update --- src/pytorch_lightning/overrides/base.py | 2 +- src/pytorch_lightning/overrides/data_parallel.py | 4 ++-- src/pytorch_lightning/strategies/deepspeed.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 3d928f656bf9a..2782e11829352 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -58,7 +58,7 @@ def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisi Args: forward_module: The module to wrap. If it's not a LightningModule, it must have an attribute ``.module`` - which must be a LightningModule reference. + pointing to a LightningModule reference. """ super().__init__() self._forward_module = forward_module diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index df48cd9602e08..25f6780a74786 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -52,8 +52,8 @@ class LightningParallelModule(_LightningModuleWrapperBase): ) Args: - forward_module: The module to wrap. If it's not a LightningModule, it must have an attribute ``.module`` which - must be a LightningModule reference. + forward_module: The module to wrap. If it's not a LightningModule, it must have an attribute ``.module`` + pointing to a LightningModule reference. """ def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 9e3f628f36ce0..23e702cc903fd 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -53,7 +53,7 @@ warning_cache = WarningCache() -_DEEPSPEED_AVAILABLE: bool = _RequirementAvailable("deepspeed") +_DEEPSPEED_AVAILABLE = _RequirementAvailable("deepspeed") if _DEEPSPEED_AVAILABLE: import deepspeed From e94797f16d823b1fc5fd0a32524193c372f9308c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 01:05:24 +0200 Subject: [PATCH 09/53] update --- src/pytorch_lightning/strategies/sharded.py | 6 +----- tests/tests_pytorch/utilities/test_imports.py | 3 +-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/strategies/sharded.py b/src/pytorch_lightning/strategies/sharded.py index ee512db310f69..ce1e4cd96b961 100644 --- a/src/pytorch_lightning/strategies/sharded.py +++ b/src/pytorch_lightning/strategies/sharded.py @@ -23,16 +23,12 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.strategies.ddp import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _IS_WINDOWS, _module_available from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only -_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") - - if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel from fairscale.optim import OSS diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index 25995bb029f3a..c673716c457f2 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -13,7 +13,6 @@ # limitations under the License. import operator -from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE from pytorch_lightning.utilities import ( @@ -23,7 +22,7 @@ _OMEGACONF_AVAILABLE, _POPTORCH_AVAILABLE, ) -from pytorch_lightning.utilities.imports import _compare_version, _RequirementAvailable, torch +from pytorch_lightning.utilities.imports import _compare_version, _FAIRSCALE_AVAILABLE, _RequirementAvailable, torch def test_module_exists(): From 39d775513636eef399ccfe02f61e127373c0c2b8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 01:06:35 +0200 Subject: [PATCH 10/53] update --- src/pytorch_lightning/strategies/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index ea8a57dcc8f5f..18d1c3e0a25a6 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -466,7 +466,7 @@ def init_deepspeed(self): "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs." ) - model = LightningDeepSpeedModule(pl_module=self.model, precision=self.precision_plugin.precision) + model = LightningDeepSpeedModule(forward_module=self.model, precision=self.precision_plugin.precision) if self.lightning_module.trainer and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) From cad60f260865aa9af4e447f7b2ae8c4ae571221d Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 01:37:29 +0200 Subject: [PATCH 11/53] fixes --- src/pytorch_lightning/overrides/data_parallel.py | 4 ++-- src/pytorch_lightning/strategies/deepspeed.py | 6 ------ src/pytorch_lightning/strategies/ipu.py | 4 ---- 3 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index 25f6780a74786..0be86a7efb1a0 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -66,7 +66,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: output = super().forward(*inputs, **kwargs) def output_transform(data: Any) -> Any: - device = cast(torch.device, self.module.device) + device = cast(torch.device, self.lightning_module.device) data = python_scalar_to_tensor(data, device) data = unsqueeze_scalar_tensor(data) return data @@ -96,7 +96,7 @@ def find_tensor_with_device(tensor: Tensor) -> Tensor: if replica_device is not None: # by calling .to() we force the update to the self.device property - self.module.to(device=replica_device) + self.lightning_module.to(device=replica_device) else: rank_zero_warn( "Could not determine on which device the inputs are." diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 18d1c3e0a25a6..57e96771380de 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -586,12 +586,6 @@ def _initialize_deepspeed_inference(self, model): ) self.model = model - @property - def lightning_module(self): - # the model may not be wrapped with DeepEngine & LightningDeepSpeedModule if calling this too early - module = getattr(self.model, "module", self.model) - return module.module if isinstance(module, LightningDeepSpeedModule) else module - @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index e66dccbe282aa..181b447a7cdf0 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -219,10 +219,6 @@ def inference_opts(self) -> "poptorch.Options": self._inference_opts = self._create_opts(training=False) return self._inference_opts - @property - def lightning_module(self) -> Optional["pl.LightningModule"]: - return self.model.module if isinstance(self.model, LightningIPUModule) else self.model - def _convert_to_poptorch_loader( self, dataloader: DataLoader, sampler, mode: Optional[RunningStage] = None ) -> "poptorch.DataLoader": From 9ee06306e1ee4f2430b48463eaf9b3b3f4517395 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 01:38:53 +0200 Subject: [PATCH 12/53] update --- src/pytorch_lightning/strategies/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 57e96771380de..91967aa11c373 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -53,7 +53,7 @@ warning_cache = WarningCache() -_DEEPSPEED_AVAILABLE = _RequirementAvailable("deepspeed") +_DEEPSPEED_AVAILABLE: bool = _RequirementAvailable("deepspeed") if _DEEPSPEED_AVAILABLE: import deepspeed From 12a3ad6a0d61e472876386c5351dbe9e7b6f4b9c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 01:53:13 +0200 Subject: [PATCH 13/53] simplify --- src/pytorch_lightning/overrides/data_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index 0be86a7efb1a0..e1914020054c4 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -96,7 +96,7 @@ def find_tensor_with_device(tensor: Tensor) -> Tensor: if replica_device is not None: # by calling .to() we force the update to the self.device property - self.lightning_module.to(device=replica_device) + self._forward_module.to(device=replica_device) else: rank_zero_warn( "Could not determine on which device the inputs are." From b99bc234bf0cf7d8bc450782afa4d9f5044971ab Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 02:06:46 +0200 Subject: [PATCH 14/53] debug --- src/pytorch_lightning/strategies/dp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_lightning/strategies/dp.py b/src/pytorch_lightning/strategies/dp.py index 5ab5021b8ac50..fc6407901dc74 100644 --- a/src/pytorch_lightning/strategies/dp.py +++ b/src/pytorch_lightning/strategies/dp.py @@ -70,6 +70,7 @@ def world_size(self) -> int: def setup(self, trainer: "pl.Trainer") -> None: # model needs to be moved to the device before it is wrapped self.model_to_device() + breakpoint() assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model = self._setup_model(LightningParallelModule(self.model)) super().setup(trainer) From b7c74ef1ec96d50c9b05ed6b6bb543c6530ff6ba Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 02:11:48 +0200 Subject: [PATCH 15/53] update --- src/pytorch_lightning/strategies/dp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/dp.py b/src/pytorch_lightning/strategies/dp.py index fc6407901dc74..5ab5021b8ac50 100644 --- a/src/pytorch_lightning/strategies/dp.py +++ b/src/pytorch_lightning/strategies/dp.py @@ -70,7 +70,6 @@ def world_size(self) -> int: def setup(self, trainer: "pl.Trainer") -> None: # model needs to be moved to the device before it is wrapped self.model_to_device() - breakpoint() assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model = self._setup_model(LightningParallelModule(self.model)) super().setup(trainer) From d842246ea02142a689835cf216b58efc698d3f44 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 03:52:17 +0200 Subject: [PATCH 16/53] validate --- src/pytorch_lightning/overrides/base.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 1360232a5f982..047c7bb0b8a36 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -63,6 +63,12 @@ def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisi """ super().__init__() self._forward_module = forward_module + if not isinstance(forward_module, pl.LightningModule) or ( + not isinstance(getattr(forward_module, "module", None), pl.LightningModule) + ): + raise ValueError( + "`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one." + ) # set the parameters_to_ignore from LightningModule. _ddp_params_and_buffers_to_ignore = getattr(self._forward_module, "_ddp_params_and_buffers_to_ignore", []) From b17653481137f72a74a1fe2ddee08a7c2d51b1e4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 03:55:28 +0200 Subject: [PATCH 17/53] update --- src/pytorch_lightning/overrides/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 047c7bb0b8a36..89d5b5de28554 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -62,13 +62,14 @@ def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisi pointing to a LightningModule reference. """ super().__init__() - self._forward_module = forward_module if not isinstance(forward_module, pl.LightningModule) or ( not isinstance(getattr(forward_module, "module", None), pl.LightningModule) ): raise ValueError( - "`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one." + "`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one," + f" got: {forward_module.__class__.__qualname__}" ) + self._forward_module = forward_module # set the parameters_to_ignore from LightningModule. _ddp_params_and_buffers_to_ignore = getattr(self._forward_module, "_ddp_params_and_buffers_to_ignore", []) From 67fae07c4575fd816396572d57ee72b1e9d2d185 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 03:57:46 +0200 Subject: [PATCH 18/53] fix --- src/pytorch_lightning/overrides/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 89d5b5de28554..18ac438d3ab94 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -62,7 +62,7 @@ def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisi pointing to a LightningModule reference. """ super().__init__() - if not isinstance(forward_module, pl.LightningModule) or ( + if not isinstance(forward_module, pl.LightningModule) and ( not isinstance(getattr(forward_module, "module", None), pl.LightningModule) ): raise ValueError( From 1e9fe947511ea1906d22982141df21f233dacad7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 04:01:14 +0200 Subject: [PATCH 19/53] debug --- src/pytorch_lightning/strategies/dp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/dp.py b/src/pytorch_lightning/strategies/dp.py index 5ab5021b8ac50..ceae9bbead8e8 100644 --- a/src/pytorch_lightning/strategies/dp.py +++ b/src/pytorch_lightning/strategies/dp.py @@ -70,7 +70,8 @@ def world_size(self) -> int: def setup(self, trainer: "pl.Trainer") -> None: # model needs to be moved to the device before it is wrapped self.model_to_device() - assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) + assert isinstance(self.model, pl.LightningModule), f"{self.model}" + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)), f"{self.model}" self.model = self._setup_model(LightningParallelModule(self.model)) super().setup(trainer) From 50503999019ca0be95ce8700d2de0a1b1b581e45 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 04:03:21 +0200 Subject: [PATCH 20/53] teardown --- src/pytorch_lightning/strategies/strategy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 40fccb2dd70fb..6656b0d069b0f 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -467,6 +467,7 @@ def teardown(self) -> None: if self.lightning_module is not None: log.detail(f"{self.__class__.__name__}: moving model to CPU") self.lightning_module.cpu() + self.model = self.lightning_module self.precision_plugin.teardown() assert self.accelerator is not None self.accelerator.teardown() From dfffb4985d5e7c23737e3fdd916e127082e83dd5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 04:04:55 +0200 Subject: [PATCH 21/53] fix --- src/pytorch_lightning/strategies/dp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/dp.py b/src/pytorch_lightning/strategies/dp.py index ceae9bbead8e8..5ab5021b8ac50 100644 --- a/src/pytorch_lightning/strategies/dp.py +++ b/src/pytorch_lightning/strategies/dp.py @@ -70,8 +70,7 @@ def world_size(self) -> int: def setup(self, trainer: "pl.Trainer") -> None: # model needs to be moved to the device before it is wrapped self.model_to_device() - assert isinstance(self.model, pl.LightningModule), f"{self.model}" - assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)), f"{self.model}" + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model = self._setup_model(LightningParallelModule(self.model)) super().setup(trainer) From a0d1940ad7e2d22287107ccaede9a3a3b2879e1a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 04:15:10 +0200 Subject: [PATCH 22/53] discussion --- src/pytorch_lightning/trainer/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index f74217813e48c..3ef638a71ed18 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -659,6 +659,8 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: if not self.interrupted: self.state.status = TrainerStatus.INTERRUPTED self._call_callback_hooks("on_exception", exception) + # TODO: Do we need teardown here? What if we get keyboard interrupt and the model remains wrapped? + # Will a subsequent call wrap the model again (and fail)? except BaseException as exception: self.state.status = TrainerStatus.INTERRUPTED if distributed_available() and self.world_size > 1: From 0bbd41aefc86c25ef6b5f7d1429a44edb3c78be4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 04:33:11 +0200 Subject: [PATCH 23/53] rename --- tests/tests_pytorch/accelerators/test_ipu.py | 46 ++++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py index 97f374a40d6c3..e03827e3021d9 100644 --- a/tests/tests_pytorch/accelerators/test_ipu.py +++ b/tests/tests_pytorch/accelerators/test_ipu.py @@ -118,7 +118,7 @@ def test_warning_if_ipus_not_used(): @RunIf(ipu=True) -def test_no_warning_plugin(tmpdir): +def test_no_warning_strategy(tmpdir): with pytest.warns(None) as record: Trainer(default_root_dir=tmpdir, max_epochs=1, strategy=IPUStrategy(training_opts=poptorch.Options())) assert len(record) == 0 @@ -225,7 +225,7 @@ def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: @RunIf(ipu=True) -def test_device_iterations_ipu_plugin(tmpdir): +def test_device_iterations_ipu_strategy(tmpdir): class TestCallback(Callback): def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: assert trainer.strategy.device_iterations == 2 @@ -432,10 +432,10 @@ def test_manual_poptorch_opts_custom(tmpdir): class TestCallback(Callback): def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: # ensure dataloaders were correctly set up during training. - plugin = trainer.strategy - assert isinstance(plugin, IPUStrategy) - assert plugin.training_opts.replication_factor == 2 - assert plugin.inference_opts.replication_factor == 1 + strategy = trainer.strategy + assert isinstance(strategy, IPUStrategy) + assert strategy.training_opts.replication_factor == 2 + assert strategy.inference_opts.replication_factor == 1 val_dataloader = trainer.val_dataloaders[0] train_dataloader = trainer.train_dataloader @@ -446,21 +446,21 @@ def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: assert train_dataloader.options.replication_factor == 2 assert val_dataloader.options.replication_factor == 1 - plugin = IPUStrategy(inference_opts=inference_opts, training_opts=training_opts) + strategy = IPUStrategy(inference_opts=inference_opts, training_opts=training_opts) # ensure we default to the training options replication factor - assert plugin.replication_factor == 2 - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy=plugin, callbacks=TestCallback()) + assert strategy.replication_factor == 2 + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, strategy=strategy, callbacks=TestCallback()) trainer.fit(model) - plugin = trainer.strategy - assert isinstance(plugin, IPUStrategy) + strategy = trainer.strategy + assert isinstance(strategy, IPUStrategy) - training_opts = plugin.training_opts + training_opts = strategy.training_opts assert training_opts.device_iterations == 8 assert training_opts.replication_factor == 2 assert training_opts.Training.gradient_accumulation == 2 - inference_opts = plugin.inference_opts + inference_opts = strategy.inference_opts assert inference_opts.device_iterations == 16 assert inference_opts.replication_factor == 1 assert inference_opts.Training.gradient_accumulation == 1 @@ -471,8 +471,8 @@ def test_replication_factor(tmpdir): """Ensure if the user passes manual poptorch Options with custom parameters set, we set them correctly in the dataloaders.""" - plugin = IPUStrategy() - trainer = Trainer(accelerator="ipu", devices=2, default_root_dir=tmpdir, fast_dev_run=True, strategy=plugin) + strategy = IPUStrategy() + trainer = Trainer(accelerator="ipu", devices=2, default_root_dir=tmpdir, fast_dev_run=True, strategy=strategy) assert isinstance(trainer.accelerator, IPUAccelerator) assert trainer.num_devices == 2 assert trainer.strategy.replication_factor == 2 @@ -482,11 +482,11 @@ def test_replication_factor(tmpdir): inference_opts = poptorch.Options() training_opts.replicationFactor(8) inference_opts.replicationFactor(7) - plugin = IPUStrategy(inference_opts=inference_opts, training_opts=training_opts) + strategy = IPUStrategy(inference_opts=inference_opts, training_opts=training_opts) - trainer = Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1, strategy=plugin) + trainer = Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1, strategy=strategy) trainer.optimizers = model.configure_optimizers()[0] - plugin.model = model + strategy.model = model model.trainer = trainer trainer.state.fn = TrainerFn.FITTING trainer.strategy.setup(trainer) @@ -596,13 +596,13 @@ def test_set_devices_if_none_ipu(): @RunIf(ipu=True) -def test_strategy_choice_ipu_plugin(tmpdir): +def test_strategy_choice_ipu_strategy(): trainer = Trainer(strategy=IPUStrategy(), accelerator="ipu", devices=8) assert isinstance(trainer.strategy, IPUStrategy) @RunIf(ipu=True) -def test_device_type_when_training_plugin_ipu_passed(tmpdir): +def test_device_type_when_ipu_strategy_passed(): trainer = Trainer(strategy=IPUStrategy(), accelerator="ipu", devices=8) assert isinstance(trainer.strategy, IPUStrategy) assert isinstance(trainer.accelerator, IPUAccelerator) @@ -610,11 +610,11 @@ def test_device_type_when_training_plugin_ipu_passed(tmpdir): @RunIf(ipu=True) def test_poptorch_models_at_different_stages(tmpdir): - plugin = IPUStrategy() - trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, accelerator="ipu", devices=8) + strategy = IPUStrategy() + trainer = Trainer(default_root_dir=tmpdir, strategy=strategy, accelerator="ipu", devices=8) model = BoringModel() model.trainer = trainer - plugin.model = model + strategy.model = model trainer.optimizers = model.configure_optimizers()[0] trainer.state.fn = TrainerFn.FITTING From b7211a6a00250068716f320c8cb3ecfe23b2a5f5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 04:35:27 +0200 Subject: [PATCH 24/53] fix --- tests/tests_pytorch/accelerators/test_ipu.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/accelerators/test_ipu.py b/tests/tests_pytorch/accelerators/test_ipu.py index e03827e3021d9..f4c13f10f068d 100644 --- a/tests/tests_pytorch/accelerators/test_ipu.py +++ b/tests/tests_pytorch/accelerators/test_ipu.py @@ -99,7 +99,7 @@ def test_epoch_end(self, outputs) -> None: @pytest.mark.skipif(_IPU_AVAILABLE, reason="test requires non-IPU machine") @mock.patch("pytorch_lightning.accelerators.ipu.IPUAccelerator.is_available", return_value=True) -def test_fail_if_no_ipus(mock_ipu_acc_avail, tmpdir): +def test_fail_if_no_ipus(_, tmpdir): with pytest.raises(MisconfigurationException, match="IPU Accelerator requires IPU devices to run"): Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1) @@ -486,7 +486,7 @@ def test_replication_factor(tmpdir): trainer = Trainer(default_root_dir=tmpdir, accelerator="ipu", devices=1, strategy=strategy) trainer.optimizers = model.configure_optimizers()[0] - strategy.model = model + strategy._lightning_module = model model.trainer = trainer trainer.state.fn = TrainerFn.FITTING trainer.strategy.setup(trainer) @@ -541,7 +541,7 @@ def configure_optimizers(self): @RunIf(ipu=True) -def test_precision_plugin(tmpdir): +def test_precision_plugin(): """Ensure precision plugin value is set correctly.""" plugin = IPUPrecisionPlugin(precision=16) @@ -614,7 +614,7 @@ def test_poptorch_models_at_different_stages(tmpdir): trainer = Trainer(default_root_dir=tmpdir, strategy=strategy, accelerator="ipu", devices=8) model = BoringModel() model.trainer = trainer - strategy.model = model + strategy._lightning_module = model trainer.optimizers = model.configure_optimizers()[0] trainer.state.fn = TrainerFn.FITTING From 3a7e8859944f68038bd933e1d29d04a19ad29ebc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 05:41:10 +0200 Subject: [PATCH 25/53] clear model reference on connect() --- src/pytorch_lightning/strategies/strategy.py | 2 +- src/pytorch_lightning/trainer/trainer.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/pytorch_lightning/strategies/strategy.py b/src/pytorch_lightning/strategies/strategy.py index 6656b0d069b0f..7f359e1cb601f 100644 --- a/src/pytorch_lightning/strategies/strategy.py +++ b/src/pytorch_lightning/strategies/strategy.py @@ -116,6 +116,7 @@ def optimizers(self, optimizers: List[Optimizer]) -> None: def connect(self, model: "pl.LightningModule") -> None: """Called by the accelerator to connect the accelerator and the model with this plugin.""" self._lightning_module = model + self.model = model def _configure_launcher(self) -> None: """Attach the launcher based on Strategy.""" @@ -467,7 +468,6 @@ def teardown(self) -> None: if self.lightning_module is not None: log.detail(f"{self.__class__.__name__}: moving model to CPU") self.lightning_module.cpu() - self.model = self.lightning_module self.precision_plugin.teardown() assert self.accelerator is not None self.accelerator.teardown() diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 3ef638a71ed18..f74217813e48c 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -659,8 +659,6 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: if not self.interrupted: self.state.status = TrainerStatus.INTERRUPTED self._call_callback_hooks("on_exception", exception) - # TODO: Do we need teardown here? What if we get keyboard interrupt and the model remains wrapped? - # Will a subsequent call wrap the model again (and fail)? except BaseException as exception: self.state.status = TrainerStatus.INTERRUPTED if distributed_available() and self.world_size > 1: From be8abf7dc4cd45777d7c99e978f957ee48f9029a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 05:41:19 +0200 Subject: [PATCH 26/53] fix tests --- tests/tests_pytorch/strategies/test_sharded_strategy.py | 8 ++++---- .../trainer/connectors/test_callback_connector.py | 8 ++++---- tests/tests_pytorch/trainer/flags/test_overfit_batches.py | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_sharded_strategy.py b/tests/tests_pytorch/strategies/test_sharded_strategy.py index 6d0c34b734689..e0bb4ee5d9c37 100644 --- a/tests/tests_pytorch/strategies/test_sharded_strategy.py +++ b/tests/tests_pytorch/strategies/test_sharded_strategy.py @@ -271,8 +271,8 @@ def test_configure_ddp(tmpdir): def test_custom_kwargs_sharded(tmpdir, cls): """Tests to ensure that if custom kwargs are passed, they are set correctly.""" strategy = cls(reduce_fp16=True) - strategy.model = Mock(spec=LightningModule) - strategy.model.trainer = Mock() + strategy._lightning_module = Mock(spec=LightningModule) + strategy._lightning_module.trainer = Mock() strategy.parallel_devices = [Mock()] class_name = "sharded" if isinstance(strategy, DDPShardedStrategy) else "sharded_spawn" @@ -291,8 +291,8 @@ def test_custom_kwargs_sharded_reduce_buffer_size(tmpdir, params, expected_buffe """Tests to ensure that ``reduce_buffer_size`` is correctly set based on user kwargs.""" strategy = DDPShardedStrategy(**params) strategy.num_nodes = num_nodes - strategy.model = Mock(spec=LightningModule) - strategy.model.trainer = Mock() + strategy._lightning_module = Mock(spec=LightningModule) + strategy._lightning_module.trainer = Mock() strategy.parallel_devices = [Mock()] with mock.patch("pytorch_lightning.strategies.sharded.ShardedDataParallel", autospec=True) as mock_sharded: diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index d6d5018aa1dd0..0d97cce6a29f4 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -56,7 +56,7 @@ def test_checkpoint_callbacks_are_last(tmpdir): # no model callbacks model = LightningModule() model.configure_callbacks = lambda: [] - trainer.model = model + trainer.strategy._lightning_module = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [ @@ -72,7 +72,7 @@ def test_checkpoint_callbacks_are_last(tmpdir): model = LightningModule() model.configure_callbacks = lambda: [checkpoint1, early_stopping, model_summary, checkpoint2] trainer = Trainer(callbacks=[progress_bar, lr_monitor, ModelCheckpoint(tmpdir)]) - trainer.model = model + trainer.strategy._lightning_module = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() assert trainer.callbacks == [ @@ -154,7 +154,7 @@ def _attach_callbacks(trainer_callbacks, model_callbacks): enable_model_summary=False, callbacks=trainer_callbacks, ) - trainer.model = model + trainer.strategy._lightning_module = model cb_connector = CallbackConnector(trainer) cb_connector._attach_model_callbacks() return trainer @@ -212,7 +212,7 @@ def test_attach_model_callbacks_override_info(caplog): trainer = Trainer( enable_checkpointing=False, callbacks=[EarlyStopping(monitor="foo"), LearningRateMonitor(), TQDMProgressBar()] ) - trainer.model = model + trainer.strategy._lightning_module = model cb_connector = CallbackConnector(trainer) with caplog.at_level(logging.INFO): cb_connector._attach_model_callbacks() diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index 32f0b8938caf6..da3e154349e1b 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -142,7 +142,7 @@ def test_distributed_sampler_with_overfit_batches(): strategy="ddp_spawn", ) model.trainer = trainer - trainer.model = model + trainer.strategy._lightning_module = model trainer._data_connector.attach_dataloaders(model) trainer.reset_train_dataloader() train_sampler = trainer.train_dataloader.loaders.sampler From bdf407c6344a47b45370ad87abf150c7aae5ad80 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 2 Aug 2022 16:01:19 +0200 Subject: [PATCH 27/53] resolve merge conflict --- src/pytorch_lightning/strategies/deepspeed.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 7b0079998f8bc..e5e7a482695f4 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -478,6 +478,7 @@ def init_deepspeed(self) -> None: "DeepSpeed currently does not support different `accumulate_grad_batches` at different epochs." ) + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) model = LightningDeepSpeedModule(forward_module=self.model, precision=self.precision_plugin.precision) if self.lightning_module.trainer and self.lightning_module.trainer.training: From f353b9b0d507784cc16ba59b9a6ebe19f2055d95 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 2 Aug 2022 20:30:42 +0200 Subject: [PATCH 28/53] fix property --- src/pytorch_lightning/strategies/deepspeed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index e5e7a482695f4..eaa6d7c684d8c 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -605,7 +605,6 @@ def _initialize_deepspeed_inference(self, model: Module) -> None: ) self.model = model - @property @property def distributed_sampler_kwargs(self) -> Dict[str, int]: distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) From b7baa82a20454f38948e95ebb0b06ac6cb1ee8e7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 2 Aug 2022 22:52:34 +0200 Subject: [PATCH 29/53] fix attribute error --- src/pytorch_lightning/strategies/tpu_spawn.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 62bb1c308480b..5c5dc3b147950 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -139,11 +139,11 @@ def setup(self, trainer: "pl.Trainer") -> None: if self.debug: os.environ["PT_XLA_DEBUG"] = "1" - assert self.model - shared_params = find_shared_parameters(self.model) + assert self.lightning_module + shared_params = find_shared_parameters(self.lightning_module) self.model_to_device() - assert isinstance(self.model.module, Module) - set_shared_parameters(self.model.module, shared_params) + + set_shared_parameters(self.lightning_module, shared_params) self.setup_precision_plugin() if trainer.state.fn == TrainerFn.FITTING: From 4b40060e2ed1016a4f7d4cacfa9939d723b0c55c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 2 Aug 2022 23:29:38 +0200 Subject: [PATCH 30/53] add backward compatibility --- src/pytorch_lightning/overrides/base.py | 18 ++++++++- .../overrides/data_parallel.py | 19 ++++++++-- .../overrides/distributed.py | 11 +++++- src/pytorch_lightning/strategies/bagua.py | 10 ++++- src/pytorch_lightning/strategies/ipu.py | 4 +- .../deprecated_api/test_remove_1-10.py | 37 +++++++++++++++++++ .../overrides/test_data_parallel.py | 6 +-- 7 files changed, 91 insertions(+), 14 deletions(-) create mode 100644 tests/tests_pytorch/deprecated_api/test_remove_1-10.py diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 18ac438d3ab94..8002262f9942b 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union +from typing import Any, Optional, Union import torch import pytorch_lightning as pl from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin +from pytorch_lightning.utilities import rank_zero_deprecation class _LightningPrecisionModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): @@ -102,3 +103,18 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: if trainer.predicting: return self._forward_module.predict_step(*inputs, **kwargs) return self._forward_module(*inputs, **kwargs) + + @classmethod + def _validate_init_arguments( + cls, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + ): + # TODO: In v1.10, remove this method and mark the forward_module init argument in all subclasses as required + if pl_module is not None: + rank_zero_deprecation( + f"The argument `pl_module` in `{cls.__name__}` is deprecated in v1.8 and will be removed in" + " v1.10. Please use `forward_module` instead." + ) + elif forward_module is None: + raise ValueError("Argument `forward_module` is required.") diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index e1914020054c4..e3681373473e1 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -13,7 +13,7 @@ # limitations under the License. import numbers import warnings -from typing import Any, cast, Union +from typing import Any, cast, Optional, Union import torch from torch import Tensor @@ -21,7 +21,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.rank_zero import rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn def _ignore_scalar_return_in_dp() -> None: @@ -52,12 +52,23 @@ class LightningParallelModule(_LightningModuleWrapperBase): ) Args: + pl_module: The module to wrap. See description for `forward_module`. + + .. deprecated:: v1.6 + The argument `pl_module` is deprecated in v1.8 and will be removed in v1.10. Please use + `forward_module` instead. + forward_module: The module to wrap. If it's not a LightningModule, it must have an attribute ``.module`` pointing to a LightningModule reference. """ - def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: - super().__init__(forward_module) + def __init__( + self, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + ) -> None: + self._validate_init_arguments(pl_module, forward_module) + super().__init__(forward_module=(pl_module or forward_module)) _ignore_scalar_return_in_dp() def forward(self, *inputs: Any, **kwargs: Any) -> Any: diff --git a/src/pytorch_lightning/overrides/distributed.py b/src/pytorch_lightning/overrides/distributed.py index f09a7b9e3ae08..145ebb319b132 100644 --- a/src/pytorch_lightning/overrides/distributed.py +++ b/src/pytorch_lightning/overrides/distributed.py @@ -19,12 +19,19 @@ from torch.nn.parallel import DistributedDataParallel from torch.utils.data import BatchSampler, Dataset, DistributedSampler, Sampler -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase +import pytorch_lightning as pl +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.utilities.exceptions import MisconfigurationException class LightningDistributedModule(_LightningModuleWrapperBase): - ... + def __init__( + self, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + ) -> None: + self._validate_init_arguments(pl_module, forward_module) + super().__init__(forward_module=(pl_module or forward_module)) def _find_tensors( diff --git a/src/pytorch_lightning/strategies/bagua.py b/src/pytorch_lightning/strategies/bagua.py index f19f156adf47a..0c731d57bc904 100644 --- a/src/pytorch_lightning/strategies/bagua.py +++ b/src/pytorch_lightning/strategies/bagua.py @@ -50,8 +50,14 @@ class LightningBaguaModule(_LightningModuleWrapperBase): - def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: - super().__init__(forward_module) + def __init__( + self, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + ) -> None: + self._validate_init_arguments(pl_module, forward_module) + forward_module = pl_module or forward_module + super().__init__(forward_module=forward_module) # Bagua use `bagua_module_name` to distinguish different modules self._bagua_module_name = f"{forward_module.__class__.__name__}{id(forward_module)}" diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 2de4535f4fd13..efa8521b625a3 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -45,11 +45,11 @@ class LightningIPUModule(_LightningModuleWrapperBase): def __init__( self, - forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], + pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int], ) -> None: rank_zero_deprecation("`LightningIPUModule` has been deprecated in v1.7.0 and will be removed in v1.8.0") - super().__init__(forward_module) + super().__init__(pl_module) self.precision = precision def forward(self, *inputs: Any, **kwargs: Any) -> Any: diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py new file mode 100644 index 0000000000000..b6b2bf71f2b4e --- /dev/null +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -0,0 +1,37 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test deprecated functionality which will be removed in v1.10.0.""" +from unittest.mock import Mock + +import pytest + +from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.overrides import LightningDistributedModule, LightningParallelModule +from pytorch_lightning.strategies.bagua import LightningBaguaModule + + +@pytest.mark.parametrize( + "wrapper_class", + [ + LightningParallelModule, + LightningDistributedModule, + LightningBaguaModule, + ], +) +def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class): + with pytest.deprecated_call(match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8"): + wrapper_class(BoringModel()) + + with pytest.deprecated_call(match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8"): + wrapper_class(pl_module=BoringModel()) diff --git a/tests/tests_pytorch/overrides/test_data_parallel.py b/tests/tests_pytorch/overrides/test_data_parallel.py index 68f625a427cef..4f04835e2ea30 100644 --- a/tests/tests_pytorch/overrides/test_data_parallel.py +++ b/tests/tests_pytorch/overrides/test_data_parallel.py @@ -45,7 +45,7 @@ def test_lightning_wrapper_module_methods(wrapper_class, stage): pl_module = Mock(spec=LightningModule) trainer = Mock() pl_module._trainer = trainer - wrapped_module = wrapper_class(pl_module) + wrapped_module = wrapper_class(forward_module=pl_module) batch = torch.rand(5) batch_idx = 3 @@ -171,7 +171,7 @@ def training_step(self, batch, batch_idx): trainer.state.stage = RunningStage.TRAINING root_device = torch.device("cuda", 0) - wrapped_module = LightningParallelModule(pl_module).to(root_device) + wrapped_module = LightningParallelModule(forward_module=pl_module).to(root_device) model = DataParallel(wrapped_module, device_ids=[0, 1]) data = torch.tensor([0.0, 1.0], device=root_device).view(2, 1) # one value per gpu @@ -197,7 +197,7 @@ def training_step(self, batch, batch_idx): pl_module.trainer = trainer trainer.state.stage = RunningStage.TRAINING - wrapped_module = LightningParallelModule(pl_module).cuda() + wrapped_module = LightningParallelModule(forward_module=pl_module).cuda() model = DataParallel(wrapped_module, device_ids=[0, 1]) data = dict(x=1) # contains no tensors From 4147fa7356d9df80cfd54cbc0743f01455ac2b55 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 2 Aug 2022 23:33:19 +0200 Subject: [PATCH 31/53] undo changes in ipu --- src/pytorch_lightning/strategies/ipu.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index efa8521b625a3..82ba4ad227f7c 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -44,9 +44,7 @@ class LightningIPUModule(_LightningModuleWrapperBase): def __init__( - self, - pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], - precision: Union[str, int], + self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int] ) -> None: rank_zero_deprecation("`LightningIPUModule` has been deprecated in v1.7.0 and will be removed in v1.8.0") super().__init__(pl_module) From e1c5cbdf64c18b1aafff244323bdb436aa85ce54 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 2 Aug 2022 23:37:55 +0200 Subject: [PATCH 32/53] include deepspeed --- src/pytorch_lightning/strategies/deepspeed.py | 8 +++++--- tests/tests_pytorch/deprecated_api/test_remove_1-10.py | 3 ++- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index eaa6d7c684d8c..c8b5e151b20a0 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -72,10 +72,12 @@ def remove_module_hooks(model: torch.nn.Module) -> None: class LightningDeepSpeedModule(_LightningModuleWrapperBase): def __init__( self, - forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], - precision: Union[str, int], + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + precision: Union[str, int] = 32, ) -> None: - super().__init__(forward_module) + self._validate_init_arguments(pl_module, forward_module) + super().__init__(forward_module=(pl_module or forward_module)) self.precision = precision def forward(self, *inputs: Any, **kwargs: Any) -> Any: diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index b6b2bf71f2b4e..4bb35c8e0c75e 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test deprecated functionality which will be removed in v1.10.0.""" -from unittest.mock import Mock import pytest from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.overrides import LightningDistributedModule, LightningParallelModule from pytorch_lightning.strategies.bagua import LightningBaguaModule +from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule @pytest.mark.parametrize( @@ -27,6 +27,7 @@ LightningParallelModule, LightningDistributedModule, LightningBaguaModule, + LightningDeepSpeedModule, ], ) def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class): From bb4f8ccab9e97fd6d7f16524f3e5a5fe7fae9aa8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 2 Aug 2022 23:43:03 +0200 Subject: [PATCH 33/53] mypy --- src/pytorch_lightning/overrides/base.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 8002262f9942b..76999dc10e216 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -52,7 +52,9 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): - def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None: + def __init__( + self, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] + ) -> None: """Wraps the user's LightningModule and redirects the forward call to the appropriate method, either ``training_step``, ``validation_step``, ``test_step``, or ``predict_step``. @@ -70,6 +72,8 @@ def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisi "`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one," f" got: {forward_module.__class__.__qualname__}" ) + # TODO: In v1.10, remove the Optional type from forward_module and remove the assertion + assert forward_module is not None self._forward_module = forward_module # set the parameters_to_ignore from LightningModule. From 0f5d458c6b9b2d5aa5af918f178f91e3c065d260 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 2 Aug 2022 23:55:04 +0200 Subject: [PATCH 34/53] mypy --- src/pytorch_lightning/strategies/tpu_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/tpu_spawn.py b/src/pytorch_lightning/strategies/tpu_spawn.py index 5c5dc3b147950..5ca8db74c4620 100644 --- a/src/pytorch_lightning/strategies/tpu_spawn.py +++ b/src/pytorch_lightning/strategies/tpu_spawn.py @@ -124,7 +124,7 @@ def _validate_patched_dataloaders(model: "pl.LightningModule") -> None: assert not isinstance(source.instance, (pl.LightningModule, pl.LightningDataModule)) TPUSpawnStrategy._validate_dataloader(source.instance) - def connect(self, model: "pl.LightningModule") -> None: # type: ignore + def connect(self, model: "pl.LightningModule") -> None: TPUSpawnStrategy._validate_patched_dataloaders(model) self.wrapped_model = xmp.MpModelWrapper(LightningDistributedModule(model)) return super().connect(model) From 0b0df7635aac6a482da9314bd8318122f1fd4420 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 3 Aug 2022 00:03:29 +0200 Subject: [PATCH 35/53] deprecate unwrap function --- src/pytorch_lightning/overrides/base.py | 29 +++++++++++++++++++ .../overrides/data_parallel.py | 2 +- .../deprecated_api/test_remove_1-10.py | 6 ++++ 3 files changed, 36 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 76999dc10e216..385e5de288eaa 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -14,6 +14,9 @@ from typing import Any, Optional, Union import torch +import torch.nn as nn +from torch.nn import DataParallel +from torch.nn.parallel import DistributedDataParallel import pytorch_lightning as pl from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin @@ -122,3 +125,29 @@ def _validate_init_arguments( ) elif forward_module is None: raise ValueError("Argument `forward_module` is required.") + + +def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule": + """Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module`` + attributes on the wrapper. + + .. deprecated:: v1.8 + The function `unwrap_lightning_module` is deprecated in v1.8 and will be removed in v1.10. Access the + `LightningModule` directly through the strategy attribute `Strategy.lightning_module`. + + Raises: + TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped + further. + """ + rank_zero_deprecation( + "The function `unwrap_lightning_module` is deprecated in v1.8 and will be removed in v1.10. Access the" + " `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." + ) + model = wrapped_model + if isinstance(model, (DistributedDataParallel, DataParallel)): + model = unwrap_lightning_module(model.module) + if isinstance(model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)): + model = unwrap_lightning_module(model.module) + if not isinstance(model, pl.LightningModule): + raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.") + return model diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index e3681373473e1..bd2f983eeff60 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -54,7 +54,7 @@ class LightningParallelModule(_LightningModuleWrapperBase): Args: pl_module: The module to wrap. See description for `forward_module`. - .. deprecated:: v1.6 + .. deprecated:: v1.8 The argument `pl_module` is deprecated in v1.8 and will be removed in v1.10. Please use `forward_module` instead. diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index 4bb35c8e0c75e..a90a765d0cfae 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -17,6 +17,7 @@ from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.overrides import LightningDistributedModule, LightningParallelModule +from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.strategies.bagua import LightningBaguaModule from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule @@ -36,3 +37,8 @@ def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class): with pytest.deprecated_call(match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8"): wrapper_class(pl_module=BoringModel()) + + +def test_v1_10_deprecated_unwrap_lightning_module(): + with pytest.deprecated_call(match=r"The function `unwrap_lightning_module` is deprecated in v1.8"): + unwrap_lightning_module(BoringModel()) From 5485099f7ef08ca1510b77952c73512cd8e789be Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 3 Aug 2022 00:26:37 +0200 Subject: [PATCH 36/53] fairscale backward compat --- src/pytorch_lightning/overrides/base.py | 13 ++--- .../overrides/data_parallel.py | 2 +- src/pytorch_lightning/overrides/fairscale.py | 52 +++++++++++++++++++ .../deprecated_api/test_remove_1-10.py | 9 ++++ 4 files changed, 69 insertions(+), 7 deletions(-) create mode 100644 src/pytorch_lightning/overrides/fairscale.py diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 385e5de288eaa..f3f9b1b063805 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -116,7 +116,7 @@ def _validate_init_arguments( cls, pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - ): + ) -> None: # TODO: In v1.10, remove this method and mark the forward_module init argument in all subclasses as required if pl_module is not None: rank_zero_deprecation( @@ -127,7 +127,7 @@ def _validate_init_arguments( raise ValueError("Argument `forward_module` is required.") -def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule": +def unwrap_lightning_module(wrapped_model: nn.Module, _suppress_warning: bool = False) -> "pl.LightningModule": """Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module`` attributes on the wrapper. @@ -139,10 +139,11 @@ def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule": TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped further. """ - rank_zero_deprecation( - "The function `unwrap_lightning_module` is deprecated in v1.8 and will be removed in v1.10. Access the" - " `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." - ) + if not _suppress_warning: + rank_zero_deprecation( + "The function `unwrap_lightning_module` is deprecated in v1.8 and will be removed in v1.10. Access the" + " `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." + ) model = wrapped_model if isinstance(model, (DistributedDataParallel, DataParallel)): model = unwrap_lightning_module(model.module) diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index bd2f983eeff60..1d258d2c9a8a3 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -21,7 +21,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn +from pytorch_lightning.utilities.rank_zero import rank_zero_warn def _ignore_scalar_return_in_dp() -> None: diff --git a/src/pytorch_lightning/overrides/fairscale.py b/src/pytorch_lightning/overrides/fairscale.py new file mode 100644 index 0000000000000..52f123709f2b1 --- /dev/null +++ b/src/pytorch_lightning/overrides/fairscale.py @@ -0,0 +1,52 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Union + +import torch.nn as nn + +import pytorch_lightning as pl +from pytorch_lightning.overrides.base import ( + _LightningModuleWrapperBase, + _LightningPrecisionModuleWrapperBase, + unwrap_lightning_module, +) +from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE + +if _FAIRSCALE_AVAILABLE: # pragma: no-cover + from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel + + class LightningShardedDataParallel(_LightningModuleWrapperBase): + def __init__( + self, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + ) -> None: + self._validate_init_arguments(pl_module, forward_module) + super().__init__(forward_module=(pl_module or forward_module)) + + def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule": + rank_zero_deprecation( + "The function `unwrap_lightning_module_sharded` is deprecated in v1.8 and will be removed in v1.10. Access" + " the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." + ) + model = wrapped_model + if isinstance(model, ShardedDataParallel): + model = model.module + + return unwrap_lightning_module(model, _suppress_warning=True) + +else: + LightningShardedDataParallel = ... # type: ignore[assignment,misc] + unwrap_lightning_module_sharded = ... # type: ignore[assignment] diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index a90a765d0cfae..7474eb23f8f48 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -18,8 +18,10 @@ from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.overrides import LightningDistributedModule, LightningParallelModule from pytorch_lightning.overrides.base import unwrap_lightning_module +from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded from pytorch_lightning.strategies.bagua import LightningBaguaModule from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule +from tests_pytorch.helpers.runif import RunIf @pytest.mark.parametrize( @@ -29,6 +31,7 @@ LightningDistributedModule, LightningBaguaModule, LightningDeepSpeedModule, + pytest.param(LightningShardedDataParallel, marks=RunIf(fairscale=True)), ], ) def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class): @@ -42,3 +45,9 @@ def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class): def test_v1_10_deprecated_unwrap_lightning_module(): with pytest.deprecated_call(match=r"The function `unwrap_lightning_module` is deprecated in v1.8"): unwrap_lightning_module(BoringModel()) + + +@RunIf(fairscale=True) +def test_v1_10_deprecated_unwrap_lightning_module_sharded(): + with pytest.deprecated_call(match=r"The function `unwrap_lightning_module_sharded` is deprecated in v1.8"): + unwrap_lightning_module_sharded(BoringModel()) From e6f7e90fba41574e1defd6dbccd0a90833f53080 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 3 Aug 2022 01:01:09 +0200 Subject: [PATCH 37/53] flip position --- src/pytorch_lightning/overrides/data_parallel.py | 2 +- src/pytorch_lightning/overrides/distributed.py | 2 +- src/pytorch_lightning/overrides/fairscale.py | 2 +- src/pytorch_lightning/strategies/bagua.py | 2 +- src/pytorch_lightning/strategies/deepspeed.py | 2 +- tests/tests_pytorch/deprecated_api/test_remove_1-10.py | 5 ++++- 6 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index 1d258d2c9a8a3..e1742c8a75682 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -64,8 +64,8 @@ class LightningParallelModule(_LightningModuleWrapperBase): def __init__( self, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, ) -> None: self._validate_init_arguments(pl_module, forward_module) super().__init__(forward_module=(pl_module or forward_module)) diff --git a/src/pytorch_lightning/overrides/distributed.py b/src/pytorch_lightning/overrides/distributed.py index 145ebb319b132..0130eac28c3ef 100644 --- a/src/pytorch_lightning/overrides/distributed.py +++ b/src/pytorch_lightning/overrides/distributed.py @@ -27,8 +27,8 @@ class LightningDistributedModule(_LightningModuleWrapperBase): def __init__( self, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, ) -> None: self._validate_init_arguments(pl_module, forward_module) super().__init__(forward_module=(pl_module or forward_module)) diff --git a/src/pytorch_lightning/overrides/fairscale.py b/src/pytorch_lightning/overrides/fairscale.py index 52f123709f2b1..59989b4e63f03 100644 --- a/src/pytorch_lightning/overrides/fairscale.py +++ b/src/pytorch_lightning/overrides/fairscale.py @@ -30,8 +30,8 @@ class LightningShardedDataParallel(_LightningModuleWrapperBase): def __init__( self, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, ) -> None: self._validate_init_arguments(pl_module, forward_module) super().__init__(forward_module=(pl_module or forward_module)) diff --git a/src/pytorch_lightning/strategies/bagua.py b/src/pytorch_lightning/strategies/bagua.py index 0c731d57bc904..f08d1aebf1b7c 100644 --- a/src/pytorch_lightning/strategies/bagua.py +++ b/src/pytorch_lightning/strategies/bagua.py @@ -52,8 +52,8 @@ class LightningBaguaModule(_LightningModuleWrapperBase): def __init__( self, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, ) -> None: self._validate_init_arguments(pl_module, forward_module) forward_module = pl_module or forward_module diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index c8b5e151b20a0..50587a5e338b4 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -72,8 +72,8 @@ def remove_module_hooks(model: torch.nn.Module) -> None: class LightningDeepSpeedModule(_LightningModuleWrapperBase): def __init__( self, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, precision: Union[str, int] = 32, ) -> None: self._validate_init_arguments(pl_module, forward_module) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index 7474eb23f8f48..d55da324cd50e 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -22,6 +22,7 @@ from pytorch_lightning.strategies.bagua import LightningBaguaModule from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule from tests_pytorch.helpers.runif import RunIf +from tests_pytorch.helpers.utils import no_warning_call @pytest.mark.parametrize( @@ -35,7 +36,9 @@ ], ) def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class): - with pytest.deprecated_call(match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8"): + with no_warning_call( + DeprecationWarning, + match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8"): wrapper_class(BoringModel()) with pytest.deprecated_call(match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8"): From 96b547de855b9028246c202ea9731ea8710aa424 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 3 Aug 2022 01:02:54 +0200 Subject: [PATCH 38/53] update --- tests/tests_pytorch/overrides/test_data_parallel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/overrides/test_data_parallel.py b/tests/tests_pytorch/overrides/test_data_parallel.py index 4f04835e2ea30..68f625a427cef 100644 --- a/tests/tests_pytorch/overrides/test_data_parallel.py +++ b/tests/tests_pytorch/overrides/test_data_parallel.py @@ -45,7 +45,7 @@ def test_lightning_wrapper_module_methods(wrapper_class, stage): pl_module = Mock(spec=LightningModule) trainer = Mock() pl_module._trainer = trainer - wrapped_module = wrapper_class(forward_module=pl_module) + wrapped_module = wrapper_class(pl_module) batch = torch.rand(5) batch_idx = 3 @@ -171,7 +171,7 @@ def training_step(self, batch, batch_idx): trainer.state.stage = RunningStage.TRAINING root_device = torch.device("cuda", 0) - wrapped_module = LightningParallelModule(forward_module=pl_module).to(root_device) + wrapped_module = LightningParallelModule(pl_module).to(root_device) model = DataParallel(wrapped_module, device_ids=[0, 1]) data = torch.tensor([0.0, 1.0], device=root_device).view(2, 1) # one value per gpu @@ -197,7 +197,7 @@ def training_step(self, batch, batch_idx): pl_module.trainer = trainer trainer.state.stage = RunningStage.TRAINING - wrapped_module = LightningParallelModule(forward_module=pl_module).cuda() + wrapped_module = LightningParallelModule(pl_module).cuda() model = DataParallel(wrapped_module, device_ids=[0, 1]) data = dict(x=1) # contains no tensors From de2715f4a793900c08b02fb4dd1cd484ca855fd0 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 3 Aug 2022 01:03:40 +0200 Subject: [PATCH 39/53] format --- tests/tests_pytorch/deprecated_api/test_remove_1-10.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index d55da324cd50e..3bbe847491c8e 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -37,8 +37,8 @@ ) def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class): with no_warning_call( - DeprecationWarning, - match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8"): + DeprecationWarning, match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8" + ): wrapper_class(BoringModel()) with pytest.deprecated_call(match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8"): From 0564cd97819aca20912faab6288a1149ccc605ac Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 3 Aug 2022 01:15:29 +0200 Subject: [PATCH 40/53] changelog placeholders --- src/pytorch_lightning/CHANGELOG.md | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 89fa726922a40..7ec943e5c4a31 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). + +### Changed + +- Replaced the unwrapping logic in strategies with direct access to unwrapped `LightningModule` ([#13738](https://github.com/Lightning-AI/lightning/pull/13738)) + + +### Deprecated + +- Deprecated the `unwrap_lightning_module` and `unwrap_lightning_module_sharded` utility functions in favor of accessing the unwrapped `LightningModule` on the strategy directly ([#13738](https://github.com/Lightning-AI/lightning/pull/13738)) +- Deprecated the `pl_module` argument in `LightningParallelModule`, `LightningDistributedModule`, `LightningShardedDataParallel`, `LightningBaguaModule` and `LightningDeepSpeedModule` wrapper classes ([#13738](https://github.com/Lightning-AI/lightning/pull/13738)) + + ## [1.7.0] - 2022-08-02 ### Added From c87998c326ae28e17125d67005387bd23a01e349 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 5 Aug 2022 12:15:01 +0200 Subject: [PATCH 41/53] chlog --- src/pytorch_lightning/CHANGELOG.md | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 1da64074845ec..4b2a44545a0a8 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -22,12 +22,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Raised a `MisconfigurationException` if batch transfer hooks are overriden with `IPUAccelerator` ([13961](https://github.com/Lightning-AI/lightning/pull/13961)) +- Replaced the unwrapping logic in strategies with direct access to unwrapped `LightningModule` ([#13738](https://github.com/Lightning-AI/lightning/pull/13738)) + + ### Deprecated - Deprecated `amp_level` from `Trainer` in favour of passing it explictly via precision plugin ([#13898](https://github.com/Lightning-AI/lightning/pull/13898)) -- +- Deprecated the `unwrap_lightning_module` and `unwrap_lightning_module_sharded` utility functions in favor of accessing the unwrapped `LightningModule` on the strategy directly ([#13738](https://github.com/Lightning-AI/lightning/pull/13738)) + + +- Deprecated the `pl_module` argument in `LightningParallelModule`, `LightningDistributedModule`, `LightningShardedDataParallel`, `LightningBaguaModule` and `LightningDeepSpeedModule` wrapper classes ([#13738](https://github.com/Lightning-AI/lightning/pull/13738)) + ### Removed @@ -56,17 +63,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -### Changed - -- Replaced the unwrapping logic in strategies with direct access to unwrapped `LightningModule` ([#13738](https://github.com/Lightning-AI/lightning/pull/13738)) - - -### Deprecated - -- Deprecated the `unwrap_lightning_module` and `unwrap_lightning_module_sharded` utility functions in favor of accessing the unwrapped `LightningModule` on the strategy directly ([#13738](https://github.com/Lightning-AI/lightning/pull/13738)) -- Deprecated the `pl_module` argument in `LightningParallelModule`, `LightningDistributedModule`, `LightningShardedDataParallel`, `LightningBaguaModule` and `LightningDeepSpeedModule` wrapper classes ([#13738](https://github.com/Lightning-AI/lightning/pull/13738)) - - ## [1.7.0] - 2022-08-02 ### Added From 2331296dfd37e7d8a167dd4825c9b45955b89488 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 5 Aug 2022 12:25:58 +0200 Subject: [PATCH 42/53] fix mypy error --- src/pytorch_lightning/overrides/base.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index f3f9b1b063805..123aca579df74 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -147,8 +147,10 @@ def unwrap_lightning_module(wrapped_model: nn.Module, _suppress_warning: bool = model = wrapped_model if isinstance(model, (DistributedDataParallel, DataParallel)): model = unwrap_lightning_module(model.module) - if isinstance(model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)): - model = unwrap_lightning_module(model.module) + if isinstance(model, _LightningModuleWrapperBase): + model = model.lightning_module + if isinstance(model, _LightningPrecisionModuleWrapperBase): + model = model.module if not isinstance(model, pl.LightningModule): raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.") return model From 3399e0e63f89b8c39bb58b5ac9a8e46c42aa12eb Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 5 Aug 2022 12:28:36 +0200 Subject: [PATCH 43/53] update test --- tests/tests_pytorch/overrides/test_base.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/overrides/test_base.py b/tests/tests_pytorch/overrides/test_base.py index 101cf415713d5..32f2f0ba15646 100644 --- a/tests/tests_pytorch/overrides/test_base.py +++ b/tests/tests_pytorch/overrides/test_base.py @@ -13,9 +13,14 @@ # limitations under the License. import pytest import torch +from torch.nn import DataParallel from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase +from pytorch_lightning.overrides.base import ( + _LightningModuleWrapperBase, + _LightningPrecisionModuleWrapperBase, + unwrap_lightning_module, +) @pytest.mark.parametrize("wrapper_class", [_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase]) @@ -25,3 +30,13 @@ def test_wrapper_device_dtype(wrapper_class): wrapped_model.to(dtype=torch.float16) assert model.dtype == torch.float16 + + +def test_unwrap_lightning_module(): + model = BoringModel() + wrapped_model = _LightningPrecisionModuleWrapperBase(model) + wrapped_model = _LightningModuleWrapperBase(wrapped_model) + wrapped_model = DataParallel(wrapped_model) + + with pytest.deprecated_call(match="The function `unwrap_lightning_module` is deprecated in v1.8"): + assert unwrap_lightning_module(wrapped_model) == model From d5f5d69cfa7a3c69074741c68f9d6234a35dd511 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 5 Aug 2022 14:31:12 +0200 Subject: [PATCH 44/53] update import for fairscale --- src/pytorch_lightning/strategies/sharded_spawn.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index b06b7c7c2d51b..1454f638c7748 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -22,13 +22,11 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _IS_WINDOWS, _module_available from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_only -_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") - if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel from fairscale.optim import OSS From d552c9bb57442c99e1f051d8bd5a5283159d7247 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 5 Aug 2022 14:39:28 +0200 Subject: [PATCH 45/53] revert fairscale import refactor --- src/pytorch_lightning/overrides/fairscale.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/overrides/fairscale.py b/src/pytorch_lightning/overrides/fairscale.py index 59989b4e63f03..b995ecaeee4e2 100644 --- a/src/pytorch_lightning/overrides/fairscale.py +++ b/src/pytorch_lightning/overrides/fairscale.py @@ -22,7 +22,10 @@ unwrap_lightning_module, ) from pytorch_lightning.utilities import rank_zero_deprecation -from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities.imports import _IS_WINDOWS, _module_available + +_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn") + if _FAIRSCALE_AVAILABLE: # pragma: no-cover from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel From 6acb768e9d84133dc8e201dcf10e4419f8c28e76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 5 Aug 2022 08:40:23 -0400 Subject: [PATCH 46/53] Update src/pytorch_lightning/CHANGELOG.md MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- src/pytorch_lightning/CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 4b2a44545a0a8..a75c8b1173105 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -62,7 +62,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed MPS device being unrecognized ([#13992](https://github.com/Lightning-AI/lightning/pull/13992)) - ## [1.7.0] - 2022-08-02 ### Added From 3bfb48ef73b8194ba1e21ff0cce984787e8c1a87 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 10 Aug 2022 15:24:06 +0200 Subject: [PATCH 47/53] do same in ipu model as in deepspeed --- src/pytorch_lightning/strategies/deepspeed.py | 2 +- src/pytorch_lightning/strategies/ipu.py | 10 +++++++--- tests/tests_pytorch/deprecated_api/test_remove_1-10.py | 2 ++ 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index 5ee8ac2a28371..4a70eb983fd86 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -79,8 +79,8 @@ class LightningDeepSpeedModule(_LightningModuleWrapperBase): def __init__( self, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, - pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, precision: Union[str, int] = 32, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, ) -> None: rank_zero_deprecation("`LightningDeepSpeedModule` has been deprecated in v1.7.1 and will be removed in v1.9.0") self._validate_init_arguments(pl_module, forward_module) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 4bedbfd6d70fc..4fd44825ca434 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -51,10 +51,14 @@ class LightningIPUModule(_LightningModuleWrapperBase): """ def __init__( - self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase], precision: Union[str, int] + self, + forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, + precision: Union[str, int] = 32, + pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, ) -> None: - rank_zero_deprecation("`LightningIPUModule` has been deprecated in v1.7.0 and will be removed in v1.8.0") - super().__init__(pl_module) + rank_zero_deprecation("`LightningDeepSpeedModule` has been deprecated in v1.7.1 and will be removed in v1.9.0") + self._validate_init_arguments(pl_module, forward_module) + super().__init__(forward_module=(pl_module or forward_module)) self.precision = precision def forward(self, *inputs: Any, **kwargs: Any) -> Any: diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index 926f2f5f1b58c..20dae62595e35 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -21,6 +21,7 @@ from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel, unwrap_lightning_module_sharded from pytorch_lightning.strategies.bagua import LightningBaguaModule from pytorch_lightning.strategies.deepspeed import LightningDeepSpeedModule +from pytorch_lightning.strategies.ipu import LightningIPUModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.utils import no_warning_call @@ -38,6 +39,7 @@ def test_deprecated_amp_level(): LightningBaguaModule, LightningDeepSpeedModule, pytest.param(LightningShardedDataParallel, marks=RunIf(fairscale=True)), + LightningIPUModule, ], ) def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class): From 0756191741c47631511141afba725053d2c2a3d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 10 Aug 2022 09:35:45 -0400 Subject: [PATCH 48/53] Update src/pytorch_lightning/overrides/data_parallel.py Co-authored-by: Rohit Gupta --- src/pytorch_lightning/overrides/data_parallel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index e1742c8a75682..9fff7b174f683 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -55,11 +55,11 @@ class LightningParallelModule(_LightningModuleWrapperBase): pl_module: The module to wrap. See description for `forward_module`. .. deprecated:: v1.8 - The argument `pl_module` is deprecated in v1.8 and will be removed in v1.10. Please use - `forward_module` instead. + The argument ``pl_module`` is deprecated in v1.8.0 and will be removed in v1.10.0. Please use + ``forward_module`` instead. - forward_module: The module to wrap. If it's not a LightningModule, it must have an attribute ``.module`` - pointing to a LightningModule reference. + forward_module: The module to wrap. If it's not a ``LightningModule``, it must have an attribute ``.module`` + pointing to a ``LightningModule`` reference. """ def __init__( From 5102398236a55925f777186934151a4df90f9d13 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 10 Aug 2022 18:14:43 +0200 Subject: [PATCH 49/53] revert --- src/pytorch_lightning/strategies/ipu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 4fd44825ca434..f56c095dc12c1 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -56,7 +56,7 @@ def __init__( precision: Union[str, int] = 32, pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, ) -> None: - rank_zero_deprecation("`LightningDeepSpeedModule` has been deprecated in v1.7.1 and will be removed in v1.9.0") + rank_zero_deprecation("`LightningIPUModule` has been deprecated in v1.7.0 and will be removed in v1.8.0") self._validate_init_arguments(pl_module, forward_module) super().__init__(forward_module=(pl_module or forward_module)) self.precision = precision From 41a98fceea3300d3c5f040e9104346f81ff78817 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 12 Aug 2022 00:56:33 +0200 Subject: [PATCH 50/53] update versions --- src/pytorch_lightning/overrides/base.py | 16 ++++++++-------- src/pytorch_lightning/overrides/data_parallel.py | 2 +- src/pytorch_lightning/overrides/fairscale.py | 4 ++-- .../deprecated_api/test_remove_1-10.py | 10 ++++++---- tests/tests_pytorch/overrides/test_base.py | 2 +- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 123aca579df74..07f30c271b207 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -75,7 +75,7 @@ def __init__( "`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one," f" got: {forward_module.__class__.__qualname__}" ) - # TODO: In v1.10, remove the Optional type from forward_module and remove the assertion + # TODO: In v1.10.0, remove the Optional type from forward_module and remove the assertion assert forward_module is not None self._forward_module = forward_module @@ -117,11 +117,11 @@ def _validate_init_arguments( pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None, ) -> None: - # TODO: In v1.10, remove this method and mark the forward_module init argument in all subclasses as required + # TODO: In v1.10.0, remove this method and mark the forward_module init argument in all subclasses as required if pl_module is not None: rank_zero_deprecation( - f"The argument `pl_module` in `{cls.__name__}` is deprecated in v1.8 and will be removed in" - " v1.10. Please use `forward_module` instead." + f"The argument `pl_module` in `{cls.__name__}` is deprecated in v1.8.0 and will be removed in" + " v1.10.0. Please use `forward_module` instead." ) elif forward_module is None: raise ValueError("Argument `forward_module` is required.") @@ -131,9 +131,9 @@ def unwrap_lightning_module(wrapped_model: nn.Module, _suppress_warning: bool = """Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module`` attributes on the wrapper. - .. deprecated:: v1.8 - The function `unwrap_lightning_module` is deprecated in v1.8 and will be removed in v1.10. Access the - `LightningModule` directly through the strategy attribute `Strategy.lightning_module`. + .. deprecated:: v1.8.0 + The function ``unwrap_lightning_module`` is deprecated in v1.8.0 and will be removed in v1.10.0. Access the + ``LightningModule`` directly through the strategy attribute ``Strategy.lightning_module``. Raises: TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped @@ -141,7 +141,7 @@ def unwrap_lightning_module(wrapped_model: nn.Module, _suppress_warning: bool = """ if not _suppress_warning: rank_zero_deprecation( - "The function `unwrap_lightning_module` is deprecated in v1.8 and will be removed in v1.10. Access the" + "The function `unwrap_lightning_module` is deprecated in v1.8.0 and will be removed in v1.10.0. Access the" " `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." ) model = wrapped_model diff --git a/src/pytorch_lightning/overrides/data_parallel.py b/src/pytorch_lightning/overrides/data_parallel.py index 9fff7b174f683..98d23cee391bc 100644 --- a/src/pytorch_lightning/overrides/data_parallel.py +++ b/src/pytorch_lightning/overrides/data_parallel.py @@ -54,7 +54,7 @@ class LightningParallelModule(_LightningModuleWrapperBase): Args: pl_module: The module to wrap. See description for `forward_module`. - .. deprecated:: v1.8 + .. deprecated:: v1.8.0 The argument ``pl_module`` is deprecated in v1.8.0 and will be removed in v1.10.0. Please use ``forward_module`` instead. diff --git a/src/pytorch_lightning/overrides/fairscale.py b/src/pytorch_lightning/overrides/fairscale.py index b995ecaeee4e2..d9fd2e60aff61 100644 --- a/src/pytorch_lightning/overrides/fairscale.py +++ b/src/pytorch_lightning/overrides/fairscale.py @@ -41,8 +41,8 @@ def __init__( def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule": rank_zero_deprecation( - "The function `unwrap_lightning_module_sharded` is deprecated in v1.8 and will be removed in v1.10. Access" - " the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." + "The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v1.10.0." + " Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`." ) model = wrapped_model if isinstance(model, ShardedDataParallel): diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py index 20dae62595e35..186e526313bba 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-10.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-10.py @@ -44,20 +44,22 @@ def test_deprecated_amp_level(): ) def test_v1_10_deprecated_pl_module_init_parameter(wrapper_class): with no_warning_call( - DeprecationWarning, match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8" + DeprecationWarning, match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8.0" ): wrapper_class(BoringModel()) - with pytest.deprecated_call(match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8"): + with pytest.deprecated_call( + match=rf"The argument `pl_module` in `{wrapper_class.__name__}` is deprecated in v1.8.0" + ): wrapper_class(pl_module=BoringModel()) def test_v1_10_deprecated_unwrap_lightning_module(): - with pytest.deprecated_call(match=r"The function `unwrap_lightning_module` is deprecated in v1.8"): + with pytest.deprecated_call(match=r"The function `unwrap_lightning_module` is deprecated in v1.8.0"): unwrap_lightning_module(BoringModel()) @RunIf(fairscale=True) def test_v1_10_deprecated_unwrap_lightning_module_sharded(): - with pytest.deprecated_call(match=r"The function `unwrap_lightning_module_sharded` is deprecated in v1.8"): + with pytest.deprecated_call(match=r"The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0"): unwrap_lightning_module_sharded(BoringModel()) diff --git a/tests/tests_pytorch/overrides/test_base.py b/tests/tests_pytorch/overrides/test_base.py index 32f2f0ba15646..27d2db688d7ae 100644 --- a/tests/tests_pytorch/overrides/test_base.py +++ b/tests/tests_pytorch/overrides/test_base.py @@ -38,5 +38,5 @@ def test_unwrap_lightning_module(): wrapped_model = _LightningModuleWrapperBase(wrapped_model) wrapped_model = DataParallel(wrapped_model) - with pytest.deprecated_call(match="The function `unwrap_lightning_module` is deprecated in v1.8"): + with pytest.deprecated_call(match="The function `unwrap_lightning_module` is deprecated in v1.8.0"): assert unwrap_lightning_module(wrapped_model) == model From 13c05eeac402c95dfee950fda19ee98aee559a1e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 12 Aug 2022 01:03:17 +0200 Subject: [PATCH 51/53] formatting --- src/pytorch_lightning/strategies/sharded_spawn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 554836bc735e9..48d47841511cf 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import Any, Dict, Generator, List, Tuple from torch import Tensor from torch.nn import Module from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase +from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase, _LightningModuleWrapperBase from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException From 3793896c241d699b579304ac1e01632c91b58945 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Aug 2022 23:04:53 +0000 Subject: [PATCH 52/53] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/strategies/sharded_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/sharded_spawn.py b/src/pytorch_lightning/strategies/sharded_spawn.py index 48d47841511cf..f19aae7302eea 100644 --- a/src/pytorch_lightning/strategies/sharded_spawn.py +++ b/src/pytorch_lightning/strategies/sharded_spawn.py @@ -19,7 +19,7 @@ from torch.optim import Optimizer import pytorch_lightning as pl -from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase, _LightningModuleWrapperBase +from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.exceptions import MisconfigurationException From 76d60379b8fd08702ed9f4f60671eb89528f2fd8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 12 Aug 2022 01:39:40 +0200 Subject: [PATCH 53/53] remove redundant mypy assertion --- src/pytorch_lightning/overrides/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/overrides/base.py b/src/pytorch_lightning/overrides/base.py index 027d61e4ee2dd..07f30c271b207 100644 --- a/src/pytorch_lightning/overrides/base.py +++ b/src/pytorch_lightning/overrides/base.py @@ -94,7 +94,6 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any: trainer = pl_module._trainer if trainer is not None: - assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) if trainer.training: output = self._forward_module.training_step(*inputs, **kwargs) # In manual_optimization, we need to prevent DDP reducer as