Skip to content

Commit

Permalink
Replace unwrapping logic in strategies (#13738)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
3 people authored Aug 12, 2022
1 parent 6789a06 commit 807f9d8
Show file tree
Hide file tree
Showing 26 changed files with 274 additions and 159 deletions.
10 changes: 10 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Updated compatibility for LightningLite to run with the latest DeepSpeed 0.7.0 ([13967](https://github.com/Lightning-AI/lightning/pull/13967))


- Replaced the unwrapping logic in strategies with direct access to unwrapped `LightningModule` ([#13738](https://github.com/Lightning-AI/lightning/pull/13738))


### Deprecated

- Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
Expand All @@ -39,6 +42,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the calls to `pytorch_lightning.utiltiies.meta` functions in favor of built-in https://github.com/pytorch/torchdistx support ([#13868](https://github.com/Lightning-AI/lightning/pull/13868))


- Deprecated the `unwrap_lightning_module` and `unwrap_lightning_module_sharded` utility functions in favor of accessing the unwrapped `LightningModule` on the strategy directly ([#13738](https://github.com/Lightning-AI/lightning/pull/13738))


- Deprecated the `pl_module` argument in `LightningParallelModule`, `LightningDistributedModule`, `LightningShardedDataParallel`, `LightningBaguaModule` and `LightningDeepSpeedModule` wrapper classes ([#13738](https://github.com/Lightning-AI/lightning/pull/13738))



### Removed

- Removed the deprecated `Trainer.training_type_plugin` property in favor of `Trainer.strategy` ([#14011](https://github.com/Lightning-AI/lightning/pull/14011))
Expand Down
78 changes: 61 additions & 17 deletions src/pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Union
from typing import Any, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -20,6 +20,7 @@

import pytorch_lightning as pl
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.utilities import rank_zero_deprecation


class _LightningPrecisionModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
Expand Down Expand Up @@ -54,30 +55,47 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:


class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
def __init__(
self, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]]
) -> None:
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.
Inheriting classes may also modify the inputs or outputs of forward.
Args:
pl_module: the model to wrap
forward_module: The module to wrap. If it's not a LightningModule, it must have an attribute ``.module``
pointing to a LightningModule reference.
"""
super().__init__()
self.module = pl_module
if not isinstance(forward_module, pl.LightningModule) and (
not isinstance(getattr(forward_module, "module", None), pl.LightningModule)
):
raise ValueError(
"`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one,"
f" got: {forward_module.__class__.__qualname__}"
)
# TODO: In v1.10.0, remove the Optional type from forward_module and remove the assertion
assert forward_module is not None
self._forward_module = forward_module

# set the parameters_to_ignore from LightningModule.
_ddp_params_and_buffers_to_ignore = getattr(pl_module, "_ddp_params_and_buffers_to_ignore", [])
_ddp_params_and_buffers_to_ignore = getattr(self._forward_module, "_ddp_params_and_buffers_to_ignore", [])
self._ddp_params_and_buffers_to_ignore = [f"module.{p}" for p in _ddp_params_and_buffers_to_ignore]

@property
def lightning_module(self) -> "pl.LightningModule":
if isinstance(self._forward_module, pl.LightningModule):
return self._forward_module
return self._forward_module.module

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
pl_module = unwrap_lightning_module(self.module)
pl_module = self.lightning_module
trainer = pl_module._trainer

if trainer is not None:
assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
if trainer.training:
output = self.module.training_step(*inputs, **kwargs)
output = self._forward_module.training_step(*inputs, **kwargs)
# In manual_optimization, we need to prevent DDP reducer as
# it is done manually in `LightningModule.manual_backward`
# `require_backward_grad_sync` will be reset in the
Expand All @@ -86,27 +104,53 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
trainer.model.require_backward_grad_sync = False # type: ignore[assignment]
return output
if trainer.testing:
return self.module.test_step(*inputs, **kwargs)
return self._forward_module.test_step(*inputs, **kwargs)
if trainer.sanity_checking or trainer.validating:
return self.module.validation_step(*inputs, **kwargs)
return self._forward_module.validation_step(*inputs, **kwargs)
if trainer.predicting:
return self.module.predict_step(*inputs, **kwargs)
return self.module(*inputs, **kwargs)


def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule":
return self._forward_module.predict_step(*inputs, **kwargs)
return self._forward_module(*inputs, **kwargs)

@classmethod
def _validate_init_arguments(
cls,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
# TODO: In v1.10.0, remove this method and mark the forward_module init argument in all subclasses as required
if pl_module is not None:
rank_zero_deprecation(
f"The argument `pl_module` in `{cls.__name__}` is deprecated in v1.8.0 and will be removed in"
" v1.10.0. Please use `forward_module` instead."
)
elif forward_module is None:
raise ValueError("Argument `forward_module` is required.")


def unwrap_lightning_module(wrapped_model: nn.Module, _suppress_warning: bool = False) -> "pl.LightningModule":
"""Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module``
attributes on the wrapper.
.. deprecated:: v1.8.0
The function ``unwrap_lightning_module`` is deprecated in v1.8.0 and will be removed in v1.10.0. Access the
``LightningModule`` directly through the strategy attribute ``Strategy.lightning_module``.
Raises:
TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped
further.
"""
if not _suppress_warning:
rank_zero_deprecation(
"The function `unwrap_lightning_module` is deprecated in v1.8.0 and will be removed in v1.10.0. Access the"
" `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
)
model = wrapped_model
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = unwrap_lightning_module(model.module)
if isinstance(model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)):
model = unwrap_lightning_module(model.module)
if isinstance(model, _LightningModuleWrapperBase):
model = model.lightning_module
if isinstance(model, _LightningPrecisionModuleWrapperBase):
model = model.module
if not isinstance(model, pl.LightningModule):
raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.")
return model
24 changes: 18 additions & 6 deletions src/pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import numbers
import warnings
from typing import Any, cast, Union
from typing import Any, cast, Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -52,11 +52,23 @@ class LightningParallelModule(_LightningModuleWrapperBase):
)
Args:
pl_module: the model to wrap
pl_module: The module to wrap. See description for `forward_module`.
.. deprecated:: v1.8.0
The argument ``pl_module`` is deprecated in v1.8.0 and will be removed in v1.10.0. Please use
``forward_module`` instead.
forward_module: The module to wrap. If it's not a ``LightningModule``, it must have an attribute ``.module``
pointing to a ``LightningModule`` reference.
"""

def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
super().__init__(pl_module)
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
self._validate_init_arguments(pl_module, forward_module)
super().__init__(forward_module=(pl_module or forward_module))
_ignore_scalar_return_in_dp()

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
Expand All @@ -65,7 +77,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
output = super().forward(*inputs, **kwargs)

def output_transform(data: Any) -> Any:
device = cast(torch.device, self.module.device)
device = cast(torch.device, self.lightning_module.device)
data = python_scalar_to_tensor(data, device)
data = unsqueeze_scalar_tensor(data)
return data
Expand Down Expand Up @@ -95,7 +107,7 @@ def find_tensor_with_device(tensor: Tensor) -> Tensor:

if replica_device is not None:
# by calling .to() we force the update to the self.device property
self.module.to(device=replica_device)
self._forward_module.to(device=replica_device)
else:
rank_zero_warn(
"Could not determine on which device the inputs are."
Expand Down
11 changes: 9 additions & 2 deletions src/pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, Dataset, DistributedSampler, Sampler

from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class LightningDistributedModule(_LightningModuleWrapperBase):
...
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
self._validate_init_arguments(pl_module, forward_module)
super().__init__(forward_module=(pl_module or forward_module))


def _find_tensors(
Expand Down
29 changes: 23 additions & 6 deletions src/pytorch_lightning/overrides/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,44 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union

import torch.nn as nn

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module
from pytorch_lightning.utilities import _IS_WINDOWS, _module_available
from pytorch_lightning.overrides.base import (
_LightningModuleWrapperBase,
_LightningPrecisionModuleWrapperBase,
unwrap_lightning_module,
)
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.imports import _IS_WINDOWS, _module_available

_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn")

if _FAIRSCALE_AVAILABLE:

if _FAIRSCALE_AVAILABLE: # pragma: no-cover
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel

class LightningShardedDataParallel(_LightningModuleWrapperBase):
# Just do this for later docstrings
pass
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
self._validate_init_arguments(pl_module, forward_module)
super().__init__(forward_module=(pl_module or forward_module))

def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule":
rank_zero_deprecation(
"The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v1.10.0."
" Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
)
model = wrapped_model
if isinstance(model, ShardedDataParallel):
model = model.module

return unwrap_lightning_module(model)
return unwrap_lightning_module(model, _suppress_warning=True)

else:
LightningShardedDataParallel = ... # type: ignore[assignment,misc]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import torch

from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
Expand Down
27 changes: 11 additions & 16 deletions src/pytorch_lightning/strategies/bagua.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
from torch.nn import Module

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import (
_LightningModuleWrapperBase,
_LightningPrecisionModuleWrapperBase,
unwrap_lightning_module,
)
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand Down Expand Up @@ -54,10 +50,16 @@


class LightningBaguaModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
super().__init__(pl_module)
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
self._validate_init_arguments(pl_module, forward_module)
forward_module = pl_module or forward_module
super().__init__(forward_module=forward_module)
# Bagua use `bagua_module_name` to distinguish different modules
self._bagua_module_name = f"{pl_module.__class__.__name__}{id(pl_module)}"
self._bagua_module_name = f"{forward_module.__class__.__name__}{id(forward_module)}"


class BaguaStrategy(DDPStrategy):
Expand Down Expand Up @@ -109,13 +111,6 @@ def __init__(
self._bagua_flatten = flatten
self._bagua_kwargs = bagua_kwargs

@property
def lightning_module(self) -> Optional["pl.LightningModule"]:
model = self.model
if isinstance(model, BaguaDistributedDataParallel):
model = model.module
return unwrap_lightning_module(model) if model is not None else None

def setup_distributed(self) -> None:
reset_seed()

Expand Down Expand Up @@ -190,7 +185,7 @@ def _check_qadam_optimizer(self) -> None:

def _configure_bagua_model(self, trainer: "pl.Trainer") -> None:
model = LightningBaguaModule(self.model) # type: ignore[arg-type]
self._model = self._setup_model(model)
self.model = self._setup_model(model)

# start the background communication for async algorithm
if trainer.training and self._bagua_algorithm == "async":
Expand Down
8 changes: 6 additions & 2 deletions src/pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand All @@ -55,7 +54,12 @@
sync_ddp_if_available,
)
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
from pytorch_lightning.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.imports import (
_FAIRSCALE_AVAILABLE,
_IS_WINDOWS,
_TORCH_GREATER_EQUAL_1_10,
_TORCH_GREATER_EQUAL_1_11,
)
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.seed import reset_seed
Expand Down
Loading

0 comments on commit 807f9d8

Please sign in to comment.