Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace unwrapping logic in strategies #13738

Merged
merged 61 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
57b917c
wip
awaelchli Jul 2, 2022
4c65ea6
Merge branch 'master' into refactor/remove-unwrap
awaelchli Jul 26, 2022
626ac9e
model setter
awaelchli Jul 26, 2022
b07e11c
fix import
awaelchli Jul 26, 2022
352397f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2022
7184b59
refactor wrappers overrides
awaelchli Jul 26, 2022
1f5d06f
Merge remote-tracking branch 'origin/refactor/remove-unwrap' into ref…
awaelchli Jul 26, 2022
f861a01
refactor
awaelchli Jul 26, 2022
78da9af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2022
2a2e1e8
update
awaelchli Jul 26, 2022
7e773b4
Merge remote-tracking branch 'origin/refactor/remove-unwrap' into ref…
awaelchli Jul 26, 2022
e94797f
update
awaelchli Jul 26, 2022
39d7755
update
awaelchli Jul 26, 2022
cad60f2
fixes
awaelchli Jul 26, 2022
9ee0630
update
awaelchli Jul 26, 2022
12a3ad6
simplify
awaelchli Jul 26, 2022
b99bc23
debug
awaelchli Jul 27, 2022
b7c74ef
update
awaelchli Jul 27, 2022
d842246
validate
awaelchli Jul 27, 2022
b176534
update
awaelchli Jul 27, 2022
67fae07
fix
awaelchli Jul 27, 2022
1e9fe94
debug
awaelchli Jul 27, 2022
5050399
teardown
awaelchli Jul 27, 2022
dfffb49
fix
awaelchli Jul 27, 2022
a0d1940
discussion
awaelchli Jul 27, 2022
0bbd41a
rename
awaelchli Jul 27, 2022
b7211a6
fix
awaelchli Jul 27, 2022
3a7e885
clear model reference on connect()
awaelchli Jul 27, 2022
be8abf7
fix tests
awaelchli Jul 27, 2022
feea943
Merge branch 'master' into refactor/remove-unwrap
awaelchli Aug 2, 2022
bdf407c
resolve merge conflict
awaelchli Aug 2, 2022
f353b9b
fix property
awaelchli Aug 2, 2022
b7baa82
fix attribute error
awaelchli Aug 2, 2022
4b40060
add backward compatibility
awaelchli Aug 2, 2022
4147fa7
undo changes in ipu
awaelchli Aug 2, 2022
e1c5cbd
include deepspeed
awaelchli Aug 2, 2022
bb4f8cc
mypy
awaelchli Aug 2, 2022
0f5d458
mypy
awaelchli Aug 2, 2022
0b0df76
deprecate unwrap function
awaelchli Aug 2, 2022
5485099
fairscale backward compat
awaelchli Aug 2, 2022
e6f7e90
flip position
awaelchli Aug 2, 2022
96b547d
update
awaelchli Aug 2, 2022
de2715f
format
awaelchli Aug 2, 2022
0564cd9
changelog placeholders
awaelchli Aug 2, 2022
1ac803a
Merge branch 'master' into refactor/remove-unwrap
awaelchli Aug 5, 2022
c87998c
chlog
awaelchli Aug 5, 2022
2331296
fix mypy error
awaelchli Aug 5, 2022
3399e0e
update test
awaelchli Aug 5, 2022
d5f5d69
update import for fairscale
awaelchli Aug 5, 2022
d552c9b
revert fairscale import refactor
awaelchli Aug 5, 2022
6acb768
Update src/pytorch_lightning/CHANGELOG.md
awaelchli Aug 5, 2022
d1baa8a
Merge branch 'master' into refactor/remove-unwrap
awaelchli Aug 10, 2022
3bfb48e
do same in ipu model as in deepspeed
awaelchli Aug 10, 2022
0756191
Update src/pytorch_lightning/overrides/data_parallel.py
awaelchli Aug 10, 2022
9ddeb72
Merge remote-tracking branch 'origin/refactor/remove-unwrap' into ref…
awaelchli Aug 10, 2022
5102398
revert
awaelchli Aug 10, 2022
41a98fc
update versions
awaelchli Aug 11, 2022
7ccd6c9
Merge branch 'master' into refactor/remove-unwrap
awaelchli Aug 11, 2022
13c05ee
formatting
awaelchli Aug 11, 2022
3793896
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 11, 2022
76d6037
remove redundant mypy assertion
awaelchli Aug 11, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Updated compatibility for LightningLite to run with the latest DeepSpeed 0.7.0 ([13967](https://github.com/Lightning-AI/lightning/pull/13967))


- Replaced the unwrapping logic in strategies with direct access to unwrapped `LightningModule` ([#13738](https://github.com/Lightning-AI/lightning/pull/13738))


### Deprecated

- Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
Expand All @@ -36,6 +39,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated the calls to `pytorch_lightning.utiltiies.meta` functions in favor of built-in https://github.com/pytorch/torchdistx support ([#13868](https://github.com/Lightning-AI/lightning/pull/13868))


- Deprecated the `unwrap_lightning_module` and `unwrap_lightning_module_sharded` utility functions in favor of accessing the unwrapped `LightningModule` on the strategy directly ([#13738](https://github.com/Lightning-AI/lightning/pull/13738))


- Deprecated the `pl_module` argument in `LightningParallelModule`, `LightningDistributedModule`, `LightningShardedDataParallel`, `LightningBaguaModule` and `LightningDeepSpeedModule` wrapper classes ([#13738](https://github.com/Lightning-AI/lightning/pull/13738))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved



### Removed

- Removed the deprecated `Trainer.training_type_plugin` property in favor of `Trainer.strategy` ([#14011](https://github.com/Lightning-AI/lightning/pull/14011))
Expand Down
78 changes: 61 additions & 17 deletions src/pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Union
from typing import Any, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -20,6 +20,7 @@

import pytorch_lightning as pl
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin
from pytorch_lightning.utilities import rank_zero_deprecation


class _LightningPrecisionModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
Expand Down Expand Up @@ -54,30 +55,47 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:


class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
def __init__(
self, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]]
) -> None:
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.

Inheriting classes may also modify the inputs or outputs of forward.

Args:
pl_module: the model to wrap
forward_module: The module to wrap. If it's not a LightningModule, it must have an attribute ``.module``
pointing to a LightningModule reference.
"""
super().__init__()
self.module = pl_module
if not isinstance(forward_module, pl.LightningModule) and (
not isinstance(getattr(forward_module, "module", None), pl.LightningModule)
):
raise ValueError(
"`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one,"
f" got: {forward_module.__class__.__qualname__}"
)
# TODO: In v1.10.0, remove the Optional type from forward_module and remove the assertion
assert forward_module is not None
self._forward_module = forward_module

# set the parameters_to_ignore from LightningModule.
_ddp_params_and_buffers_to_ignore = getattr(pl_module, "_ddp_params_and_buffers_to_ignore", [])
_ddp_params_and_buffers_to_ignore = getattr(self._forward_module, "_ddp_params_and_buffers_to_ignore", [])
self._ddp_params_and_buffers_to_ignore = [f"module.{p}" for p in _ddp_params_and_buffers_to_ignore]

@property
def lightning_module(self) -> "pl.LightningModule":
if isinstance(self._forward_module, pl.LightningModule):
return self._forward_module
return self._forward_module.module

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
pl_module = unwrap_lightning_module(self.module)
pl_module = self.lightning_module
trainer = pl_module._trainer

if trainer is not None:
assert isinstance(self.module, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
if trainer.training:
output = self.module.training_step(*inputs, **kwargs)
output = self._forward_module.training_step(*inputs, **kwargs)
# In manual_optimization, we need to prevent DDP reducer as
# it is done manually in `LightningModule.manual_backward`
# `require_backward_grad_sync` will be reset in the
Expand All @@ -86,27 +104,53 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
trainer.model.require_backward_grad_sync = False # type: ignore[assignment]
return output
if trainer.testing:
return self.module.test_step(*inputs, **kwargs)
return self._forward_module.test_step(*inputs, **kwargs)
if trainer.sanity_checking or trainer.validating:
return self.module.validation_step(*inputs, **kwargs)
return self._forward_module.validation_step(*inputs, **kwargs)
if trainer.predicting:
return self.module.predict_step(*inputs, **kwargs)
return self.module(*inputs, **kwargs)


def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule":
return self._forward_module.predict_step(*inputs, **kwargs)
return self._forward_module(*inputs, **kwargs)

@classmethod
def _validate_init_arguments(
cls,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
# TODO: In v1.10.0, remove this method and mark the forward_module init argument in all subclasses as required
if pl_module is not None:
rank_zero_deprecation(
f"The argument `pl_module` in `{cls.__name__}` is deprecated in v1.8.0 and will be removed in"
" v1.10.0. Please use `forward_module` instead."
)
elif forward_module is None:
raise ValueError("Argument `forward_module` is required.")


def unwrap_lightning_module(wrapped_model: nn.Module, _suppress_warning: bool = False) -> "pl.LightningModule":
"""Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module``
attributes on the wrapper.

.. deprecated:: v1.8.0
The function ``unwrap_lightning_module`` is deprecated in v1.8.0 and will be removed in v1.10.0. Access the
``LightningModule`` directly through the strategy attribute ``Strategy.lightning_module``.

Raises:
TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped
further.
"""
if not _suppress_warning:
rank_zero_deprecation(
"The function `unwrap_lightning_module` is deprecated in v1.8.0 and will be removed in v1.10.0. Access the"
" `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
)
model = wrapped_model
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = unwrap_lightning_module(model.module)
if isinstance(model, (_LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase)):
model = unwrap_lightning_module(model.module)
if isinstance(model, _LightningModuleWrapperBase):
model = model.lightning_module
if isinstance(model, _LightningPrecisionModuleWrapperBase):
model = model.module
if not isinstance(model, pl.LightningModule):
raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.")
return model
24 changes: 18 additions & 6 deletions src/pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import numbers
import warnings
from typing import Any, cast, Union
from typing import Any, cast, Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -52,11 +52,23 @@ class LightningParallelModule(_LightningModuleWrapperBase):
)

Args:
pl_module: the model to wrap
pl_module: The module to wrap. See description for `forward_module`.

.. deprecated:: v1.8.0
The argument ``pl_module`` is deprecated in v1.8.0 and will be removed in v1.10.0. Please use
``forward_module`` instead.

forward_module: The module to wrap. If it's not a ``LightningModule``, it must have an attribute ``.module``
pointing to a ``LightningModule`` reference.
"""

def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
super().__init__(pl_module)
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
self._validate_init_arguments(pl_module, forward_module)
super().__init__(forward_module=(pl_module or forward_module))
_ignore_scalar_return_in_dp()

def forward(self, *inputs: Any, **kwargs: Any) -> Any:
Expand All @@ -65,7 +77,7 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
output = super().forward(*inputs, **kwargs)

def output_transform(data: Any) -> Any:
device = cast(torch.device, self.module.device)
device = cast(torch.device, self.lightning_module.device)
data = python_scalar_to_tensor(data, device)
data = unsqueeze_scalar_tensor(data)
return data
Expand Down Expand Up @@ -95,7 +107,7 @@ def find_tensor_with_device(tensor: Tensor) -> Tensor:

if replica_device is not None:
# by calling .to() we force the update to the self.device property
self.module.to(device=replica_device)
self._forward_module.to(device=replica_device)
else:
rank_zero_warn(
"Could not determine on which device the inputs are."
Expand Down
11 changes: 9 additions & 2 deletions src/pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,19 @@
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler, Dataset, DistributedSampler, Sampler

from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class LightningDistributedModule(_LightningModuleWrapperBase):
...
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
self._validate_init_arguments(pl_module, forward_module)
super().__init__(forward_module=(pl_module or forward_module))


def _find_tensors(
Expand Down
29 changes: 23 additions & 6 deletions src/pytorch_lightning/overrides/fairscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,44 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union

import torch.nn as nn

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module
from pytorch_lightning.utilities import _IS_WINDOWS, _module_available
from pytorch_lightning.overrides.base import (
_LightningModuleWrapperBase,
_LightningPrecisionModuleWrapperBase,
unwrap_lightning_module,
)
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.imports import _IS_WINDOWS, _module_available

_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available("fairscale.nn")
akihironitta marked this conversation as resolved.
Show resolved Hide resolved

if _FAIRSCALE_AVAILABLE:

if _FAIRSCALE_AVAILABLE: # pragma: no-cover
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel

class LightningShardedDataParallel(_LightningModuleWrapperBase):
# Just do this for later docstrings
pass
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
self._validate_init_arguments(pl_module, forward_module)
super().__init__(forward_module=(pl_module or forward_module))

def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule":
rank_zero_deprecation(
"The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v1.10.0."
" Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
)
model = wrapped_model
if isinstance(model, ShardedDataParallel):
model = model.module

return unwrap_lightning_module(model)
return unwrap_lightning_module(model, _suppress_warning=True)

else:
LightningShardedDataParallel = ... # type: ignore[assignment,misc]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import torch

from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
Expand Down
27 changes: 11 additions & 16 deletions src/pytorch_lightning/strategies/bagua.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
from torch.nn import Module

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import (
_LightningModuleWrapperBase,
_LightningPrecisionModuleWrapperBase,
unwrap_lightning_module,
)
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand Down Expand Up @@ -54,10 +50,16 @@


class LightningBaguaModule(_LightningModuleWrapperBase):
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
super().__init__(pl_module)
def __init__(
self,
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
) -> None:
self._validate_init_arguments(pl_module, forward_module)
forward_module = pl_module or forward_module
super().__init__(forward_module=forward_module)
# Bagua use `bagua_module_name` to distinguish different modules
self._bagua_module_name = f"{pl_module.__class__.__name__}{id(pl_module)}"
self._bagua_module_name = f"{forward_module.__class__.__name__}{id(forward_module)}"


class BaguaStrategy(DDPStrategy):
Expand Down Expand Up @@ -109,13 +111,6 @@ def __init__(
self._bagua_flatten = flatten
self._bagua_kwargs = bagua_kwargs

@property
def lightning_module(self) -> Optional["pl.LightningModule"]:
model = self.model
if isinstance(model, BaguaDistributedDataParallel):
model = model.module
return unwrap_lightning_module(model) if model is not None else None

def setup_distributed(self) -> None:
reset_seed()

Expand Down Expand Up @@ -190,7 +185,7 @@ def _check_qadam_optimizer(self) -> None:

def _configure_bagua_model(self, trainer: "pl.Trainer") -> None:
model = LightningBaguaModule(self.model) # type: ignore[arg-type]
self._model = self._setup_model(model)
self.model = self._setup_model(model)

# start the background communication for async algorithm
if trainer.training and self._bagua_algorithm == "async":
Expand Down
8 changes: 6 additions & 2 deletions src/pytorch_lightning/strategies/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand All @@ -55,7 +54,12 @@
sync_ddp_if_available,
)
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException
from pytorch_lightning.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, _TORCH_GREATER_EQUAL_1_11
from pytorch_lightning.utilities.imports import (
_FAIRSCALE_AVAILABLE,
_IS_WINDOWS,
_TORCH_GREATER_EQUAL_1_10,
_TORCH_GREATER_EQUAL_1_11,
)
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.seed import reset_seed
Expand Down
Loading