From 78f1eb46b789655933e16ec82dcc0b50676c42d7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 23 Feb 2021 13:31:59 +0000 Subject: [PATCH 01/62] Add initial FSDP integration --- pytorch_lightning/accelerators/accelerator.py | 4 +- pytorch_lightning/overrides/fairscale.py | 21 ++- pytorch_lightning/plugins/__init__.py | 6 + .../plugins/precision/__init__.py | 3 + .../plugins/precision/deepspeed_precision.py | 6 +- .../precision/full_sharded_native_amp.py | 29 ++++ .../plugins/precision/precision_plugin.py | 4 +- .../plugins/precision/sharded_native_amp.py | 6 +- .../plugins/training_type/__init__.py | 1 + .../plugins/training_type/full_sharded.py | 150 ++++++++++++++++++ .../plugins/training_type/rpc_sequential.py | 10 +- .../connectors/accelerator_connector.py | 28 +++- .../connectors/checkpoint_connector.py | 3 +- pytorch_lightning/utilities/__init__.py | 1 + pytorch_lightning/utilities/enums.py | 1 + pytorch_lightning/utilities/imports.py | 1 + requirements/extra.txt | 2 +- .../test_accelerator_connector.py | 6 +- tests/plugins/test_full_sharded_plugin.py | 146 +++++++++++++++++ tests/plugins/test_sharded_plugin.py | 2 +- tests/special_tests.sh | 1 + 21 files changed, 410 insertions(+), 21 deletions(-) create mode 100644 pytorch_lightning/plugins/precision/full_sharded_native_amp.py create mode 100644 pytorch_lightning/plugins/training_type/full_sharded.py create mode 100644 tests/plugins/test_full_sharded_plugin.py diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 84d53b5addd6b..c24f96db85f35 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -290,7 +290,7 @@ def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Opt def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None: """clips all the optimizer parameters to the given value""" - self.precision_plugin.clip_gradients(optimizer, clip_val) + self.precision_plugin.clip_gradients(self.model, optimizer, clip_val) def on_train_epoch_end(self, outputs) -> None: """Hook to do something on the end of an training epoch @@ -371,7 +371,7 @@ def optimizer_state(self, optimizer: Optimizer) -> dict: return optimizer.state_dict() def on_save(self, checkpoint): - return checkpoint + return self.training_type_plugin.on_save(checkpoint) def barrier(self, name: Optional[str] = None) -> None: self.training_type_plugin.barrier(name=name) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index f7c3b8d5fd575..4caa4e9b109c9 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -13,7 +13,7 @@ # limitations under the License. from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module -from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULL_SHARDED_AVAILABLE LightningShardedDataParallel = None if _FAIRSCALE_AVAILABLE: @@ -29,3 +29,22 @@ def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule: model = model.module return unwrap_lightning_module(model) + + +LightningFullShardedDataParallel = None +if _FAIRSCALE_FULL_SHARDED_AVAILABLE: + from fairscale.nn import FlattenParamsWrapper + from fairscale.nn.data_parallel import FullyShardedDataParallel + + class LightningFullShardedDataParallel(_LightningModuleWrapperBase): + # Just do this for later docstrings + pass + + def unwrap_lightning_module_full_sharded(wrapped_model) -> LightningModule: + model = wrapped_model + if isinstance(model, FullyShardedDataParallel): + model = model.module + # Additional check if we're using a flattened parameters buffer + if isinstance(model, FlattenParamsWrapper): + model = model.module + return unwrap_lightning_module(model) diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index dec672d025294..df836cddbfa92 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,6 +1,9 @@ from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.full_sharded_native_amp import ( # noqa: F401 + FullShardedNativeMixedPrecisionPlugin, +) from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 @@ -10,6 +13,7 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.full_sharded import FullShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 @@ -29,6 +33,8 @@ "DDPSpawnPlugin", "DeepSpeedPlugin", "DeepSpeedPrecisionPlugin", + "FullShardedPlugin", + "FullShardedNativeMixedPrecisionPlugin", "HorovodPlugin", "NativeMixedPrecisionPlugin", "PrecisionPlugin", diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index fc60deffcbb77..5c5d8d95cc30f 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -1,5 +1,8 @@ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.full_sharded_native_amp import ( # noqa: F401 + FullShardedNativeMixedPrecisionPlugin, +) from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 711ede2f7ded4..a636c679fc488 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -1,4 +1,4 @@ -from typing import Callable, Union +from typing import Any, Callable, Union import torch from torch.optim import Optimizer @@ -54,7 +54,9 @@ def backward( return closure_loss - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + def clip_gradients( + self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0) + ): """ DeepSpeed handles clipping gradients via the training type plugin. """ diff --git a/pytorch_lightning/plugins/precision/full_sharded_native_amp.py b/pytorch_lightning/plugins/precision/full_sharded_native_amp.py new file mode 100644 index 0000000000000..f2b5a53f9aff0 --- /dev/null +++ b/pytorch_lightning/plugins/precision/full_sharded_native_amp.py @@ -0,0 +1,29 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Union + +from torch.optim import Optimizer + +from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin + + +class FullShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): + """Mixed Precision for Full Sharded Training + """ + + def clip_gradients( + self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0) + ): + # Model manages clipping of gradients + model.clip_grad_norm_(clip_val, norm_type) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 34879e514a6f2..0da1960713274 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -86,7 +86,9 @@ def pre_optimizer_step( def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: """Hook to do something after each optimizer step.""" - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)) -> None: + def clip_gradients( + self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0) + ): """Clips the gradients to a specific value""" # TODO: separate TPU case from here if clip_val is None: diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index b3b01fc720d2b..dbe9ac4b43e52 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Union +from typing import Any, cast, Union from torch.optim import Optimizer @@ -31,6 +31,8 @@ def __init__(self): super().__init__() self.scaler = ShardedGradScaler() - def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + def clip_gradients( + self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0) + ): optimizer = cast(OSS, optimizer) optimizer.clip_grad_norm(clip_val, norm_type=norm_type) diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py index 30723d67da3f4..bdce16121375c 100644 --- a/pytorch_lightning/plugins/training_type/__init__.py +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -3,6 +3,7 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.full_sharded import FullShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/training_type/full_sharded.py b/pytorch_lightning/plugins/training_type/full_sharded.py new file mode 100644 index 0000000000000..9409d969ab418 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/full_sharded.py @@ -0,0 +1,150 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional + +import torch + +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULL_SHARDED_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException + +if _FAIRSCALE_AVAILABLE: + from fairscale.nn.data_parallel import FullyShardedDataParallel + + from pytorch_lightning.overrides.fairscale import ( + LightningFullShardedDataParallel, + unwrap_lightning_module_full_sharded, + ) + + +class FullShardedPlugin(DDPPlugin): + + def __init__( + self, + cpu_offload: bool = True, + flatten_parameters: bool = False, + reshard_after_forward: bool = True, + move_grads_to_cpu: Optional[bool] = None, + fp32_reduce_scatter: Optional[bool] = None, + compute_dtype: Optional[torch.dtype] = None, + bucket_cap_mb: int = 25, + parallel_devices: Optional[List[torch.device]] = None, + num_nodes: int = 1, + cluster_environment: ClusterEnvironment = None, + sync_batchnorm: Optional[bool] = False + ): + """ + + Provides capabilities to run training using the Full Sharded capabilities provided by FairScale. + + Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model + size, whilst using efficient communication to reduce overhead. In practice, this means we can remain + at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar + to ZeRO-Stage 3 but have been modified/adjusted for PyTorch. + + `For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`. + + .. warning:: ``FullShardedPlugin`` is in beta and subject to change. + + Defaults have been set to enable CPU Offload, but options have been exposed and may require configuration + based on your level of memory/speed efficiency. + We suggest having a look at this PR for more information. + `https://github.com/facebookresearch/fairscale/pull/413` + + + Many of the helpful doc strings below came from the original FairScale documentation: + `https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html` + + Arguments: + + cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False). + + move_grads_to_cpu: Moves gradient shards to CPU after reducation. + Only disable if using CPU based optimizers (defaults to ``cpu_offload``). + + flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency + (default: False). + + reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows + down training. Only revelant when nesting FullyShardedDataParallel wrappers inside the model. + (default: False). + + fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision + (default: None) + + compute_dtype: dtype for full parameters for computation. Default to torch.float32, + unless using mixed precision, in which case defaults to torch.float16. + + bucket_cap_mb: bucket parameters so that gradient reduction + can potentially overlap with backward computation. + bucket_cap_mb controls the bucket size in MegaBytes (MB). + Buckets are sub-divided based on world_size, + so the max shard size is roughly bucket_cap_mb / world_size. + Values <= 0 disable bucketing. (Default: 25). + + """ + if not _FAIRSCALE_FULL_SHARDED_AVAILABLE: + raise MisconfigurationException( + "Full Sharded Training is not available. Install the latest FairScale via `pip install fairscale -U`" + ) + + if sync_batchnorm: + raise MisconfigurationException("Currently sync batch norm is not supported by Full Sharded Training.") + super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm=sync_batchnorm) + self.cpu_offload = cpu_offload + self.move_grads_to_cpu = move_grads_to_cpu + self.flatten_parameters = flatten_parameters + self.reshard_after_forward = reshard_after_forward + self.fp32_reduce_scatter = fp32_reduce_scatter + self.compute_dtype = compute_dtype + self.bucket_cap_mb = bucket_cap_mb + + def configure_ddp(self): + precision = self.lightning_module.trainer.precision + self.model = FullyShardedDataParallel( + LightningFullShardedDataParallel(self.model), + cpu_offload=self.cpu_offload, + move_grads_to_cpu=self.move_grads_to_cpu, + flatten_parameters=self.flatten_parameters, + mixed_precision=precision == "mixed", + reshard_after_forward=self.reshard_after_forward, + fp32_reduce_scatter=self.fp32_reduce_scatter, + compute_dtype=self.compute_dtype, + bucket_cap_mb=self.bucket_cap_mb, + ) + + @property + def lightning_module(self) -> LightningModule: + return unwrap_lightning_module_full_sharded(self.model) + + def model_to_device(self): + if not self.cpu_offload: + super().model_to_device() + + def on_save(self, checkpoint: dict) -> dict: + state_dict = self.collate_state_dict() + checkpoint['state_dict'] = state_dict + return checkpoint + + def collate_state_dict(self): + """ + Collects the models sharded state dict from all processes before returning. + Returns: The unsharded model state dict. + """ + state_dict = self.model.state_dict() + # Remove module prefix from state dict as this is the behaviour of state dict. + state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()} + return state_dict diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 3878aa9db3ea4..9ffcfc770ab59 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -25,7 +25,7 @@ from pytorch_lightning.overrides.distributed import LightningDistributedModule from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin from pytorch_lightning.trainer.states import RunningStage -from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only +from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_PIPE_AVAILABLE: @@ -56,6 +56,10 @@ def __init__( .. _RPCSequentialPlugin: https://arxiv.org/abs/1811.06965 + .. warning:: + This plugin has been deprecated. Please use the ``FullShardedPlugin`` which provides better performance + and scaling without pipelining the model. + Pipeline parallelism comes with with checkpointing to reduce peak memory required to train while minimizing device under-utilization. This is turned on by default and can be turned off via the checkpoint argument. @@ -87,6 +91,10 @@ def __init__( at the same time. Defaults to `True` if `get_model_parallel_world_size() > 1` """ + rank_zero_warn( + "RPC Sequential Plugin has been deprecated. Please use the `FullShardedPlugin` " + "which provides better performance and scaling without pipelining the model." + ) self._check_pipe_available() super().__init__(rpc_timeout_sec=rpc_timeout_sec, **kwargs) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 7021081d6cc90..74391b61580d2 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -32,6 +32,8 @@ DDPSpawnShardedPlugin, DeepSpeedPlugin, DeepSpeedPrecisionPlugin, + FullShardedNativeMixedPrecisionPlugin, + FullShardedPlugin, HorovodPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin, @@ -249,7 +251,7 @@ def use_dp(self) -> bool: def use_ddp(self) -> bool: return self._distrib_type in ( DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED, - DistributedType.DDP_SHARDED_SPAWN, DistributedType.DEEPSPEED + DistributedType.DDP_SHARDED_SPAWN, DistributedType.FULL_SHARDED, DistributedType.DEEPSPEED ) @property @@ -329,8 +331,10 @@ def select_precision_plugin(self) -> PrecisionPlugin: raise MisconfigurationException(msg) else: log.info("Using native 16bit precision.") - if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): + if self._sharded_training_type: return ShardedNativeMixedPrecisionPlugin() + if self._full_sharded_training_type: + return FullShardedNativeMixedPrecisionPlugin() return NativeMixedPrecisionPlugin() if self.amp_type == AMPType.APEX: @@ -339,9 +343,9 @@ def select_precision_plugin(self) -> PrecisionPlugin: "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) - if isinstance(self.training_type_plugin, (DDPShardedPlugin, DDPSpawnShardedPlugin)): + if self._sharded_training_type or self._full_sharded_training_type: raise MisconfigurationException( - "Sharded Plugin is not supported with Apex AMP," + "Sharded Plugins are not supported with Apex AMP," " please using native AMP for 16-bit precision." ) log.info("Using APEX 16bit precision.") @@ -367,7 +371,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN - + use_ddp_full_sharded = self._distrib_type == DistributedType.FULL_SHARDED # TODO: decouple from TE # ddp script mode uses the same flags as TE if os.environ.get("PL_IN_DDP_SUBPROCESS", False): @@ -375,6 +379,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: if self.on_tpu: ddp_plugin_cls = TPUSpawnPlugin + elif use_ddp_full_sharded: + ddp_plugin_cls = FullShardedPlugin elif use_ddp_sharded: ddp_plugin_cls = DDPShardedPlugin elif use_ddp_sharded_spawn: @@ -612,3 +618,15 @@ def configure_slurm_ddp(self): # notify user the that slurm is managing tasks if self.is_slurm_managing_tasks: rank_zero_info("Multi-processing is handled by Slurm.") + + @property + def _sharded_training_type(self) -> bool: + return isinstance(self.training_type_plugin, + (DDPShardedPlugin, DDPSpawnShardedPlugin + )) or self._distrib_type in (DistributedType.DDP_SHARDED, DistributedType.DDP_SHARDED_SPAWN) + + @property + def _full_sharded_training_type(self) -> bool: + return isinstance( + self.training_type_plugin, FullShardedPlugin + ) or self._distrib_type == DistributedType.FULL_SHARDED diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3b75f406b1917..2384f2aa8fbe5 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -402,8 +402,7 @@ def save_checkpoint(self, filepath, weights_only: bool = False): if self.trainer.is_global_zero: # write the checkpoint dictionary on the file - if self.trainer.training_type_plugin: - checkpoint = self.trainer.training_type_plugin.on_save(checkpoint) + checkpoint = self.trainer.accelerator.on_save(checkpoint) try: atomic_save(checkpoint, filepath) except AttributeError as err: diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index cf3aa06f305b8..905237c3ed164 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -28,6 +28,7 @@ _BOLTS_AVAILABLE, _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, + _FAIRSCALE_FULL_SHARDED_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE, _GROUP_AVAILABLE, _HOROVOD_AVAILABLE, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 3e4add4fb68d1..e83d70b22bb09 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -66,6 +66,7 @@ class DistributedType(LightningEnum): HOROVOD = 'horovod' DDP_SHARDED = 'ddp_sharded' DDP_SHARDED_SPAWN = 'ddp_sharded_spawn' + FULL_SHARDED = 'full_sharded' RPC_SEQUENTIAL_PLUGIN = 'rpc_sequential' diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 8024997382457..8a5a36495ca9c 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -57,6 +57,7 @@ def _compare_version(package: str, op, version) -> bool: _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel') +_FAIRSCALE_FULL_SHARDED_AVAILABLE = not _IS_WINDOWS and _compare_version("fairscale", operator.ge, "0.3.0") _FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') _HOROVOD_AVAILABLE = _module_available("horovod.torch") diff --git a/requirements/extra.txt b/requirements/extra.txt index 0e7dffbcb39b0..61ce039395702 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,4 @@ torchtext>=0.5, <0.7 # TODO: temporary fix fix for compatibility onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -https://github.com/PyTorchLightning/fairscale/archive/pl_1.2.0.zip +fairscale>=0.3.0 diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 82b631807c8e9..730086dede580 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -26,8 +26,8 @@ from pytorch_lightning.plugins import ( DDP2Plugin, DDPPlugin, - DDPShardedPlugin, DDPSpawnPlugin, + FullShardedPlugin, PrecisionPlugin, SingleDevicePlugin, ) @@ -396,7 +396,7 @@ def test_plugin_accelerator_choice(accelerator, plugin): Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent. """ trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2) - assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) + assert isinstance(trainer.accelerator.training_type_plugin, FullShardedPlugin) trainer = Trainer(plugins=plugin, num_processes=2) - assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) + assert isinstance(trainer.accelerator.training_type_plugin, FullShardedPlugin) diff --git a/tests/plugins/test_full_sharded_plugin.py b/tests/plugins/test_full_sharded_plugin.py new file mode 100644 index 0000000000000..362a22281326a --- /dev/null +++ b/tests/plugins/test_full_sharded_plugin.py @@ -0,0 +1,146 @@ +import os +import platform +from unittest import mock + +import pytest +import torch + +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.plugins import FullShardedNativeMixedPrecisionPlugin, FullShardedPlugin +from pytorch_lightning.utilities import _APEX_AVAILABLE, _FAIRSCALE_FULL_SHARDED_AVAILABLE, _NATIVE_AMP_AVAILABLE +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel + + +@pytest.mark.parametrize(["plugin"], [("full_sharded", )]) +@pytest.mark.skipif(not _FAIRSCALE_FULL_SHARDED_AVAILABLE, reason="Fairscale is not available") +def test_sharded_ddp_choice(tmpdir, plugin): + """ + Test to ensure that plugin is correctly chosen + """ + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + if plugin == 'full_sharded': + assert isinstance(trainer.accelerator.training_type_plugin, FullShardedPlugin) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + plugins=plugin, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex") +@pytest.mark.skipif(not _FAIRSCALE_FULL_SHARDED_AVAILABLE, reason="Fairscale is not available") +def test_invalid_apex_sharded(tmpdir): + """ + Test to ensure that we raise an error when we try to use apex and sharded + """ + + model = BoringModel() + with pytest.raises(MisconfigurationException, match='Sharded Plugins are not supported with Apex AMP'): + trainer = Trainer( + fast_dev_run=True, + plugins='full_sharded', + precision=16, + amp_backend='apex', + ) + + trainer.fit(model) + + +@pytest.mark.parametrize(["plugin"], [("full_sharded", )]) +@pytest.mark.skipif(not _FAIRSCALE_FULL_SHARDED_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) +@mock.patch('torch.cuda.device_count', return_value=1) +@mock.patch('torch.cuda.is_available', return_value=True) +def test_ddp_choice_sharded_amp(device_count_mock, mock_cuda_available, plugin, tmpdir): + """ + Test to ensure that plugin native amp plugin is correctly chosen when using sharded + """ + + class CB(Callback): + + def on_fit_start(self, trainer, pl_module): + if plugin == 'full_sharded': + assert isinstance(trainer.accelerator.training_type_plugin, FullShardedPlugin) + assert isinstance(trainer.accelerator.precision_plugin, FullShardedNativeMixedPrecisionPlugin) + raise SystemExit() + + model = BoringModel() + trainer = Trainer( + fast_dev_run=True, + gpus=1, + precision=16, + plugins=plugin, + callbacks=[CB()], + ) + + with pytest.raises(SystemExit): + trainer.fit(model) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires CUDA") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _FAIRSCALE_FULL_SHARDED_AVAILABLE, reason="Fairscale is not available") +def test_full_sharded_plugin_checkpoint(tmpdir): + """ + Test to ensure that checkpoint is saved correctly when using a single GPU. + """ + model = BoringModel() + trainer = Trainer( + gpus=1, + plugins='full_sharded', + fast_dev_run=True, + precision=16, + ) + + trainer.fit(model) + + _assert_save_equality(tmpdir, trainer) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _FAIRSCALE_FULL_SHARDED_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.skipif( + not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" +) +def test_full_sharded_plugin_checkpoint_multi_gpu(tmpdir): + """ + Test to ensure that checkpoint is saved correctly when using multiple GPUs + """ + model = BoringModel() + trainer = Trainer( + gpus=2, + plugins='full_sharded', + fast_dev_run=True, + precision=16, + ) + + trainer.fit(model) + + _assert_save_equality(tmpdir, trainer) + + +def _assert_save_equality(tmpdir, trainer): + if trainer.global_rank == 0: + + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Ensure we gather all shards for comparison + model_state_dict = trainer.accelerator.training_type_plugin.collate_state_dict() + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()): + assert torch.equal(ddp_param.float().cpu(), shard_param) diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index f3683ffcba252..73d958c0db7de 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -47,7 +47,7 @@ def test_invalid_apex_sharded(tmpdir): """ model = BoringModel() - with pytest.raises(MisconfigurationException, match='Sharded Plugin is not supported with Apex AMP'): + with pytest.raises(MisconfigurationException, match='Sharded Plugins are not supported with Apex AMP'): trainer = Trainer( fast_dev_run=True, accelerator='ddp_sharded_spawn', diff --git a/tests/special_tests.sh b/tests/special_tests.sh index ffb21255a6d3c..7a9d514e35045 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -18,6 +18,7 @@ DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --ca python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu +python ${DEFAULTS} tests/plugins/test_full_sharded_plugin.py::test_full_sharded_plugin_checkpoint_multi_gpu python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual_amp From c36e00a0f8bb5eb552c84768f64f3b178650dc46 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 23 Feb 2021 14:11:50 +0000 Subject: [PATCH 02/62] Fix error in refactor --- tests/accelerators/test_accelerator_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/accelerators/test_accelerator_connector.py b/tests/accelerators/test_accelerator_connector.py index 730086dede580..82b631807c8e9 100644 --- a/tests/accelerators/test_accelerator_connector.py +++ b/tests/accelerators/test_accelerator_connector.py @@ -26,8 +26,8 @@ from pytorch_lightning.plugins import ( DDP2Plugin, DDPPlugin, + DDPShardedPlugin, DDPSpawnPlugin, - FullShardedPlugin, PrecisionPlugin, SingleDevicePlugin, ) @@ -396,7 +396,7 @@ def test_plugin_accelerator_choice(accelerator, plugin): Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent. """ trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2) - assert isinstance(trainer.accelerator.training_type_plugin, FullShardedPlugin) + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) trainer = Trainer(plugins=plugin, num_processes=2) - assert isinstance(trainer.accelerator.training_type_plugin, FullShardedPlugin) + assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin) From 59dbb8371e0ef30f87316c18296b6ad372eee458 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 24 Feb 2021 10:35:19 +0000 Subject: [PATCH 03/62] update --- .../plugins/precision/native_amp.py | 8 ++++ .../plugins/training_type/full_sharded.py | 43 +++++++++++++++---- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index 94e6cf376b03a..ec03f5e28b236 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -61,6 +61,7 @@ def backward( # unscale gradient to allow analyze within `on_after_backward` if not should_accumulate and model.automatic_optimization: self.scaler.unscale_(optimizer) + self.move_grad_to_cpu(model.trainer.model) return closure_loss @@ -88,6 +89,13 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: self.scaler.step(optimizer) self.scaler.update() + def move_grad_to_cpu(self, model): + if hasattr(model, "cpu_offload"): + if model.cpu_offload: + for param in model.params: + param._cpu_grad.copy_(param.grad.data, non_blocking=True) + param.grad.data = param._cpu_grad + @contextmanager def train_step_context(self) -> Generator[autocast, None, None]: """Enable autocast context""" diff --git a/pytorch_lightning/plugins/training_type/full_sharded.py b/pytorch_lightning/plugins/training_type/full_sharded.py index 9409d969ab418..3d057f6e8e42c 100644 --- a/pytorch_lightning/plugins/training_type/full_sharded.py +++ b/pytorch_lightning/plugins/training_type/full_sharded.py @@ -22,7 +22,8 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_AVAILABLE: - from fairscale.nn.data_parallel import FullyShardedDataParallel + from fairscale.nn.data_parallel.fully_sharded_data_parallel import ( + FullyShardedDataParallel, Parameter, TrainingState) from pytorch_lightning.overrides.fairscale import ( LightningFullShardedDataParallel, @@ -30,21 +31,44 @@ ) +class LightningFullyShardedDataParallel(FullyShardedDataParallel): + def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None: + """Hook to call on each param after the reduce-scatter.""" + assert torch.cuda.current_stream() == self._streams["post_backward"] + assert param.grad is not None + self.assert_state(TrainingState.BACKWARD) + param.grad.data = reduced_grad + # Cast grad to param's dtype (typically FP32). Note: we do this + # before the move_grads_to_cpu step so that this entire hook remains + # non-blocking. The downside is a bit more D2H transfer in that case. + if self.mixed_precision: + param.grad.data = param.grad.data.to(dtype=param.data.dtype) + # Optionally move gradients to CPU, typically used if one is running + # the optimizer on the CPU. + # issues with this part + + # This part needs to be done after unscaling the gradients. + #if self.move_grads_to_cpu: + # param._cpu_grad.copy_(param.grad.data, non_blocking=True) + # param.grad.data = param._cpu_grad + # Don't let this memory get reused until after the transfers. + #reduced_grad.record_stream(torch.cuda.current_stream()) + + class FullShardedPlugin(DDPPlugin): def __init__( self, cpu_offload: bool = True, flatten_parameters: bool = False, - reshard_after_forward: bool = True, - move_grads_to_cpu: Optional[bool] = None, - fp32_reduce_scatter: Optional[bool] = None, + reshard_after_forward: bool = False, + fp32_reduce_scatter: Optional[bool] = False, compute_dtype: Optional[torch.dtype] = None, bucket_cap_mb: int = 25, parallel_devices: Optional[List[torch.device]] = None, num_nodes: int = 1, cluster_environment: ClusterEnvironment = None, - sync_batchnorm: Optional[bool] = False + sync_batchnorm: Optional[bool] = False, ): """ @@ -72,7 +96,7 @@ def __init__( cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False). - move_grads_to_cpu: Moves gradient shards to CPU after reducation. + move_grads_to_cpu: Moves gradient shards to CPU after reduction. Only disable if using CPU based optimizers (defaults to ``cpu_offload``). flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency @@ -105,7 +129,6 @@ def __init__( raise MisconfigurationException("Currently sync batch norm is not supported by Full Sharded Training.") super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm=sync_batchnorm) self.cpu_offload = cpu_offload - self.move_grads_to_cpu = move_grads_to_cpu self.flatten_parameters = flatten_parameters self.reshard_after_forward = reshard_after_forward self.fp32_reduce_scatter = fp32_reduce_scatter @@ -113,11 +136,12 @@ def __init__( self.bucket_cap_mb = bucket_cap_mb def configure_ddp(self): - precision = self.lightning_module.trainer.precision + trainer = self.lightning_module.trainer + precision = trainer.precision self.model = FullyShardedDataParallel( LightningFullShardedDataParallel(self.model), cpu_offload=self.cpu_offload, - move_grads_to_cpu=self.move_grads_to_cpu, + move_grads_to_cpu=self.cpu_offload, flatten_parameters=self.flatten_parameters, mixed_precision=precision == "mixed", reshard_after_forward=self.reshard_after_forward, @@ -125,6 +149,7 @@ def configure_ddp(self): compute_dtype=self.compute_dtype, bucket_cap_mb=self.bucket_cap_mb, ) + trainer.accelerator.setup_optimizers(trainer) @property def lightning_module(self) -> LightningModule: From 19a1440e3d8b33120b1d76a81fda61e72c5dc7ce Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 24 Feb 2021 22:41:26 +0000 Subject: [PATCH 04/62] Revert "update" This reverts commit 59dbb837 --- .../plugins/precision/native_amp.py | 8 ---- .../plugins/training_type/full_sharded.py | 43 ++++--------------- 2 files changed, 9 insertions(+), 42 deletions(-) diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index ec03f5e28b236..94e6cf376b03a 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -61,7 +61,6 @@ def backward( # unscale gradient to allow analyze within `on_after_backward` if not should_accumulate and model.automatic_optimization: self.scaler.unscale_(optimizer) - self.move_grad_to_cpu(model.trainer.model) return closure_loss @@ -89,13 +88,6 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: self.scaler.step(optimizer) self.scaler.update() - def move_grad_to_cpu(self, model): - if hasattr(model, "cpu_offload"): - if model.cpu_offload: - for param in model.params: - param._cpu_grad.copy_(param.grad.data, non_blocking=True) - param.grad.data = param._cpu_grad - @contextmanager def train_step_context(self) -> Generator[autocast, None, None]: """Enable autocast context""" diff --git a/pytorch_lightning/plugins/training_type/full_sharded.py b/pytorch_lightning/plugins/training_type/full_sharded.py index 3d057f6e8e42c..9409d969ab418 100644 --- a/pytorch_lightning/plugins/training_type/full_sharded.py +++ b/pytorch_lightning/plugins/training_type/full_sharded.py @@ -22,8 +22,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_AVAILABLE: - from fairscale.nn.data_parallel.fully_sharded_data_parallel import ( - FullyShardedDataParallel, Parameter, TrainingState) + from fairscale.nn.data_parallel import FullyShardedDataParallel from pytorch_lightning.overrides.fairscale import ( LightningFullShardedDataParallel, @@ -31,44 +30,21 @@ ) -class LightningFullyShardedDataParallel(FullyShardedDataParallel): - def _post_reduction_hook(self, param: Parameter, reduced_grad: torch.Tensor) -> None: - """Hook to call on each param after the reduce-scatter.""" - assert torch.cuda.current_stream() == self._streams["post_backward"] - assert param.grad is not None - self.assert_state(TrainingState.BACKWARD) - param.grad.data = reduced_grad - # Cast grad to param's dtype (typically FP32). Note: we do this - # before the move_grads_to_cpu step so that this entire hook remains - # non-blocking. The downside is a bit more D2H transfer in that case. - if self.mixed_precision: - param.grad.data = param.grad.data.to(dtype=param.data.dtype) - # Optionally move gradients to CPU, typically used if one is running - # the optimizer on the CPU. - # issues with this part - - # This part needs to be done after unscaling the gradients. - #if self.move_grads_to_cpu: - # param._cpu_grad.copy_(param.grad.data, non_blocking=True) - # param.grad.data = param._cpu_grad - # Don't let this memory get reused until after the transfers. - #reduced_grad.record_stream(torch.cuda.current_stream()) - - class FullShardedPlugin(DDPPlugin): def __init__( self, cpu_offload: bool = True, flatten_parameters: bool = False, - reshard_after_forward: bool = False, - fp32_reduce_scatter: Optional[bool] = False, + reshard_after_forward: bool = True, + move_grads_to_cpu: Optional[bool] = None, + fp32_reduce_scatter: Optional[bool] = None, compute_dtype: Optional[torch.dtype] = None, bucket_cap_mb: int = 25, parallel_devices: Optional[List[torch.device]] = None, num_nodes: int = 1, cluster_environment: ClusterEnvironment = None, - sync_batchnorm: Optional[bool] = False, + sync_batchnorm: Optional[bool] = False ): """ @@ -96,7 +72,7 @@ def __init__( cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False). - move_grads_to_cpu: Moves gradient shards to CPU after reduction. + move_grads_to_cpu: Moves gradient shards to CPU after reducation. Only disable if using CPU based optimizers (defaults to ``cpu_offload``). flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency @@ -129,6 +105,7 @@ def __init__( raise MisconfigurationException("Currently sync batch norm is not supported by Full Sharded Training.") super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm=sync_batchnorm) self.cpu_offload = cpu_offload + self.move_grads_to_cpu = move_grads_to_cpu self.flatten_parameters = flatten_parameters self.reshard_after_forward = reshard_after_forward self.fp32_reduce_scatter = fp32_reduce_scatter @@ -136,12 +113,11 @@ def __init__( self.bucket_cap_mb = bucket_cap_mb def configure_ddp(self): - trainer = self.lightning_module.trainer - precision = trainer.precision + precision = self.lightning_module.trainer.precision self.model = FullyShardedDataParallel( LightningFullShardedDataParallel(self.model), cpu_offload=self.cpu_offload, - move_grads_to_cpu=self.cpu_offload, + move_grads_to_cpu=self.move_grads_to_cpu, flatten_parameters=self.flatten_parameters, mixed_precision=precision == "mixed", reshard_after_forward=self.reshard_after_forward, @@ -149,7 +125,6 @@ def configure_ddp(self): compute_dtype=self.compute_dtype, bucket_cap_mb=self.bucket_cap_mb, ) - trainer.accelerator.setup_optimizers(trainer) @property def lightning_module(self) -> LightningModule: From 3b38615fd27b265789bf0236de8bb9339fb2d63f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 24 Feb 2021 22:49:56 +0000 Subject: [PATCH 05/62] Address reviews --- pytorch_lightning/overrides/fairscale.py | 10 +-- pytorch_lightning/plugins/__init__.py | 10 +-- .../plugins/precision/__init__.py | 4 +- ...ive_amp.py => fully_sharded_native_amp.py} | 2 +- .../plugins/training_type/__init__.py | 2 +- .../{full_sharded.py => fully_sharded.py} | 80 +++++++++---------- .../plugins/training_type/rpc_sequential.py | 4 +- .../connectors/accelerator_connector.py | 24 +++--- pytorch_lightning/utilities/__init__.py | 2 +- pytorch_lightning/utilities/enums.py | 2 +- pytorch_lightning/utilities/imports.py | 2 +- tests/plugins/test_full_sharded_plugin.py | 38 ++++----- tests/special_tests.sh | 2 +- 13 files changed, 91 insertions(+), 91 deletions(-) rename pytorch_lightning/plugins/precision/{full_sharded_native_amp.py => fully_sharded_native_amp.py} (92%) rename pytorch_lightning/plugins/training_type/{full_sharded.py => fully_sharded.py} (54%) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 4caa4e9b109c9..681fba9eb7dc9 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -13,7 +13,7 @@ # limitations under the License. from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module -from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULL_SHARDED_AVAILABLE +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE LightningShardedDataParallel = None if _FAIRSCALE_AVAILABLE: @@ -31,16 +31,16 @@ def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule: return unwrap_lightning_module(model) -LightningFullShardedDataParallel = None -if _FAIRSCALE_FULL_SHARDED_AVAILABLE: +LightningFullyShardedDataParallel = None +if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: from fairscale.nn import FlattenParamsWrapper from fairscale.nn.data_parallel import FullyShardedDataParallel - class LightningFullShardedDataParallel(_LightningModuleWrapperBase): + class LightningFullyShardedDataParallel(_LightningModuleWrapperBase): # Just do this for later docstrings pass - def unwrap_lightning_module_full_sharded(wrapped_model) -> LightningModule: + def unwrap_lightning_module_fully_sharded(wrapped_model) -> LightningModule: model = wrapped_model if isinstance(model, FullyShardedDataParallel): model = model.module diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index df836cddbfa92..8487defc68977 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -1,8 +1,8 @@ from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.full_sharded_native_amp import ( # noqa: F401 - FullShardedNativeMixedPrecisionPlugin, +from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401 + FullyShardedNativeMixedPrecisionPlugin, ) from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 @@ -13,7 +13,7 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 -from pytorch_lightning.plugins.training_type.full_sharded import FullShardedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.fully_sharded import FullyShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 @@ -33,8 +33,8 @@ "DDPSpawnPlugin", "DeepSpeedPlugin", "DeepSpeedPrecisionPlugin", - "FullShardedPlugin", - "FullShardedNativeMixedPrecisionPlugin", + "FullyShardedPlugin", + "FullyShardedNativeMixedPrecisionPlugin", "HorovodPlugin", "NativeMixedPrecisionPlugin", "PrecisionPlugin", diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py index 5c5d8d95cc30f..43fdc24c1f4ab 100644 --- a/pytorch_lightning/plugins/precision/__init__.py +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -1,7 +1,7 @@ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.deepspeed_precision import DeepSpeedPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.full_sharded_native_amp import ( # noqa: F401 - FullShardedNativeMixedPrecisionPlugin, +from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401 + FullyShardedNativeMixedPrecisionPlugin, ) from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/precision/full_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py similarity index 92% rename from pytorch_lightning/plugins/precision/full_sharded_native_amp.py rename to pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index f2b5a53f9aff0..9e6837f47c352 100644 --- a/pytorch_lightning/plugins/precision/full_sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -18,7 +18,7 @@ from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -class FullShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): +class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): """Mixed Precision for Full Sharded Training """ diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py index bdce16121375c..cca55ece01857 100644 --- a/pytorch_lightning/plugins/training_type/__init__.py +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -3,7 +3,7 @@ from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 -from pytorch_lightning.plugins.training_type.full_sharded import FullShardedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.fully_sharded import FullyShardedPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/training_type/full_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py similarity index 54% rename from pytorch_lightning/plugins/training_type/full_sharded.py rename to pytorch_lightning/plugins/training_type/fully_sharded.py index 9409d969ab418..7018ab7152ae1 100644 --- a/pytorch_lightning/plugins/training_type/full_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -18,19 +18,19 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin -from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULL_SHARDED_AVAILABLE +from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -if _FAIRSCALE_AVAILABLE: +if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: from fairscale.nn.data_parallel import FullyShardedDataParallel from pytorch_lightning.overrides.fairscale import ( - LightningFullShardedDataParallel, - unwrap_lightning_module_full_sharded, + LightningFullyShardedDataParallel, + unwrap_lightning_module_fully_sharded, ) -class FullShardedPlugin(DDPPlugin): +class FullyShardedPlugin(DDPPlugin): def __init__( self, @@ -48,55 +48,55 @@ def __init__( ): """ - Provides capabilities to run training using the Full Sharded capabilities provided by FairScale. + Provides capabilities to run training using the Full Sharded capabilities provided by FairScale. - Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model - size, whilst using efficient communication to reduce overhead. In practice, this means we can remain - at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar - to ZeRO-Stage 3 but have been modified/adjusted for PyTorch. + Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model + size, whilst using efficient communication to reduce overhead. In practice, this means we can remain + at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar + to ZeRO-Stage 3 but have been modified/adjusted for PyTorch. - `For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`. + `For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`. - .. warning:: ``FullShardedPlugin`` is in beta and subject to change. + .. warning:: ``FullyShardedPlugin`` is in beta and subject to change. - Defaults have been set to enable CPU Offload, but options have been exposed and may require configuration - based on your level of memory/speed efficiency. - We suggest having a look at this PR for more information. - `https://github.com/facebookresearch/fairscale/pull/413` + Defaults have been set to enable CPU Offload, but options have been exposed and may require configuration + based on your level of memory/speed efficiency. + We suggest having a look at this PR for more information. + `https://github.com/facebookresearch/fairscale/pull/413` - Many of the helpful doc strings below came from the original FairScale documentation: - `https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html` + Many of the helpful doc strings below came from the original FairScale documentation: + `https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html` - Arguments: + Arguments: - cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False). + cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False). - move_grads_to_cpu: Moves gradient shards to CPU after reducation. - Only disable if using CPU based optimizers (defaults to ``cpu_offload``). + move_grads_to_cpu: Moves gradient shards to CPU after reducation. + Only disable if using CPU based optimizers (defaults to ``cpu_offload``). - flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency - (default: False). + flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency + (default: False). - reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows - down training. Only revelant when nesting FullyShardedDataParallel wrappers inside the model. - (default: False). + reshard_after_forward: Reshard parameters after the forward pass, which saves memory but slows + down training. Only revelant when nesting FullyShardedDataParallel wrappers inside the model. + (default: False). - fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision - (default: None) + fp32_reduce_scatter: Reduce-Scatter gradients in FP32. Only relevant in mixed precision + (default: None) - compute_dtype: dtype for full parameters for computation. Default to torch.float32, - unless using mixed precision, in which case defaults to torch.float16. + compute_dtype: dtype for full parameters for computation. Default to torch.float32, + unless using mixed precision, in which case defaults to torch.float16. - bucket_cap_mb: bucket parameters so that gradient reduction - can potentially overlap with backward computation. - bucket_cap_mb controls the bucket size in MegaBytes (MB). - Buckets are sub-divided based on world_size, - so the max shard size is roughly bucket_cap_mb / world_size. - Values <= 0 disable bucketing. (Default: 25). + bucket_cap_mb: bucket parameters so that gradient reduction + can potentially overlap with backward computation. + bucket_cap_mb controls the bucket size in MegaBytes (MB). + Buckets are sub-divided based on world_size, + so the max shard size is roughly bucket_cap_mb / world_size. + Values <= 0 disable bucketing. (Default: 25). """ - if not _FAIRSCALE_FULL_SHARDED_AVAILABLE: + if not _FAIRSCALE_FULLY_SHARDED_AVAILABLE: raise MisconfigurationException( "Full Sharded Training is not available. Install the latest FairScale via `pip install fairscale -U`" ) @@ -115,7 +115,7 @@ def __init__( def configure_ddp(self): precision = self.lightning_module.trainer.precision self.model = FullyShardedDataParallel( - LightningFullShardedDataParallel(self.model), + LightningFullyShardedDataParallel(self.model), cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, flatten_parameters=self.flatten_parameters, @@ -128,7 +128,7 @@ def configure_ddp(self): @property def lightning_module(self) -> LightningModule: - return unwrap_lightning_module_full_sharded(self.model) + return unwrap_lightning_module_fully_sharded(self.model) def model_to_device(self): if not self.cpu_offload: diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 9ffcfc770ab59..7882831d50b55 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -57,7 +57,7 @@ def __init__( .. _RPCSequentialPlugin: https://arxiv.org/abs/1811.06965 .. warning:: - This plugin has been deprecated. Please use the ``FullShardedPlugin`` which provides better performance + This plugin has been deprecated. Please use the ``FullyShardedPlugin`` which provides better performance and scaling without pipelining the model. Pipeline parallelism comes with with checkpointing to reduce peak @@ -92,7 +92,7 @@ def __init__( `get_model_parallel_world_size() > 1` """ rank_zero_warn( - "RPC Sequential Plugin has been deprecated. Please use the `FullShardedPlugin` " + "RPC Sequential Plugin has been deprecated. Please use the `FullyShardedPlugin` " "which provides better performance and scaling without pipelining the model." ) self._check_pipe_available() diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 74391b61580d2..12a89305b66e4 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -32,8 +32,8 @@ DDPSpawnShardedPlugin, DeepSpeedPlugin, DeepSpeedPrecisionPlugin, - FullShardedNativeMixedPrecisionPlugin, - FullShardedPlugin, + FullyShardedNativeMixedPrecisionPlugin, + FullyShardedPlugin, HorovodPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin, @@ -251,7 +251,7 @@ def use_dp(self) -> bool: def use_ddp(self) -> bool: return self._distrib_type in ( DistributedType.DDP, DistributedType.DDP_SPAWN, DistributedType.DDP_SHARDED, - DistributedType.DDP_SHARDED_SPAWN, DistributedType.FULL_SHARDED, DistributedType.DEEPSPEED + DistributedType.DDP_SHARDED_SPAWN, DistributedType.FULLY_SHARDED, DistributedType.DEEPSPEED ) @property @@ -333,8 +333,8 @@ def select_precision_plugin(self) -> PrecisionPlugin: log.info("Using native 16bit precision.") if self._sharded_training_type: return ShardedNativeMixedPrecisionPlugin() - if self._full_sharded_training_type: - return FullShardedNativeMixedPrecisionPlugin() + if self._fully_sharded_training_type: + return FullyShardedNativeMixedPrecisionPlugin() return NativeMixedPrecisionPlugin() if self.amp_type == AMPType.APEX: @@ -343,7 +343,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: "You have asked for Apex AMP but you have not installed it yet." " Install apex first using this guide: https://github.com/NVIDIA/apex#linux" ) - if self._sharded_training_type or self._full_sharded_training_type: + if self._sharded_training_type or self._fully_sharded_training_type: raise MisconfigurationException( "Sharded Plugins are not supported with Apex AMP," " please using native AMP for 16-bit precision." @@ -371,7 +371,7 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: use_ddp_cpu_slurm = use_ddp_cpu_spawn and self.is_slurm_managing_tasks use_ddp_sharded = self._distrib_type == DistributedType.DDP_SHARDED use_ddp_sharded_spawn = self._distrib_type == DistributedType.DDP_SHARDED_SPAWN - use_ddp_full_sharded = self._distrib_type == DistributedType.FULL_SHARDED + use_ddp_fully_sharded = self._distrib_type == DistributedType.FULLY_SHARDED # TODO: decouple from TE # ddp script mode uses the same flags as TE if os.environ.get("PL_IN_DDP_SUBPROCESS", False): @@ -379,8 +379,8 @@ def select_training_type_plugin(self) -> TrainingTypePlugin: if self.on_tpu: ddp_plugin_cls = TPUSpawnPlugin - elif use_ddp_full_sharded: - ddp_plugin_cls = FullShardedPlugin + elif use_ddp_fully_sharded: + ddp_plugin_cls = FullyShardedPlugin elif use_ddp_sharded: ddp_plugin_cls = DDPShardedPlugin elif use_ddp_sharded_spawn: @@ -626,7 +626,7 @@ def _sharded_training_type(self) -> bool: )) or self._distrib_type in (DistributedType.DDP_SHARDED, DistributedType.DDP_SHARDED_SPAWN) @property - def _full_sharded_training_type(self) -> bool: + def _fully_sharded_training_type(self) -> bool: return isinstance( - self.training_type_plugin, FullShardedPlugin - ) or self._distrib_type == DistributedType.FULL_SHARDED + self.training_type_plugin, FullyShardedPlugin + ) or self._distrib_type == DistributedType.FULLY_SHARDED diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 905237c3ed164..9b50835b38a3e 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -28,7 +28,7 @@ _BOLTS_AVAILABLE, _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, - _FAIRSCALE_FULL_SHARDED_AVAILABLE, + _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE, _GROUP_AVAILABLE, _HOROVOD_AVAILABLE, diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index e83d70b22bb09..ba7b8c9acec5f 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -66,7 +66,7 @@ class DistributedType(LightningEnum): HOROVOD = 'horovod' DDP_SHARDED = 'ddp_sharded' DDP_SHARDED_SPAWN = 'ddp_sharded_spawn' - FULL_SHARDED = 'full_sharded' + FULLY_SHARDED = 'fully_sharded' RPC_SEQUENTIAL_PLUGIN = 'rpc_sequential' diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 8a5a36495ca9c..4d2cbf863eaa7 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -57,7 +57,7 @@ def _compare_version(package: str, op, version) -> bool: _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel') -_FAIRSCALE_FULL_SHARDED_AVAILABLE = not _IS_WINDOWS and _compare_version("fairscale", operator.ge, "0.3.0") +_FAIRSCALE_FULLY_SHARDED_AVAILABLE = not _IS_WINDOWS and _compare_version("fairscale", operator.ge, "0.3.0") _FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') _HOROVOD_AVAILABLE = _module_available("horovod.torch") diff --git a/tests/plugins/test_full_sharded_plugin.py b/tests/plugins/test_full_sharded_plugin.py index 362a22281326a..ed5a4f4aceeb0 100644 --- a/tests/plugins/test_full_sharded_plugin.py +++ b/tests/plugins/test_full_sharded_plugin.py @@ -7,14 +7,14 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback -from pytorch_lightning.plugins import FullShardedNativeMixedPrecisionPlugin, FullShardedPlugin -from pytorch_lightning.utilities import _APEX_AVAILABLE, _FAIRSCALE_FULL_SHARDED_AVAILABLE, _NATIVE_AMP_AVAILABLE +from pytorch_lightning.plugins import FullyShardedNativeMixedPrecisionPlugin, FullyShardedPlugin +from pytorch_lightning.utilities import _APEX_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _NATIVE_AMP_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel -@pytest.mark.parametrize(["plugin"], [("full_sharded", )]) -@pytest.mark.skipif(not _FAIRSCALE_FULL_SHARDED_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.parametrize(["plugin"], [("fully_sharded", )]) +@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") def test_sharded_ddp_choice(tmpdir, plugin): """ Test to ensure that plugin is correctly chosen @@ -23,8 +23,8 @@ def test_sharded_ddp_choice(tmpdir, plugin): class CB(Callback): def on_fit_start(self, trainer, pl_module): - if plugin == 'full_sharded': - assert isinstance(trainer.accelerator.training_type_plugin, FullShardedPlugin) + if plugin == 'fully_sharded': + assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin) raise SystemExit() model = BoringModel() @@ -39,7 +39,7 @@ def on_fit_start(self, trainer, pl_module): @pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex") -@pytest.mark.skipif(not _FAIRSCALE_FULL_SHARDED_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") def test_invalid_apex_sharded(tmpdir): """ Test to ensure that we raise an error when we try to use apex and sharded @@ -49,7 +49,7 @@ def test_invalid_apex_sharded(tmpdir): with pytest.raises(MisconfigurationException, match='Sharded Plugins are not supported with Apex AMP'): trainer = Trainer( fast_dev_run=True, - plugins='full_sharded', + plugins='fully_sharded', precision=16, amp_backend='apex', ) @@ -57,8 +57,8 @@ def test_invalid_apex_sharded(tmpdir): trainer.fit(model) -@pytest.mark.parametrize(["plugin"], [("full_sharded", )]) -@pytest.mark.skipif(not _FAIRSCALE_FULL_SHARDED_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.parametrize(["plugin"], [("fully_sharded", )]) +@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") @pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) @mock.patch('torch.cuda.device_count', return_value=1) @@ -71,9 +71,9 @@ def test_ddp_choice_sharded_amp(device_count_mock, mock_cuda_available, plugin, class CB(Callback): def on_fit_start(self, trainer, pl_module): - if plugin == 'full_sharded': - assert isinstance(trainer.accelerator.training_type_plugin, FullShardedPlugin) - assert isinstance(trainer.accelerator.precision_plugin, FullShardedNativeMixedPrecisionPlugin) + if plugin == 'fully_sharded': + assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin) + assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) raise SystemExit() model = BoringModel() @@ -91,15 +91,15 @@ def on_fit_start(self, trainer, pl_module): @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires CUDA") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not _FAIRSCALE_FULL_SHARDED_AVAILABLE, reason="Fairscale is not available") -def test_full_sharded_plugin_checkpoint(tmpdir): +@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") +def test_fully_sharded_plugin_checkpoint(tmpdir): """ Test to ensure that checkpoint is saved correctly when using a single GPU. """ model = BoringModel() trainer = Trainer( gpus=1, - plugins='full_sharded', + plugins='fully_sharded', fast_dev_run=True, precision=16, ) @@ -111,18 +111,18 @@ def test_full_sharded_plugin_checkpoint(tmpdir): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not _FAIRSCALE_FULL_SHARDED_AVAILABLE, reason="Fairscale is not available") +@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) -def test_full_sharded_plugin_checkpoint_multi_gpu(tmpdir): +def test_fully_sharded_plugin_checkpoint_multi_gpu(tmpdir): """ Test to ensure that checkpoint is saved correctly when using multiple GPUs """ model = BoringModel() trainer = Trainer( gpus=2, - plugins='full_sharded', + plugins='fully_sharded', fast_dev_run=True, precision=16, ) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 7a9d514e35045..e79b8d7ac7ea5 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -18,7 +18,7 @@ DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --ca python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp python ${DEFAULTS} tests/models/test_sync_batchnorm.py::test_sync_batchnorm_ddp python ${DEFAULTS} tests/plugins/test_deepspeed_plugin.py::test_deepspeed_multigpu -python ${DEFAULTS} tests/plugins/test_full_sharded_plugin.py::test_full_sharded_plugin_checkpoint_multi_gpu +python ${DEFAULTS} tests/plugins/test_fully_sharded_plugin.py::test_fully_sharded_plugin_checkpoint_multi_gpu python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual python ${DEFAULTS} tests/plugins/test_rpc_sequential_plugin.py::test_rpc_sequential_plugin_manual_amp From 5ff06ab00ca713dc8da1b24f7772d5b9f778bfa7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 24 Feb 2021 22:50:30 +0000 Subject: [PATCH 06/62] Fix doc string --- .../plugins/precision/fully_sharded_native_amp.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 9e6837f47c352..34ff55229360e 100644 --- a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -19,8 +19,7 @@ class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): - """Mixed Precision for Full Sharded Training - """ + """Mixed Precision for Full Sharded Training""" def clip_gradients( self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0) From 36434f0561c6d22112290e64c8f4d78efc62f12a Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 24 Feb 2021 22:56:42 +0000 Subject: [PATCH 07/62] Even moar code review --- pytorch_lightning/overrides/fairscale.py | 21 ++++++++++--------- .../plugins/training_type/rpc_sequential.py | 2 +- pytorch_lightning/utilities/enums.py | 2 +- tests/plugins/test_full_sharded_plugin.py | 12 +++++------ 4 files changed, 19 insertions(+), 18 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 681fba9eb7dc9..af8330f605cf9 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -15,14 +15,15 @@ from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE -LightningShardedDataParallel = None + +class LightningShardedDataParallel(_LightningModuleWrapperBase): + # Just do this for later docstrings + pass + + if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel - class LightningShardedDataParallel(_LightningModuleWrapperBase): - # Just do this for later docstrings - pass - def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule: model = wrapped_model if isinstance(model, ShardedDataParallel): @@ -31,15 +32,15 @@ def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule: return unwrap_lightning_module(model) -LightningFullyShardedDataParallel = None +class LightningFullyShardedDataParallel(_LightningModuleWrapperBase): + # Just do this for later docstrings + pass + + if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: from fairscale.nn import FlattenParamsWrapper from fairscale.nn.data_parallel import FullyShardedDataParallel - class LightningFullyShardedDataParallel(_LightningModuleWrapperBase): - # Just do this for later docstrings - pass - def unwrap_lightning_module_fully_sharded(wrapped_model) -> LightningModule: model = wrapped_model if isinstance(model, FullyShardedDataParallel): diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 7882831d50b55..82fea322747da 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -56,7 +56,7 @@ def __init__( .. _RPCSequentialPlugin: https://arxiv.org/abs/1811.06965 - .. warning:: + .. deprecated:: This plugin has been deprecated. Please use the ``FullyShardedPlugin`` which provides better performance and scaling without pipelining the model. diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index ba7b8c9acec5f..35059f92274a6 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -66,7 +66,7 @@ class DistributedType(LightningEnum): HOROVOD = 'horovod' DDP_SHARDED = 'ddp_sharded' DDP_SHARDED_SPAWN = 'ddp_sharded_spawn' - FULLY_SHARDED = 'fully_sharded' + FULLY_SHARDED = 'ddp_fully_sharded' RPC_SEQUENTIAL_PLUGIN = 'rpc_sequential' diff --git a/tests/plugins/test_full_sharded_plugin.py b/tests/plugins/test_full_sharded_plugin.py index ed5a4f4aceeb0..cc71e8c668590 100644 --- a/tests/plugins/test_full_sharded_plugin.py +++ b/tests/plugins/test_full_sharded_plugin.py @@ -13,7 +13,7 @@ from tests.helpers.boring_model import BoringModel -@pytest.mark.parametrize(["plugin"], [("fully_sharded", )]) +@pytest.mark.parametrize(["plugin"], [("ddp_fully_sharded", )]) @pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") def test_sharded_ddp_choice(tmpdir, plugin): """ @@ -23,7 +23,7 @@ def test_sharded_ddp_choice(tmpdir, plugin): class CB(Callback): def on_fit_start(self, trainer, pl_module): - if plugin == 'fully_sharded': + if plugin == 'ddp_fully_sharded': assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin) raise SystemExit() @@ -49,7 +49,7 @@ def test_invalid_apex_sharded(tmpdir): with pytest.raises(MisconfigurationException, match='Sharded Plugins are not supported with Apex AMP'): trainer = Trainer( fast_dev_run=True, - plugins='fully_sharded', + plugins='ddp_fully_sharded', precision=16, amp_backend='apex', ) @@ -57,7 +57,7 @@ def test_invalid_apex_sharded(tmpdir): trainer.fit(model) -@pytest.mark.parametrize(["plugin"], [("fully_sharded", )]) +@pytest.mark.parametrize(["plugin"], [("ddp_fully_sharded", )]) @pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") @pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) @@ -71,7 +71,7 @@ def test_ddp_choice_sharded_amp(device_count_mock, mock_cuda_available, plugin, class CB(Callback): def on_fit_start(self, trainer, pl_module): - if plugin == 'fully_sharded': + if plugin == 'ddp_fully_sharded': assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin) assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) raise SystemExit() @@ -99,7 +99,7 @@ def test_fully_sharded_plugin_checkpoint(tmpdir): model = BoringModel() trainer = Trainer( gpus=1, - plugins='fully_sharded', + plugins='ddp_fully_sharded', fast_dev_run=True, precision=16, ) From c61a19081065497d61f79d5b27061b7bd5d6de4d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 24 Feb 2021 23:07:57 +0000 Subject: [PATCH 08/62] Add deprecation --- pytorch_lightning/plugins/training_type/rpc_sequential.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 82fea322747da..09179e21805a0 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -93,7 +93,7 @@ def __init__( """ rank_zero_warn( "RPC Sequential Plugin has been deprecated. Please use the `FullyShardedPlugin` " - "which provides better performance and scaling without pipelining the model." + "which provides better performance and scaling without pipelining the model.", DeprecationWarning ) self._check_pipe_available() super().__init__(rpc_timeout_sec=rpc_timeout_sec, **kwargs) From 02599e6ca8bf802fc9d314edf4abbb199769f17b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 25 Feb 2021 13:59:15 +0000 Subject: [PATCH 09/62] Fix name of test --- .../{test_full_sharded_plugin.py => test_fully_sharded_plugin.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/plugins/{test_full_sharded_plugin.py => test_fully_sharded_plugin.py} (100%) diff --git a/tests/plugins/test_full_sharded_plugin.py b/tests/plugins/test_fully_sharded_plugin.py similarity index 100% rename from tests/plugins/test_full_sharded_plugin.py rename to tests/plugins/test_fully_sharded_plugin.py From e79977a45c8a6818594bae0c64d2651b28b1eb8f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 1 Mar 2021 14:35:38 +0000 Subject: [PATCH 10/62] Integrate nesting, fix bugs across implementation --- pytorch_lightning/accelerators/accelerator.py | 3 +- pytorch_lightning/core/hooks.py | 7 +++ pytorch_lightning/core/lightning.py | 4 ++ pytorch_lightning/overrides/fairscale.py | 3 +- .../plugins/training_type/ddp.py | 13 ++++- .../plugins/training_type/fully_sharded.py | 54 +++++++++++++++---- .../training_type/training_type_plugin.py | 4 ++ .../connectors/checkpoint_connector.py | 3 +- 8 files changed, 76 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index f0c8cfdf4552e..5965cb5619665 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -74,7 +74,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None: model: the model to train """ self.connect_training_type_plugin(self.training_type_plugin, model) - self.setup_optimizers(trainer) + if not self.training_type_plugin.manage_configure_optimizers: + self.setup_optimizers(trainer) self.connect_precision_plugin(self.precision_plugin) def start_training(self, trainer: 'Trainer') -> None: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index e0b33c1219e8b..569262c430b7a 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -334,6 +334,13 @@ def on_post_move_to_device(self): """ + def on_distributed_model_setup(self) -> None: + """ + + Returns: + + """ + class DataHooks: """Hooks to be used for data related stuff.""" diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c4d63cff4637b..a954becc7b21e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -462,6 +462,10 @@ def all_gather( all_gather = partial(all_gather, group=group, sync_grads=sync_grads) return apply_to_collection(data, torch.Tensor, all_gather) + @property + def accelerator_model(self): + return self.trainer.accelerator.model + def forward(self, *args, **kwargs): r""" Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index af8330f605cf9..52032bd437536 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, unwrap_lightning_module from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE @@ -32,7 +33,7 @@ def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule: return unwrap_lightning_module(model) -class LightningFullyShardedDataParallel(_LightningModuleWrapperBase): +class LightningFullyShardedDataModule(_LightningModuleWrapperBase): # Just do this for later docstrings pass diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 80161d6e59b6b..3a7c6e760e79a 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -250,13 +250,22 @@ def pre_dispatch(self): if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) - # move the model to the correct device - self.model_to_device() + if self.move_to_device_in_prefetch: + # move the model to the correct device + self.model_to_device() self.configure_ddp() self.barrier() + @property + def move_to_device_in_prefetch(self) -> bool: + """ + We will call the model_to_device hook within pre-fetch if this is set to True. + Useful for when plugin would like to call model_to_device at another time, or skip the call. + """ + return True + def post_dispatch(self): if "WORLD_SIZE" in os.environ: del os.environ["WORLD_SIZE"] diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 7018ab7152ae1..6bd5f7a093ad1 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -22,10 +22,11 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + from fairscale.nn import enable_wrap from fairscale.nn.data_parallel import FullyShardedDataParallel from pytorch_lightning.overrides.fairscale import ( - LightningFullyShardedDataParallel, + LightningFullyShardedDataModule, unwrap_lightning_module_fully_sharded, ) @@ -34,8 +35,8 @@ class FullyShardedPlugin(DDPPlugin): def __init__( self, - cpu_offload: bool = True, - flatten_parameters: bool = False, + cpu_offload: bool = False, + flatten_parameters: bool = True, reshard_after_forward: bool = True, move_grads_to_cpu: Optional[bool] = None, fp32_reduce_scatter: Optional[bool] = None, @@ -72,8 +73,8 @@ def __init__( cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False). - move_grads_to_cpu: Moves gradient shards to CPU after reducation. - Only disable if using CPU based optimizers (defaults to ``cpu_offload``). + move_grads_to_cpu: Moves gradient shards to CPU after reduction. + Only disable if using CPU based optimizers (defaults to ``cpu_offload``). flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency (default: False). @@ -111,11 +112,35 @@ def __init__( self.fp32_reduce_scatter = fp32_reduce_scatter self.compute_dtype = compute_dtype self.bucket_cap_mb = bucket_cap_mb + self._process_group = None + + @property + def process_group(self): + if self._process_group is None: + self._process_group = torch.distributed.new_group() + return self._process_group def configure_ddp(self): precision = self.lightning_module.trainer.precision + + # set the device before instantiate the wrapper + if self.root_device.type == "cuda": + torch.cuda.set_device(self.root_device) + + with enable_wrap( + cpu_offload=self.cpu_offload, + flatten_parameters=self.flatten_parameters, + move_grads_to_cpu=self.move_grads_to_cpu, + mixed_precision=precision == "mixed", + process_group=self.process_group + ): + # todo: this should somehow be incorporated as a general hook. + # currently this also means you have to use fully sharded to load the model as well. + self.lightning_module.trainer.call_hook("on_distributed_model_setup") + self.model = FullyShardedDataParallel( - LightningFullyShardedDataParallel(self.model), + LightningFullyShardedDataModule(self.model), + process_group=self.process_group, cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, flatten_parameters=self.flatten_parameters, @@ -125,14 +150,21 @@ def configure_ddp(self): compute_dtype=self.compute_dtype, bucket_cap_mb=self.bucket_cap_mb, ) + if not self.cpu_offload: + super().model_to_device() + # setup optimizers after fully sharded has wrapped the lightning module + self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer) + + def model_to_device(self): + self.model.to(self.root_device) @property def lightning_module(self) -> LightningModule: return unwrap_lightning_module_fully_sharded(self.model) - def model_to_device(self): - if not self.cpu_offload: - super().model_to_device() + @property + def move_to_device_in_prefetch(self): + return False def on_save(self, checkpoint: dict) -> dict: state_dict = self.collate_state_dict() @@ -148,3 +180,7 @@ def collate_state_dict(self): # Remove module prefix from state dict as this is the behaviour of state dict. state_dict = {k.partition('module.')[2]: state_dict[k] for k in state_dict.keys()} return state_dict + + @property + def manage_configure_optimizers(self) -> bool: + return True diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 77fae2746c402..59e9270b5a8f0 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -165,3 +165,7 @@ def init_optimizers(self, trainer: "Trainer", model: LightningModule): def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): optimizer.step(closure=lambda_closure, **kwargs) + + @property + def manage_configure_optimizers(self) -> bool: + return False diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 2384f2aa8fbe5..0ea9866edc425 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -399,10 +399,9 @@ def save_checkpoint(self, filepath, weights_only: bool = False): """ # dump states as a checkpoint dictionary object checkpoint = self.dump_checkpoint(weights_only) + checkpoint = self.trainer.accelerator.on_save(checkpoint) if self.trainer.is_global_zero: # write the checkpoint dictionary on the file - - checkpoint = self.trainer.accelerator.on_save(checkpoint) try: atomic_save(checkpoint, filepath) except AttributeError as err: From ebf1818ed40a440e988c6a16274fc5a56c1163d4 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 2 Mar 2021 15:59:59 +0000 Subject: [PATCH 11/62] Formatting types --- pytorch_lightning/plugins/precision/deepspeed_precision.py | 3 ++- .../plugins/precision/fully_sharded_native_amp.py | 2 +- pytorch_lightning/plugins/precision/precision_plugin.py | 3 ++- pytorch_lightning/plugins/precision/sharded_native_amp.py | 4 +++- requirements/extra.txt | 2 +- 5 files changed, 9 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 9574da0fb982c..700019c851376 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -76,7 +76,8 @@ def backward( return closure_loss def clip_gradients( - self, model: Any, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + self, model: Any, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0 + ) -> None: """ DeepSpeed handles clipping gradients via the training type plugin. """ diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 34ff55229360e..651b7030c9316 100644 --- a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -23,6 +23,6 @@ class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): def clip_gradients( self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0) - ): + ) -> None: # Model manages clipping of gradients model.clip_grad_norm_(clip_val, norm_type) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 18b4591e82b43..6281efa3859e2 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -99,7 +99,8 @@ def post_optimizer_step(self, optimizer: 'Optimizer', optimizer_idx: int) -> Non """Hook to do something after each optimizer step.""" def clip_gradients( - self, model: Any, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + self, model: Any, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0 + ) -> None: """Clips the gradients to a specific value""" # TODO: separate TPU case from here if clip_val is None: diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index bf0a4511206b9..719ebd3190e26 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -32,6 +32,8 @@ def __init__(self) -> None: super().__init__() self.scaler = ShardedGradScaler() - def clip_gradients(self, model: Any, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: + def clip_gradients( + self, model: Any, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0 + ) -> None: optimizer = cast(OSS, optimizer) optimizer.clip_grad_norm(clip_val, norm_type=norm_type) diff --git a/requirements/extra.txt b/requirements/extra.txt index dc9796af4ee8c..33009078dffb2 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,4 +7,4 @@ torchtext>=0.5 onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -fairscale>=0.3.0 +https://github.com/facebookresearch/fairscale/archive/f3359550d9bb3a2e4c1fcdcfef739e0d97fff774.zip From 290e8fd042fc60e041bae7f955c38191a49b3ce1 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 2 Mar 2021 16:48:23 +0000 Subject: [PATCH 12/62] Add additional tests for accelerator model --- tests/plugins/test_fully_sharded_plugin.py | 47 ++++++++++++++++++---- tests/trainer/properties/test_get_model.py | 42 +++++++++++++++++++ 2 files changed, 81 insertions(+), 8 deletions(-) diff --git a/tests/plugins/test_fully_sharded_plugin.py b/tests/plugins/test_fully_sharded_plugin.py index cc71e8c668590..153b61cce1b69 100644 --- a/tests/plugins/test_fully_sharded_plugin.py +++ b/tests/plugins/test_fully_sharded_plugin.py @@ -1,5 +1,4 @@ import os -import platform from unittest import mock import pytest @@ -8,9 +7,13 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import FullyShardedNativeMixedPrecisionPlugin, FullyShardedPlugin -from pytorch_lightning.utilities import _APEX_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _NATIVE_AMP_AVAILABLE +from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel +from tests.helpers.runif import RunIf + +if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + from fairscale.nn import auto_wrap, FullyShardedDataParallel @pytest.mark.parametrize(["plugin"], [("ddp_fully_sharded", )]) @@ -38,8 +41,8 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) -@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex") @pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") +@RunIf(amp_apex=True) def test_invalid_apex_sharded(tmpdir): """ Test to ensure that we raise an error when we try to use apex and sharded @@ -59,10 +62,10 @@ def test_invalid_apex_sharded(tmpdir): @pytest.mark.parametrize(["plugin"], [("ddp_fully_sharded", )]) @pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") -@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP") @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) @mock.patch('torch.cuda.device_count', return_value=1) @mock.patch('torch.cuda.is_available', return_value=True) +@RunIf(amp_native=True) def test_ddp_choice_sharded_amp(device_count_mock, mock_cuda_available, plugin, tmpdir): """ Test to ensure that plugin native amp plugin is correctly chosen when using sharded @@ -89,9 +92,8 @@ def on_fit_start(self, trainer, pl_module): trainer.fit(model) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires CUDA") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") +@RunIf(min_gpus=1, skip_windows=True) def test_fully_sharded_plugin_checkpoint(tmpdir): """ Test to ensure that checkpoint is saved correctly when using a single GPU. @@ -109,12 +111,41 @@ def test_fully_sharded_plugin_checkpoint(tmpdir): _assert_save_equality(tmpdir, trainer) -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") +@RunIf(min_gpus=1, skip_windows=True) +def test_fully_sharded_plugin_checkpoint_autowrap(tmpdir): + """ + Test to ensure that checkpoint is saved correctly when using auto_wrap. + """ + + class TestModel(BoringModel): + + def on_distributed_model_setup(self) -> None: + self.layer = auto_wrap(self.layer, min_num_params=1) + + def on_train_start(self) -> None: + assert isinstance(self.layer, FullyShardedDataParallel) + assert isinstance(self.accelerator_model, FullyShardedDataParallel) + + model = TestModel() + + trainer = Trainer( + gpus=1, + plugins='ddp_fully_sharded', + fast_dev_run=True, + precision=16, + ) + + trainer.fit(model) + + _assert_save_equality(tmpdir, trainer) + + @pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) +@RunIf(min_gpus=2, skip_windows=True) def test_fully_sharded_plugin_checkpoint_multi_gpu(tmpdir): """ Test to ensure that checkpoint is saved correctly when using multiple GPUs diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 4dc5b5f34b50c..29248aa5c8f89 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -11,12 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest +import torch +from torch.nn.parallel import DistributedDataParallel from pytorch_lightning import Trainer +from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _FAIRSCALE_FULLY_SHARDED_AVAILABLE from tests.accelerators import DDPLauncher from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf +if _FAIRSCALE_AVAILABLE: + from fairscale.nn import ShardedDataParallel +if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + from fairscale.nn import FullyShardedDataParallel + class TrainerGetModel(BoringModel): @@ -103,3 +112,36 @@ def test_get_model_ddp_gpu(tmpdir, args=None): ) trainer.fit(model) return 1 + + +@pytest.mark.parametrize(["accelerator", "wrapper"], [ + ('ddp', DistributedDataParallel), + pytest.param( + 'ddp_sharded', + ShardedDataParallel, + marks=pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="FairScale not available.") + ), + pytest.param( + 'ddp_fully_sharded', + FullyShardedDataParallel, + marks=pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="FairScale not available.") + ), +]) +@RunIf(min_gpus=1, skip_windows=True) +def test_get_accelerator_wrapped_model(accelerator, wrapper, tmpdir): + """ + Ensure we can access the wrapped accelerator model during training. + """ + + class TestModel(BoringModel): + + def on_train_start(self) -> None: + assert isinstance(self.accelerator_model, wrapper) + + def configure_optimizers(self): + return torch.optim.SGD(self.accelerator_model.parameters(), lr=0.1) + + model = TestModel() + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, accelerator=accelerator, gpus=1) + trainer.fit(model) From 5c5f762e776946b57245fd65376f52b715c87110 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 2 Mar 2021 17:09:55 +0000 Subject: [PATCH 13/62] Fix import --- tests/trainer/properties/test_get_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 29248aa5c8f89..b9e804ac3ba78 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -22,7 +22,7 @@ from tests.helpers.runif import RunIf if _FAIRSCALE_AVAILABLE: - from fairscale.nn import ShardedDataParallel + from fairscale.nn.data_parallel import ShardedDataParallel if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: from fairscale.nn import FullyShardedDataParallel @@ -130,7 +130,7 @@ def test_get_model_ddp_gpu(tmpdir, args=None): @RunIf(min_gpus=1, skip_windows=True) def test_get_accelerator_wrapped_model(accelerator, wrapper, tmpdir): """ - Ensure we can access the wrapped accelerator model during training. + Ensure we can access the wrapped accelerator model during training.ShardedDataParallel """ class TestModel(BoringModel): From d28438b4895f91c38ef1084ebbf738f519d43cb5 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 3 Mar 2021 10:53:18 +0000 Subject: [PATCH 14/62] Few test fixes, expose params --- .../plugins/training_type/fully_sharded.py | 40 +++++++---- tests/plugins/test_fully_sharded_plugin.py | 69 ++++++++----------- 2 files changed, 56 insertions(+), 53 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 6bd5f7a093ad1..5011ebf0b51cb 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -22,7 +22,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: - from fairscale.nn import enable_wrap + from fairscale.nn import auto_wrap, enable_wrap, wrap from fairscale.nn.data_parallel import FullyShardedDataParallel from pytorch_lightning.overrides.fairscale import ( @@ -42,6 +42,9 @@ def __init__( fp32_reduce_scatter: Optional[bool] = None, compute_dtype: Optional[torch.dtype] = None, bucket_cap_mb: int = 25, + automatic_module_wrap: bool = False, + min_num_params: int = 1e8, + activation_checkpoint: bool = False, parallel_devices: Optional[List[torch.device]] = None, num_nodes: int = 1, cluster_environment: ClusterEnvironment = None, @@ -112,6 +115,9 @@ def __init__( self.fp32_reduce_scatter = fp32_reduce_scatter self.compute_dtype = compute_dtype self.bucket_cap_mb = bucket_cap_mb + self.automatic_module_wrap = automatic_module_wrap + self.min_num_params = min_num_params + self.activation_checkpoint = activation_checkpoint self._process_group = None @property @@ -128,18 +134,6 @@ def configure_ddp(self): torch.cuda.set_device(self.root_device) with enable_wrap( - cpu_offload=self.cpu_offload, - flatten_parameters=self.flatten_parameters, - move_grads_to_cpu=self.move_grads_to_cpu, - mixed_precision=precision == "mixed", - process_group=self.process_group - ): - # todo: this should somehow be incorporated as a general hook. - # currently this also means you have to use fully sharded to load the model as well. - self.lightning_module.trainer.call_hook("on_distributed_model_setup") - - self.model = FullyShardedDataParallel( - LightningFullyShardedDataModule(self.model), process_group=self.process_group, cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, @@ -149,8 +143,26 @@ def configure_ddp(self): fp32_reduce_scatter=self.fp32_reduce_scatter, compute_dtype=self.compute_dtype, bucket_cap_mb=self.bucket_cap_mb, - ) + ): + # Allow user to manually wrap the lightning modules, and any internal modules + # todo: this should somehow be incorporated as a general hook. + # currently this also means you have to use fully sharded to load the model as well. + self.lightning_module.trainer.call_hook("on_distributed_model_setup") + if self.automatic_module_wrap: + self.model = auto_wrap( + LightningFullyShardedDataModule(self.model), + min_num_params=self.min_num_params, + activation_checkpoint=self.activation_checkpoint + ) + if not isinstance(self.model, FullyShardedDataParallel): + self.model = wrap(self.model, activation_checkpoint=self.activation_checkpoint) + else: + self.model = wrap( + LightningFullyShardedDataModule(self.model), activation_checkpoint=self.activation_checkpoint + ) + if not self.cpu_offload: + # When using CPU Offload, FSDP will manage the CUDA movement for us super().model_to_device() # setup optimizers after fully sharded has wrapped the lightning module self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer) diff --git a/tests/plugins/test_fully_sharded_plugin.py b/tests/plugins/test_fully_sharded_plugin.py index 153b61cce1b69..3970a5163f4dc 100644 --- a/tests/plugins/test_fully_sharded_plugin.py +++ b/tests/plugins/test_fully_sharded_plugin.py @@ -5,7 +5,6 @@ import torch from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins import FullyShardedNativeMixedPrecisionPlugin, FullyShardedPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -16,29 +15,16 @@ from fairscale.nn import auto_wrap, FullyShardedDataParallel -@pytest.mark.parametrize(["plugin"], [("ddp_fully_sharded", )]) @pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") -def test_sharded_ddp_choice(tmpdir, plugin): +def test_sharded_ddp_choice(tmpdir): """ Test to ensure that plugin is correctly chosen """ - - class CB(Callback): - - def on_fit_start(self, trainer, pl_module): - if plugin == 'ddp_fully_sharded': - assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin) - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, - plugins=plugin, - callbacks=[CB()], + plugins='ddp_fully_sharded', ) - - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin) @pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") @@ -60,36 +46,24 @@ def test_invalid_apex_sharded(tmpdir): trainer.fit(model) -@pytest.mark.parametrize(["plugin"], [("ddp_fully_sharded", )]) @pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) @mock.patch('torch.cuda.device_count', return_value=1) @mock.patch('torch.cuda.is_available', return_value=True) @RunIf(amp_native=True) -def test_ddp_choice_sharded_amp(device_count_mock, mock_cuda_available, plugin, tmpdir): +def test_ddp_choice_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): """ Test to ensure that plugin native amp plugin is correctly chosen when using sharded """ - - class CB(Callback): - - def on_fit_start(self, trainer, pl_module): - if plugin == 'ddp_fully_sharded': - assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin) - assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) - raise SystemExit() - - model = BoringModel() trainer = Trainer( fast_dev_run=True, gpus=1, precision=16, - plugins=plugin, - callbacks=[CB()], + plugins='ddp_fully_sharded', ) - with pytest.raises(SystemExit): - trainer.fit(model) + assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin) + assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) @pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") @@ -98,7 +72,13 @@ def test_fully_sharded_plugin_checkpoint(tmpdir): """ Test to ensure that checkpoint is saved correctly when using a single GPU. """ - model = BoringModel() + + class TestModel(BoringModel): + + def configure_optimizers(self): + return torch.optim.SGD(self.accelerator_model.parameters(), lr=0.1) + + model = TestModel() trainer = Trainer( gpus=1, plugins='ddp_fully_sharded', @@ -111,27 +91,32 @@ def test_fully_sharded_plugin_checkpoint(tmpdir): _assert_save_equality(tmpdir, trainer) +@pytest.mark.parametrize('automatic_module_wrap', [True, False]) @pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") @RunIf(min_gpus=1, skip_windows=True) -def test_fully_sharded_plugin_checkpoint_autowrap(tmpdir): +def test_fully_sharded_plugin_checkpoint_manual_autowrap(automatic_module_wrap, tmpdir): """ - Test to ensure that checkpoint is saved correctly when using auto_wrap. + Test to ensure that checkpoint is saved correctly when using automatic, and manual auto_wrap. """ class TestModel(BoringModel): def on_distributed_model_setup(self) -> None: - self.layer = auto_wrap(self.layer, min_num_params=1) + if not automatic_module_wrap: + self.layer = auto_wrap(self.layer, min_num_params=1) def on_train_start(self) -> None: assert isinstance(self.layer, FullyShardedDataParallel) assert isinstance(self.accelerator_model, FullyShardedDataParallel) + def configure_optimizers(self): + return torch.optim.SGD(self.accelerator_model.parameters(), lr=0.1) + model = TestModel() trainer = Trainer( gpus=1, - plugins='ddp_fully_sharded', + plugins=FullyShardedPlugin(automatic_module_wrap=automatic_module_wrap, min_num_params=1), fast_dev_run=True, precision=16, ) @@ -150,7 +135,13 @@ def test_fully_sharded_plugin_checkpoint_multi_gpu(tmpdir): """ Test to ensure that checkpoint is saved correctly when using multiple GPUs """ - model = BoringModel() + + class TestModel(BoringModel): + + def configure_optimizers(self): + return torch.optim.SGD(self.accelerator_model.parameters(), lr=0.1) + + model = TestModel() trainer = Trainer( gpus=2, plugins='fully_sharded', From ab591a81bef0c4e4fb638b42ee48d6fa7b410d07 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 3 Mar 2021 11:47:17 +0000 Subject: [PATCH 15/62] Allow training_type_plugin to delay optimizer configure --- pytorch_lightning/accelerators/accelerator.py | 9 ++++++--- .../plugins/training_type/training_type_plugin.py | 9 +++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 38fb423d22aa8..375e94a7bffbc 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -74,7 +74,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None: model: the model to train """ self.connect_training_type_plugin(self.training_type_plugin, model) - self.setup_optimizers(trainer) + if not self.training_type_plugin.setup_optimizers_after_dispatch: + self.setup_optimizers(trainer) self.connect_precision_plugin(self.precision_plugin) def start_training(self, trainer: 'Trainer') -> None: @@ -86,12 +87,14 @@ def start_testing(self, trainer: 'Trainer') -> None: def start_predicting(self, trainer: 'Trainer') -> None: self.training_type_plugin.start_predicting(trainer) - def pre_dispatch(self) -> None: + def pre_dispatch(self, trainer: 'Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.pre_dispatch() + if self.training_type_plugin.setup_optimizers_after_dispatch: + self.setup_optimizers(trainer) self.precision_plugin.pre_dispatch() - def post_dispatch(self) -> None: + def post_dispatch(self, trainer: 'Trainer') -> None: """Hook to do something before the training/evaluation/prediction starts.""" self.training_type_plugin.post_dispatch() self.precision_plugin.post_dispatch() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cf4b93e04e2dc..af77547ccd144 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -169,3 +169,12 @@ def init_optimizers(self, trainer: "Trainer", model: LightningModule): def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): optimizer.step(closure=lambda_closure, **kwargs) + + @property + def setup_optimizers_after_dispatch(self) -> bool: + """ + Override to delay setting optimizers and schedulers till after dispatch. + This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model. + Returns: True if delaying setup optimizers till after dispatch, False to call within setup. + """ + return False From a60f2c03cbed93c565c5a14fae7b5ec9cdcba9a4 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 3 Mar 2021 16:14:56 +0000 Subject: [PATCH 16/62] Add missing references to trainer, add a CPU accelerator based test --- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/accelerators/test_cpu.py | 35 +++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7bfc3d41f9a8d..3f3567a02e29c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -510,10 +510,10 @@ def fit( return self.accelerator.results or 1 def pre_dispatch(self): - self.accelerator.pre_dispatch() + self.accelerator.pre_dispatch(self) def post_dispatch(self): - self.accelerator.post_dispatch() + self.accelerator.post_dispatch(self) self.accelerator.teardown() def dispatch(self): diff --git a/tests/accelerators/test_cpu.py b/tests/accelerators/test_cpu.py index 81a5132e47356..f4b4067b00953 100644 --- a/tests/accelerators/test_cpu.py +++ b/tests/accelerators/test_cpu.py @@ -1,12 +1,13 @@ from unittest.mock import Mock import pytest +import pytorch_lightning as pl import torch - from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.plugins import SingleDevicePlugin from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from pytorch_lightning.utilities.exceptions import MisconfigurationException +from tests.helpers.boring_model import BoringModel def test_unsupported_precision_plugins(): @@ -18,3 +19,35 @@ def test_unsupported_precision_plugins(): ) with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."): accelerator.setup(trainer=trainer, model=model) + + +@pytest.mark.parametrize("delay_dispatch", [True, False]) +def test_plugin_setup_optimizers_after_dispatch(tmpdir, delay_dispatch): + """ + Test when using a custom training type plugin that delays setup optimizers, + we do not call setup optimizers till after ``pre_dispatch``. + """ + + class TestModel(BoringModel): + def on_fit_start(self): + if delay_dispatch: + # Ensure we haven't setup optimizers if we've delayed dispatch + assert len(self.trainer.optimizers) == 0 + else: + assert len(self.trainer.optimizers) > 0 + + def on_fit_end(self): + assert len(self.trainer.optimizers) > 0 + + class CustomPlugin(SingleDevicePlugin): + @property + def setup_optimizers_after_dispatch(self) -> bool: + return delay_dispatch + + model = TestModel() + trainer = pl.Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + plugins=CustomPlugin(device=torch.device("cpu")) + ) + trainer.fit(model) From 516bd04398bbf1d8a2b8b8e46cd8e7046cb4fb9d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 9 Mar 2021 11:05:57 +0000 Subject: [PATCH 17/62] Update for latest API changes to fairscale --- pytorch_lightning/overrides/fairscale.py | 16 +++++++++----- .../plugins/training_type/fully_sharded.py | 21 ++++++++----------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index 52032bd437536..d44e4a336fc61 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -33,7 +33,7 @@ def unwrap_lightning_module_sharded(wrapped_model) -> LightningModule: return unwrap_lightning_module(model) -class LightningFullyShardedDataModule(_LightningModuleWrapperBase): +class LightningFullyShardedModule(_LightningModuleWrapperBase): # Just do this for later docstrings pass @@ -43,10 +43,16 @@ class LightningFullyShardedDataModule(_LightningModuleWrapperBase): from fairscale.nn.data_parallel import FullyShardedDataParallel def unwrap_lightning_module_fully_sharded(wrapped_model) -> LightningModule: + """ + Unwrap the lightning module within the FSDP wrapper. This is recursive as FSDP can be nested, meaning + the LightningModule could be a few layers deep. + """ model = wrapped_model if isinstance(model, FullyShardedDataParallel): - model = model.module + model = unwrap_lightning_module_fully_sharded(model.module) # Additional check if we're using a flattened parameters buffer - if isinstance(model, FlattenParamsWrapper): - model = model.module - return unwrap_lightning_module(model) + elif isinstance(model, FlattenParamsWrapper): + model = unwrap_lightning_module_fully_sharded(model.module) + if isinstance(model, _LightningModuleWrapperBase): + model = unwrap_lightning_module_fully_sharded(model.module) + return model diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 1eb7bcfd41215..538951669fdf7 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -26,7 +26,7 @@ from fairscale.nn.data_parallel import FullyShardedDataParallel from pytorch_lightning.overrides.fairscale import ( - LightningFullyShardedDataModule, + LightningFullyShardedModule, unwrap_lightning_module_fully_sharded, ) @@ -42,7 +42,7 @@ def __init__( fp32_reduce_scatter: Optional[bool] = None, compute_dtype: Optional[torch.dtype] = None, bucket_cap_mb: int = 25, - automatic_module_wrap: bool = False, + automatic_module_wrap: bool = True, min_num_params: int = 1e8, activation_checkpoint: bool = False, parallel_devices: Optional[List[torch.device]] = None, @@ -134,6 +134,7 @@ def configure_ddp(self): torch.cuda.set_device(self.root_device) with enable_wrap( + wrapper_cls=FullyShardedDataParallel, process_group=self.process_group, cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, @@ -149,26 +150,22 @@ def configure_ddp(self): # currently this also means you have to use fully sharded to load the model as well. self.lightning_module.trainer.call_hook("on_distributed_model_setup") if self.automatic_module_wrap: - self.model = auto_wrap( - LightningFullyShardedDataModule(self.model), - min_num_params=self.min_num_params, - activation_checkpoint=self.activation_checkpoint - ) + self.model = auto_wrap(LightningFullyShardedModule(self.model)) if not isinstance(self.model, FullyShardedDataParallel): - self.model = wrap(self.model, activation_checkpoint=self.activation_checkpoint) + self.model = wrap(self.model) else: - self.model = wrap( - LightningFullyShardedDataModule(self.model), activation_checkpoint=self.activation_checkpoint - ) + self.model = wrap(LightningFullyShardedModule(self.model)) if not self.cpu_offload: # When using CPU Offload, FSDP will manage the CUDA movement for us - super().model_to_device() + self.model_to_device() # setup optimizers after fully sharded has wrapped the lightning module self.lightning_module.trainer.accelerator.setup_optimizers(self.lightning_module.trainer) def model_to_device(self): self.model.to(self.root_device) + # ensure we update the device type in the lightning module + self.lightning_module.to(self.root_device) @property def lightning_module(self) -> LightningModule: From 9f8864f2519875646527f370f08f63cbb809b215 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 23 Mar 2021 12:17:18 +0000 Subject: [PATCH 18/62] Add base hook for model parallel --- pytorch_lightning/accelerators/accelerator.py | 15 ++++++++++++++- pytorch_lightning/callbacks/base.py | 3 +++ pytorch_lightning/core/hooks.py | 7 +++++++ .../plugins/training_type/training_type_plugin.py | 14 +++++++++++++- pytorch_lightning/trainer/callback_hook.py | 5 +++++ pytorch_lightning/trainer/trainer.py | 8 ++++++++ tests/callbacks/test_callbacks.py | 3 +++ 7 files changed, 53 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 60e6ea88b4250..c97918e1e407e 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union +import contextlib +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, TYPE_CHECKING, Union import torch from torch.optim import Optimizer @@ -432,3 +433,15 @@ def results(self) -> Any: In distributed training, we make sure to transfer the results to the appropriate master process. """ return self.training_type_plugin.results + + @contextlib.contextmanager + def model_parallel_context(self) -> Generator: + """ + Provide hook to create modules in a parallel aware context. This is useful for when we'd like to + shard the model instantly, which is useful for extremely large models which can save memory and + initialization time. + + Returns: Model parallel context. + """ + with self.training_type_plugin.model_parallel_context(): + yield diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index db507fa991446..d9d056bbc4fee 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -29,6 +29,9 @@ class Callback(abc.ABC): Subclass this class and override any of the relevant hooks """ + def on_model_parallel_setup(self, trainer, pl_module: LightningModule) -> None: + """Called before model parallel accelerator setup""" + def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule) -> None: """Called before accelerator is being setup""" pass diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 9624f94652713..bbea853c6b0e7 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -334,6 +334,13 @@ def on_post_move_to_device(self): """ + def on_model_parallel_setup(self) -> None: + """ + Hook to create modules in a parallel aware context. This is useful for when using sharded plugins, + where we'd like to shard the model instantly, which is useful for extremely large models + which can save memory and initialization time. + """ + class DataHooks: """Hooks to be used for data related stuff.""" diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 89f27963caadf..a407acd4a6040 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Iterable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Generator, Iterable, Optional, TYPE_CHECKING, Union import torch from torch.nn import Module @@ -192,3 +193,14 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: Returns: If True, delay setup optimizers till pre_dispatch, else call within setup. """ return False + + @contextlib.contextmanager + def model_parallel_context(self) -> Generator: + """ + Provide hook to create modules in a parallel aware context. This is useful for when we'd like to + shard the model instantly, which is useful for extremely large models which can save memory and + initialization time. + + Returns: Model parallel context. + """ + yield diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 8823d48a7817e..849f99d4f8e09 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -38,6 +38,11 @@ def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) + def on_model_parallel_setup(self, model: LightningModule, stage: Optional[str]) -> None: + """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" + for callback in self.callbacks: + callback.on_model_parallel_setup(self, model, stage) + def setup(self, model: LightningModule, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f7bd1757b9bc2..716a1bc3707af 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -433,6 +433,7 @@ def fit( self.accelerator.setup_environment() self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment self.accelerator.setup(self, model) # note: this sets up self.lightning_module + self.call_model_parallel_hook(model) # allow user to setup in model parallel environment # ---------------------------- # INSPECT THE CORE LOOPS @@ -1075,6 +1076,13 @@ def call_setup_hook(self, model: LightningModule) -> None: self.setup(model, stage=state) model.setup(stage=state) + def call_model_parallel_hook(self, model: LightningModule) -> None: + if not hasattr(self.lightning_module, 'is_model_parallel_setup'): + self.on_model_parallel_setup(model) + with self.accelerator.model_parallel_context(): + model.on_model_parallel_setup() + self.lightning_module.is_model_parallel_setup = True + def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state self.profiler.teardown(stage=state) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index fdefc6ae9ef1c..1490a7f53cd77 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -48,6 +48,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'fit'), + call.on_model_parallel_setup(model), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), @@ -117,6 +118,7 @@ def test_trainer_callback_hook_system_test(tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'test'), + call.on_model_parallel_setup(trainer, model), call.on_test_start(trainer, model), call.on_test_epoch_start(trainer, model), call.on_test_batch_start(trainer, model, ANY, 0, 0), @@ -150,6 +152,7 @@ def test_trainer_callback_hook_system_validate(tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'validate'), + call.on_model_parallel_setup(trainer, model), call.on_validation_start(trainer, model), call.on_validation_epoch_start(trainer, model), call.on_validation_batch_start(trainer, model, ANY, 0, 0), From eac5344b077504c1550ceffef859605225102a03 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 20:40:16 +0530 Subject: [PATCH 19/62] fix callback signature --- tests/callbacks/test_callbacks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 1490a7f53cd77..381de5d3e659f 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -48,7 +48,7 @@ def test_trainer_callback_hook_system_fit(_, tmpdir): call.on_init_end(trainer), call.on_before_accelerator_backend_setup(trainer, model), call.setup(trainer, model, 'fit'), - call.on_model_parallel_setup(model), + call.on_model_parallel_setup(trainer, model), call.on_fit_start(trainer, model), call.on_pretrain_routine_start(trainer, model), call.on_pretrain_routine_end(trainer, model), From 32df0cb9e6277e0dd9d12d8d1326e41e0dcb9041 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 25 Mar 2021 15:58:09 +0000 Subject: [PATCH 20/62] Simplify hook --- pytorch_lightning/trainer/callback_hook.py | 4 ++-- pytorch_lightning/trainer/trainer.py | 8 +++----- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 849f99d4f8e09..782921cc06047 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -38,10 +38,10 @@ def on_before_accelerator_backend_setup(self, model: LightningModule) -> None: for callback in self.callbacks: callback.on_before_accelerator_backend_setup(self, model) - def on_model_parallel_setup(self, model: LightningModule, stage: Optional[str]) -> None: + def on_model_parallel_setup(self, model: LightningModule) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" for callback in self.callbacks: - callback.on_model_parallel_setup(self, model, stage) + callback.on_model_parallel_setup(self, model) def setup(self, model: LightningModule, stage: Optional[str]) -> None: """Called at the beginning of fit (train + validate), validate, test, or predict, or tune.""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 716a1bc3707af..a20b8618161a4 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1077,11 +1077,9 @@ def call_setup_hook(self, model: LightningModule) -> None: model.setup(stage=state) def call_model_parallel_hook(self, model: LightningModule) -> None: - if not hasattr(self.lightning_module, 'is_model_parallel_setup'): - self.on_model_parallel_setup(model) - with self.accelerator.model_parallel_context(): - model.on_model_parallel_setup() - self.lightning_module.is_model_parallel_setup = True + self.on_model_parallel_setup(model) + with self.accelerator.model_parallel_context(): + model.on_model_parallel_setup() def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state From 282a133dd834fca4a2d419a637b445a94cad7aca Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 25 Mar 2021 16:52:57 +0000 Subject: [PATCH 21/62] Add hook logic --- pytorch_lightning/accelerators/accelerator.py | 9 +++++++++ .../plugins/training_type/training_type_plugin.py | 9 +++++++++ pytorch_lightning/trainer/trainer.py | 9 ++++++--- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c97918e1e407e..d6362bb103f5c 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -445,3 +445,12 @@ def model_parallel_context(self) -> Generator: """ with self.training_type_plugin.model_parallel_context(): yield + + @property + def call_model_parallel_setup_hook(self) -> bool: + """ + Allow model parallel hook to be called in suitable environments determined by the training type plugin. + This is useful for when we want to shard the model once within fit. + Returns: True if we want to call the model parallel setup hook. + """ + return self.training_type_plugin.call_model_parallel_setup_hook diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index a407acd4a6040..2997b14d37316 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -204,3 +204,12 @@ def model_parallel_context(self) -> Generator: Returns: Model parallel context. """ yield + + @property + def call_model_parallel_setup_hook(self) -> bool: + """ + Allow model parallel hook to be called in suitable environments determined by the training type plugin. + This is useful for when we want to shard the model once within fit. + Returns: True if we want to call the model parallel setup hook. + """ + return True diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a20b8618161a4..aacc71a823a3b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1077,9 +1077,12 @@ def call_setup_hook(self, model: LightningModule) -> None: model.setup(stage=state) def call_model_parallel_hook(self, model: LightningModule) -> None: - self.on_model_parallel_setup(model) - with self.accelerator.model_parallel_context(): - model.on_model_parallel_setup() + # Call model parallel hook if accelerator requests. In some cases + # we will not call the hook; the hook has initialized the sharded model for example. + if self.accelerator.call_model_parallel_setup_hook: + self.on_model_parallel_setup(model) + with self.accelerator.model_parallel_context(): + model.on_model_parallel_setup() def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state From 7a94e72f183d8b2855f07c5ccf10f85e93b35f2e Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 23:17:49 +0530 Subject: [PATCH 22/62] add tests --- tests/accelerators/test_common.py | 65 +++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index bd8636ba839f9..ee45878863551 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -1,9 +1,24 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest import torch import tests.helpers.utils as tutils from pytorch_lightning import Trainer +from pytorch_lightning.plugins import SingleDevicePlugin from tests.accelerators.test_dp import CustomClassificationModelDP +from tests.helpers.boring_model import BoringModel from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf @@ -44,3 +59,53 @@ def test_evaluate(tmpdir, trainer_kwargs): # make sure weights didn't change new_weights = model.layer_0.weight.clone().detach().cpu() torch.testing.assert_allclose(old_weights, new_weights) + + +def test_model_parallel_setup_called(tmpdir): + + class TestModel(BoringModel): + + def on_model_parallel_setup(self): + self.on_model_parallel_setup_called = True + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + + assert model.on_model_parallel_setup_called + + +def test_model_parallel_setup_false(tmpdir): + """Ensure ``on_model_parallel_setup`` is not called, when turned off""" + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.on_model_parallel_setup_called = False + + def on_model_parallel_setup(self): + self.on_model_parallel_setup_called = True + + class CustomPlugin(SingleDevicePlugin): + + @property + def call_model_parallel_setup_hook(self) -> bool: + return False + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + plugins=CustomPlugin(device=torch.device("cpu")) + ) + trainer.fit(model) + + assert not model.on_model_parallel_setup_called From 809148135218ce102010516e9d064785deae622d Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 23:33:44 +0530 Subject: [PATCH 23/62] add property setter --- .../plugins/training_type/training_type_plugin.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 2997b14d37316..f68510242839b 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -34,6 +34,7 @@ class TrainingTypePlugin(Plugin, ABC): def __init__(self) -> None: self._model = None self._results = None + self._call_model_parallel_setup_hook = True def connect(self, model: 'Module') -> None: """Called by the accelerator to connect the accelerator and the model with this plugin""" @@ -212,4 +213,9 @@ def call_model_parallel_setup_hook(self) -> bool: This is useful for when we want to shard the model once within fit. Returns: True if we want to call the model parallel setup hook. """ - return True + return self._call_model_parallel_setup_hook + + @call_model_parallel_setup_hook.setter + def call_model_parallel_setup_hook(self, mode: bool) -> bool: + if isinstance(mode, bool): + self._call_model_parallel_setup_hook = mode From 633fc77148c0538e92479de95bae042bbf430cea Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 23:49:49 +0530 Subject: [PATCH 24/62] add logic for being called once --- pytorch_lightning/accelerators/accelerator.py | 5 ++++ .../training_type/training_type_plugin.py | 3 +-- pytorch_lightning/trainer/trainer.py | 1 + tests/accelerators/test_common.py | 27 +++++++++++++++++++ 4 files changed, 34 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index d6362bb103f5c..6ef3895077892 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -454,3 +454,8 @@ def call_model_parallel_setup_hook(self) -> bool: Returns: True if we want to call the model parallel setup hook. """ return self.training_type_plugin.call_model_parallel_setup_hook + + @call_model_parallel_setup_hook.setter + def call_model_parallel_setup_hook(self, mode: bool) -> bool: + if isinstance(mode, bool): + self.training_type_plugin.call_model_parallel_setup_hook = mode diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index f68510242839b..8431f653955e7 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -217,5 +217,4 @@ def call_model_parallel_setup_hook(self) -> bool: @call_model_parallel_setup_hook.setter def call_model_parallel_setup_hook(self, mode: bool) -> bool: - if isinstance(mode, bool): - self._call_model_parallel_setup_hook = mode + self._call_model_parallel_setup_hook = mode diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index aacc71a823a3b..dd942e040efd3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1083,6 +1083,7 @@ def call_model_parallel_hook(self, model: LightningModule) -> None: self.on_model_parallel_setup(model) with self.accelerator.model_parallel_context(): model.on_model_parallel_setup() + self.accelerator.call_model_parallel_setup_hook = False def call_teardown_hook(self, model: LightningModule) -> None: state = self._teardown_state diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index ee45878863551..c0b9efee947fb 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -109,3 +109,30 @@ def call_model_parallel_setup_hook(self) -> bool: trainer.fit(model) assert not model.on_model_parallel_setup_called + + +def test_model_parallel_setup_called_once(tmpdir): + """Ensure ``on_model_parallel_setup`` is only called once""" + + class TestModel(BoringModel): + + def __init__(self): + super().__init__() + self.on_model_parallel_setup_called = False + + def on_model_parallel_setup(self): + self.on_model_parallel_setup_called = True + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + ) + trainer.fit(model) + + assert model.on_model_parallel_setup_called + model.on_model_parallel_setup_called = False + + assert not model.on_model_parallel_setup_called From c99a36f960edda694d18426c2319e1e2a9dadac1 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 23:53:06 +0530 Subject: [PATCH 25/62] Update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 32cf9122efe34..fc9b18ff3f0b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -57,6 +57,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `outputs` parameter to callback's `on_validation_epoch_end` & `on_test_epoch_end` hooks ([#6120](https://github.com/PyTorchLightning/pytorch-lightning/pull/6120)) +- Added `on_model_parallel_setup` hook ([#6679](https://github.com/PyTorchLightning/pytorch-lightning/pull/6679)) + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) From 9529a22882057fcd4f43d4842d35799083d1de80 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 23:56:32 +0530 Subject: [PATCH 26/62] Fix --- pytorch_lightning/accelerators/accelerator.py | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 293a5e28b4c79..d6041fe1fca9f 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -452,20 +452,6 @@ def model_parallel_context(self) -> Generator: with self.training_type_plugin.model_parallel_context(): yield - @property - def call_model_parallel_setup_hook(self) -> bool: - """ - Allow model parallel hook to be called in suitable environments determined by the training type plugin. - This is useful for when we want to shard the model once within fit. - Returns: True if we want to call the model parallel setup hook. - """ - return self.training_type_plugin.call_model_parallel_setup_hook - - @call_model_parallel_setup_hook.setter - def call_model_parallel_setup_hook(self, mode: bool) -> bool: - if isinstance(mode, bool): - self.training_type_plugin.call_model_parallel_setup_hook = mode - # todo: remove in v1.5 def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: """ @@ -493,3 +479,17 @@ def connect_precision_plugin(self, plugin: PrecisionPlugin) -> None: ' It will be removed in v1.5.' ) self.setup_precision_plugin(plugin) + + @property + def call_model_parallel_setup_hook(self) -> bool: + """ + Allow model parallel hook to be called in suitable environments determined by the training type plugin. + This is useful for when we want to shard the model once within fit. + Returns: True if we want to call the model parallel setup hook. + """ + return self.training_type_plugin.call_model_parallel_setup_hook + + @call_model_parallel_setup_hook.setter + def call_model_parallel_setup_hook(self, mode: bool) -> bool: + if isinstance(mode, bool): + self.training_type_plugin.call_model_parallel_setup_hook = mode From 3c1c782187923c99cf82352819525f45e31ba5a8 Mon Sep 17 00:00:00 2001 From: Kaushik Bokka Date: Thu, 25 Mar 2021 23:59:06 +0530 Subject: [PATCH 27/62] fix return type --- pytorch_lightning/accelerators/accelerator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index d6041fe1fca9f..d7b7156e4ad96 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -490,6 +490,6 @@ def call_model_parallel_setup_hook(self) -> bool: return self.training_type_plugin.call_model_parallel_setup_hook @call_model_parallel_setup_hook.setter - def call_model_parallel_setup_hook(self, mode: bool) -> bool: + def call_model_parallel_setup_hook(self, mode: bool) -> None: if isinstance(mode, bool): self.training_type_plugin.call_model_parallel_setup_hook = mode From 87ec222c8bdafddeb4e08b07f5e6d4666605902d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 25 Mar 2021 18:50:49 +0000 Subject: [PATCH 28/62] Fix property name --- pytorch_lightning/plugins/training_type/fully_sharded.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 538951669fdf7..e55909a4af5d1 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -25,10 +25,7 @@ from fairscale.nn import auto_wrap, enable_wrap, wrap from fairscale.nn.data_parallel import FullyShardedDataParallel - from pytorch_lightning.overrides.fairscale import ( - LightningFullyShardedModule, - unwrap_lightning_module_fully_sharded, - ) + from pytorch_lightning.overrides.fairscale import LightningFullyShardedModule, unwrap_lightning_module_fully_sharded class FullyShardedPlugin(DDPPlugin): @@ -191,5 +188,5 @@ def collate_state_dict(self): return state_dict @property - def setup_optimizers_after_dispatch(self) -> bool: + def setup_optimizers_in_pre_dispatch(self) -> bool: return True From 5f6e039aa1bd00d6ba371c8f422c366cf19058fc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 25 Mar 2021 19:25:16 +0000 Subject: [PATCH 29/62] Updaet wrapper, use latest fixes for hooks --- .../plugins/training_type/fully_sharded.py | 27 ++++++++++++------- tests/plugins/test_fully_sharded_plugin.py | 10 ++++--- 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index e55909a4af5d1..f8813978feabb 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +import contextlib +from typing import Generator, List, Optional import torch @@ -22,7 +23,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: - from fairscale.nn import auto_wrap, enable_wrap, wrap + from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap from fairscale.nn.data_parallel import FullyShardedDataParallel from pytorch_lightning.overrides.fairscale import LightningFullyShardedModule, unwrap_lightning_module_fully_sharded @@ -123,15 +124,16 @@ def process_group(self): self._process_group = torch.distributed.new_group() return self._process_group - def configure_ddp(self): + @contextlib.contextmanager + def model_parallel_context(self) -> Generator: precision = self.lightning_module.trainer.precision - # set the device before instantiate the wrapper - if self.root_device.type == "cuda": - torch.cuda.set_device(self.root_device) + def wrap_policy(*args, **kwargs): + return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) with enable_wrap( wrapper_cls=FullyShardedDataParallel, + auto_wrap_policy=wrap_policy, process_group=self.process_group, cpu_offload=self.cpu_offload, move_grads_to_cpu=self.move_grads_to_cpu, @@ -142,10 +144,15 @@ def configure_ddp(self): compute_dtype=self.compute_dtype, bucket_cap_mb=self.bucket_cap_mb, ): - # Allow user to manually wrap the lightning modules, and any internal modules - # todo: this should somehow be incorporated as a general hook. - # currently this also means you have to use fully sharded to load the model as well. - self.lightning_module.trainer.call_hook("on_distributed_model_setup") + yield + + def configure_ddp(self): + + # set the device before instantiate the wrapper + if self.root_device.type == "cuda": + torch.cuda.set_device(self.root_device) + + with self.model_parallel_context(): if self.automatic_module_wrap: self.model = auto_wrap(LightningFullyShardedModule(self.model)) if not isinstance(self.model, FullyShardedDataParallel): diff --git a/tests/plugins/test_fully_sharded_plugin.py b/tests/plugins/test_fully_sharded_plugin.py index 3970a5163f4dc..dd0253f8141ac 100644 --- a/tests/plugins/test_fully_sharded_plugin.py +++ b/tests/plugins/test_fully_sharded_plugin.py @@ -12,7 +12,7 @@ from tests.helpers.runif import RunIf if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: - from fairscale.nn import auto_wrap, FullyShardedDataParallel + from fairscale.nn import auto_wrap, default_auto_wrap_policy, FullyShardedDataParallel @pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") @@ -101,9 +101,13 @@ def test_fully_sharded_plugin_checkpoint_manual_autowrap(automatic_module_wrap, class TestModel(BoringModel): - def on_distributed_model_setup(self) -> None: + def on_model_parallel_setup(self) -> None: if not automatic_module_wrap: - self.layer = auto_wrap(self.layer, min_num_params=1) + + def wrap_policy(*args, **kwargs): + return default_auto_wrap_policy(*args, **kwargs, min_num_params=1) + + self.layer = auto_wrap(self.layer, auto_wrap_policy=wrap_policy) def on_train_start(self) -> None: assert isinstance(self.layer, FullyShardedDataParallel) From b512e72971e57c4fc64dd387f14dd9f4696f9440 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 25 Mar 2021 23:07:03 +0000 Subject: [PATCH 30/62] Swap hook order --- pytorch_lightning/trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6198d4603d555..6e1dee2bcf12e 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -436,8 +436,8 @@ def fit( self.accelerator.connect(model) self.accelerator.setup_environment() self.call_setup_hook(model) # allow user to setup lightning_module in accelerator environment - self.accelerator.setup(self, model) # note: this sets up self.lightning_module self.call_model_parallel_hook(model) # allow user to setup in model parallel environment + self.accelerator.setup(self, model) # note: this sets up self.lightning_module # ---------------------------- # INSPECT THE CORE LOOPS From 1e5ca37b42e3a8324026b263fe92feedab984f1d Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 29 Mar 2021 22:08:06 +0100 Subject: [PATCH 31/62] Small changes --- .../plugins/training_type/fully_sharded.py | 20 +++++++------------ pytorch_lightning/utilities/__init__.py | 1 + 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index f8813978feabb..96ccf10e69ae6 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -125,7 +125,7 @@ def process_group(self): return self._process_group @contextlib.contextmanager - def model_parallel_context(self) -> Generator: + def model_sharded_context(self) -> Generator: precision = self.lightning_module.trainer.precision def wrap_policy(*args, **kwargs): @@ -147,18 +147,12 @@ def wrap_policy(*args, **kwargs): yield def configure_ddp(self): - - # set the device before instantiate the wrapper - if self.root_device.type == "cuda": - torch.cuda.set_device(self.root_device) - - with self.model_parallel_context(): - if self.automatic_module_wrap: - self.model = auto_wrap(LightningFullyShardedModule(self.model)) - if not isinstance(self.model, FullyShardedDataParallel): - self.model = wrap(self.model) - else: - self.model = wrap(LightningFullyShardedModule(self.model)) + if self.automatic_module_wrap: + self.model = auto_wrap(LightningFullyShardedModule(self.model)) + if not isinstance(self.model, FullyShardedDataParallel): + self.model = wrap(self.model) + else: + self.model = wrap(LightningFullyShardedModule(self.model)) if not self.cpu_offload: # When using CPU Offload, FSDP will manage the CUDA movement for us diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index f20bc7b3f16dd..c1f0dd3d4aec1 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -14,6 +14,7 @@ """General utilities""" import numpy + from pytorch_lightning.utilities.apply_func import move_data_to_device # noqa: F401 from pytorch_lightning.utilities.distributed import ( # noqa: F401 AllGatherGrad, From 936dc1a94b67d31a1f5836d91b76feeee59d47c9 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 29 Mar 2021 22:57:37 +0100 Subject: [PATCH 32/62] Fixes --- .../plugins/training_type/fully_sharded.py | 29 +++++++++++++------ 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 96ccf10e69ae6..ba49bb1754409 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -23,7 +23,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: - from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap + from fairscale.nn import auto_wrap, checkpoint_wrapper, default_auto_wrap_policy, enable_wrap, wrap from fairscale.nn.data_parallel import FullyShardedDataParallel from pytorch_lightning.overrides.fairscale import LightningFullyShardedModule, unwrap_lightning_module_fully_sharded @@ -125,14 +125,24 @@ def process_group(self): return self._process_group @contextlib.contextmanager - def model_sharded_context(self) -> Generator: + def model_sharded_context(self, override_checkpointing=False) -> Generator: + + # set the device before instantiate the wrapper + if self.root_device.type == "cuda": + torch.cuda.set_device(self.root_device) + precision = self.lightning_module.trainer.precision def wrap_policy(*args, **kwargs): return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) + def model_wrapper(module, **wrap_overrides): + if self.activation_checkpoint: + module = checkpoint_wrapper(module) + return FullyShardedDataParallel(module, **wrap_overrides) + with enable_wrap( - wrapper_cls=FullyShardedDataParallel, + wrapper_cls=model_wrapper, auto_wrap_policy=wrap_policy, process_group=self.process_group, cpu_offload=self.cpu_offload, @@ -147,12 +157,13 @@ def wrap_policy(*args, **kwargs): yield def configure_ddp(self): - if self.automatic_module_wrap: - self.model = auto_wrap(LightningFullyShardedModule(self.model)) - if not isinstance(self.model, FullyShardedDataParallel): - self.model = wrap(self.model) - else: - self.model = wrap(LightningFullyShardedModule(self.model)) + with self.model_sharded_context(): + if self.automatic_module_wrap: + self.model = auto_wrap(LightningFullyShardedModule(self.model)) + if not isinstance(self.model, FullyShardedDataParallel): + self.model = wrap(self.model) + else: + self.model = wrap(LightningFullyShardedModule(self.model)) if not self.cpu_offload: # When using CPU Offload, FSDP will manage the CUDA movement for us From a6de18ea3efc6b25c0cd2eaa65dbf82fa6d7b650 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 1 Apr 2021 21:25:33 +0100 Subject: [PATCH 33/62] Remove activation checkpointing --- pytorch_lightning/plugins/training_type/fully_sharded.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index ba49bb1754409..dbc37fc087e80 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -42,7 +42,6 @@ def __init__( bucket_cap_mb: int = 25, automatic_module_wrap: bool = True, min_num_params: int = 1e8, - activation_checkpoint: bool = False, parallel_devices: Optional[List[torch.device]] = None, num_nodes: int = 1, cluster_environment: ClusterEnvironment = None, @@ -115,7 +114,6 @@ def __init__( self.bucket_cap_mb = bucket_cap_mb self.automatic_module_wrap = automatic_module_wrap self.min_num_params = min_num_params - self.activation_checkpoint = activation_checkpoint self._process_group = None @property @@ -136,13 +134,8 @@ def model_sharded_context(self, override_checkpointing=False) -> Generator: def wrap_policy(*args, **kwargs): return default_auto_wrap_policy(*args, **kwargs, min_num_params=self.min_num_params) - def model_wrapper(module, **wrap_overrides): - if self.activation_checkpoint: - module = checkpoint_wrapper(module) - return FullyShardedDataParallel(module, **wrap_overrides) - with enable_wrap( - wrapper_cls=model_wrapper, + wrapper_cls=FullyShardedDataParallel, auto_wrap_policy=wrap_policy, process_group=self.process_group, cpu_offload=self.cpu_offload, From 8684f94d160c5e7f23f2d2ee15a0dbff740095fe Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 2 Apr 2021 00:41:43 +0100 Subject: [PATCH 34/62] Turn off auto wrap by default --- pytorch_lightning/plugins/training_type/fully_sharded.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index dbc37fc087e80..deeaf621d21a0 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -40,7 +40,7 @@ def __init__( fp32_reduce_scatter: Optional[bool] = None, compute_dtype: Optional[torch.dtype] = None, bucket_cap_mb: int = 25, - automatic_module_wrap: bool = True, + automatic_module_wrap: bool = False, min_num_params: int = 1e8, parallel_devices: Optional[List[torch.device]] = None, num_nodes: int = 1, From 76091ae230fe4bb247173718ed172c77e7658c2f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 7 Apr 2021 17:41:47 +0100 Subject: [PATCH 35/62] Move to trainer.model --- pytorch_lightning/core/lightning.py | 4 ---- tests/plugins/test_fully_sharded_plugin.py | 8 ++++---- tests/trainer/properties/test_get_model.py | 4 ++-- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 5884ed9da39f1..7efe88515b37e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -469,10 +469,6 @@ def all_gather( all_gather = partial(all_gather, group=group, sync_grads=sync_grads) return apply_to_collection(data, torch.Tensor, all_gather) - @property - def accelerator_model(self): - return self.trainer.accelerator.model - def forward(self, *args, **kwargs): r""" Same as :meth:`torch.nn.Module.forward()`, however in Lightning you want this to define diff --git a/tests/plugins/test_fully_sharded_plugin.py b/tests/plugins/test_fully_sharded_plugin.py index dd0253f8141ac..53a9c99554991 100644 --- a/tests/plugins/test_fully_sharded_plugin.py +++ b/tests/plugins/test_fully_sharded_plugin.py @@ -76,7 +76,7 @@ def test_fully_sharded_plugin_checkpoint(tmpdir): class TestModel(BoringModel): def configure_optimizers(self): - return torch.optim.SGD(self.accelerator_model.parameters(), lr=0.1) + return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) model = TestModel() trainer = Trainer( @@ -111,10 +111,10 @@ def wrap_policy(*args, **kwargs): def on_train_start(self) -> None: assert isinstance(self.layer, FullyShardedDataParallel) - assert isinstance(self.accelerator_model, FullyShardedDataParallel) + assert isinstance(self.trainer.model, FullyShardedDataParallel) def configure_optimizers(self): - return torch.optim.SGD(self.accelerator_model.parameters(), lr=0.1) + return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) model = TestModel() @@ -143,7 +143,7 @@ def test_fully_sharded_plugin_checkpoint_multi_gpu(tmpdir): class TestModel(BoringModel): def configure_optimizers(self): - return torch.optim.SGD(self.accelerator_model.parameters(), lr=0.1) + return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) model = TestModel() trainer = Trainer( diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 15898cfd1b6e3..3bd3f24740def 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -113,10 +113,10 @@ def test_get_accelerator_wrapped_model(accelerator, wrapper, tmpdir): class TestModel(BoringModel): def on_train_start(self) -> None: - assert isinstance(self.accelerator_model, wrapper) + assert isinstance(self.trainer.model, wrapper) def configure_optimizers(self): - return torch.optim.SGD(self.accelerator_model.parameters(), lr=0.1) + return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) model = TestModel() From 226d4982f89c01e54925350b9389e0b0600f9dfa Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 7 Apr 2021 17:54:04 +0100 Subject: [PATCH 36/62] fix reference --- tests/plugins/test_fully_sharded_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_fully_sharded_plugin.py b/tests/plugins/test_fully_sharded_plugin.py index 53a9c99554991..54c3ed078d2e2 100644 --- a/tests/plugins/test_fully_sharded_plugin.py +++ b/tests/plugins/test_fully_sharded_plugin.py @@ -101,7 +101,7 @@ def test_fully_sharded_plugin_checkpoint_manual_autowrap(automatic_module_wrap, class TestModel(BoringModel): - def on_model_parallel_setup(self) -> None: + def configure_sharded_model(self) -> None: if not automatic_module_wrap: def wrap_policy(*args, **kwargs): From b881e2f4f426f618acf18f72bfd2247f40cb08a1 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 7 Apr 2021 18:49:15 +0100 Subject: [PATCH 37/62] Remove flag --- pytorch_lightning/plugins/training_type/fully_sharded.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index deeaf621d21a0..752af2eda99c5 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -123,7 +123,7 @@ def process_group(self): return self._process_group @contextlib.contextmanager - def model_sharded_context(self, override_checkpointing=False) -> Generator: + def model_sharded_context(self) -> Generator: # set the device before instantiate the wrapper if self.root_device.type == "cuda": From e8959be7890f088711f1cd63834801953b89c437 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 7 Apr 2021 18:55:34 +0100 Subject: [PATCH 38/62] Fix imports --- pytorch_lightning/plugins/training_type/fully_sharded.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 752af2eda99c5..9d85d2ae4ecd6 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -23,7 +23,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: - from fairscale.nn import auto_wrap, checkpoint_wrapper, default_auto_wrap_policy, enable_wrap, wrap + from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap from fairscale.nn.data_parallel import FullyShardedDataParallel from pytorch_lightning.overrides.fairscale import LightningFullyShardedModule, unwrap_lightning_module_fully_sharded From 52478ac5b46c6f81ab01b0e973fd307f91d992f0 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 7 Apr 2021 20:33:45 +0100 Subject: [PATCH 39/62] Fix versions, update docs --- .../plugins/training_type/fully_sharded.py | 23 ++++++++++++------- requirements/extra.txt | 2 +- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 9d85d2ae4ecd6..41a94bfeba2ec 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -54,13 +54,13 @@ def __init__( Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar - to ZeRO-Stage 3 but have been modified/adjusted for PyTorch. + to ZeRO-Stage 3 but has been built for upstreaming to PyTorch. `For more information: https://fairscale.readthedocs.io/en/latest/api/nn/fsdp.html`. .. warning:: ``FullyShardedPlugin`` is in beta and subject to change. - Defaults have been set to enable CPU Offload, but options have been exposed and may require configuration + Defaults have been set and options have been exposed, but may require configuration based on your level of memory/speed efficiency. We suggest having a look at this PR for more information. `https://github.com/facebookresearch/fairscale/pull/413` @@ -73,7 +73,7 @@ def __init__( cpu_offload: Offload FP32 params to CPU. Only useable in precision=16 mode (default: False). - move_grads_to_cpu: Moves gradient shards to CPU after reduction. + move_grads_to_cpu: Moves gradient shards to CPU after reduction. Only disable if using CPU based optimizers (defaults to ``cpu_offload``). flatten_parameters: Flattens parameter into single contiguous tensor for speed efficiency @@ -90,11 +90,18 @@ def __init__( unless using mixed precision, in which case defaults to torch.float16. bucket_cap_mb: bucket parameters so that gradient reduction - can potentially overlap with backward computation. - bucket_cap_mb controls the bucket size in MegaBytes (MB). - Buckets are sub-divided based on world_size, - so the max shard size is roughly bucket_cap_mb / world_size. - Values <= 0 disable bucketing. (Default: 25). + can potentially overlap with backward computation. + bucket_cap_mb controls the bucket size in MegaBytes (MB). + Buckets are sub-divided based on world_size, + so the max shard size is roughly bucket_cap_mb / world_size. + Values <= 0 disable bucketing. (Default: 25). + + automatic_module_wrap: Automatically wrap the lightning module with Fully Sharded recursively. + Using ``min_num_params`` to determine the amount of parameters to wrap at a time. + (default: False) + + min_num_params: Number of parameters to wrap when using FairScale ``auto_wrap``. + (default: 1e8) """ if not _FAIRSCALE_FULLY_SHARDED_AVAILABLE: diff --git a/requirements/extra.txt b/requirements/extra.txt index 8f23f8e200f2b..3b4e4efbb8879 100644 --- a/requirements/extra.txt +++ b/requirements/extra.txt @@ -7,6 +7,6 @@ torchtext>=0.5 # onnx>=1.7.0 onnxruntime>=1.3.0 hydra-core>=1.0 -fairscale>=0.3.2 +fairscale>=0.3.3 jsonargparse[signatures]>=3.3.1 deepspeed>=0.3.13 From b7f189680bd278b35f9b6a5a4727691e313d4e39 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 8 Apr 2021 16:53:32 +0100 Subject: [PATCH 40/62] Fix clip gradients --- .../plugins/precision/fully_sharded_native_amp.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 651b7030c9316..5e225caa5acac 100644 --- a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -16,13 +16,19 @@ from torch.optim import Optimizer from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.utilities import GradClipAlgorithmType class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): """Mixed Precision for Full Sharded Training""" def clip_gradients( - self, model: Any, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0) + self, + model: 'LightningModule', + optimizer: 'Optimizer', + clip_val: Union[int, float], + gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, + norm_type: float = 2.0 ) -> None: # Model manages clipping of gradients model.clip_grad_norm_(clip_val, norm_type) From 9fa26c00821d388ff6ca37de0be8108553d8bafe Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 14 Apr 2021 10:17:12 +0100 Subject: [PATCH 41/62] Fixes --- .github/workflows/events-nightly.yml | 1 + dockers/base-cuda/Dockerfile | 4 ++-- pytorch_lightning/utilities/imports.py | 2 +- tests/trainer/properties/test_get_model.py | 1 + 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/events-nightly.yml b/.github/workflows/events-nightly.yml index 5ad4396a006f7..91d509f193339 100644 --- a/.github/workflows/events-nightly.yml +++ b/.github/workflows/events-nightly.yml @@ -3,6 +3,7 @@ name: Nightly events # https://jasonet.co/posts/scheduled-actions/ # https://github.community/t/distinct-job-for-each-schedule/17811/2 on: + push: {} # fixme schedule: - cron: "0 0 * * *" # At the end of every day diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index d98cabd12a469..ec0d7cb97563b 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -114,8 +114,8 @@ RUN \ rm -rf apex RUN \ - # install DeepSpeed - pip install deepspeed>=0.3.14 + # install DeepSpeed and FairScale + pip install deepspeed>=0.3.14 fairscale>=0.3.4 RUN \ # Show what we have diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 91406ad10185b..7a86bd9f638cf 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -76,7 +76,7 @@ def _compare_version(package: str, op, version) -> bool: _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') _FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel') -_FAIRSCALE_FULLY_SHARDED_AVAILABLE = not _IS_WINDOWS and _compare_version("fairscale", operator.ge, "0.3.0") +_FAIRSCALE_FULLY_SHARDED_AVAILABLE = not _IS_WINDOWS and _compare_version("fairscale", operator.ge, "0.3.4") _FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') _HOROVOD_AVAILABLE = _module_available("horovod.torch") diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 3bd3f24740def..6f0b0c18fe8e1 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -20,6 +20,7 @@ from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf +FullyShardedDataParallel = None if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel import ShardedDataParallel if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: From 56f23ce4329a345d046b7aefcf0bfdc8eca0e297 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 14 Apr 2021 10:17:23 +0100 Subject: [PATCH 42/62] pull --- .github/workflows/events-nightly.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/events-nightly.yml b/.github/workflows/events-nightly.yml index 91d509f193339..5ad4396a006f7 100644 --- a/.github/workflows/events-nightly.yml +++ b/.github/workflows/events-nightly.yml @@ -3,7 +3,6 @@ name: Nightly events # https://jasonet.co/posts/scheduled-actions/ # https://github.community/t/distinct-job-for-each-schedule/17811/2 on: - push: {} # fixme schedule: - cron: "0 0 * * *" # At the end of every day From 9ca3f0c77dfa35eadc6eb6155d164921ee0bde18 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 14 Apr 2021 11:08:53 +0100 Subject: [PATCH 43/62] Few changes across the board --- .../plugins/training_type/fully_sharded.py | 10 ++-- .../training_type/training_type_plugin.py | 2 +- tests/helpers/runif.py | 8 ++++ tests/plugins/test_fully_sharded_plugin.py | 46 +++++++++---------- 4 files changed, 38 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 41a94bfeba2ec..89523a7b8e048 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -180,10 +180,6 @@ def model_to_device(self): def lightning_module(self) -> LightningModule: return unwrap_lightning_module_fully_sharded(self.model) - @property - def move_to_device_in_prefetch(self): - return False - def on_save(self, checkpoint: dict) -> dict: state_dict = self.collate_state_dict() checkpoint['state_dict'] = state_dict @@ -201,4 +197,10 @@ def collate_state_dict(self): @property def setup_optimizers_in_pre_dispatch(self) -> bool: + # Setup optimizers after the Fully Sharded Model has been made return True + + @property + def move_to_device_in_prefetch(self): + # Fully Sharded handles moving to device + return False diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 6fd02142bf410..096d4616f327b 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -247,8 +247,8 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], filepath: str) -> None: filepath: write-target file's path """ # dump states as a checkpoint dictionary object + checkpoint = self.on_save(checkpoint) if self.is_global_zero: - checkpoint = self.on_save(checkpoint) try: # write the checkpoint dictionary on the file atomic_save(checkpoint, filepath) diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 5483e33d9cddb..2b25a9333f52c 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -24,6 +24,7 @@ _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _FAIRSCALE_AVAILABLE, + _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE, _HOROVOD_AVAILABLE, _NATIVE_AMP_AVAILABLE, @@ -68,6 +69,7 @@ def __new__( special: bool = False, rpc: bool = False, fairscale: bool = False, + fairscale_fully_sharded: bool = False, fairscale_pipe: bool = False, deepspeed: bool = False, **kwargs @@ -89,6 +91,8 @@ def __new__( special: running in special mode, outside pytest suit rpc: requires Remote Procedure Call (RPC) fairscale: if `fairscale` module is required to run the test + fairscale_fully_sharded: if `fairscale` fully sharded module is required to run the test + fairscale_pipe: if `fairscale` with pipe module is required to run the test deepspeed: if `deepspeed` module is required to run the test kwargs: native pytest.mark.skipif keyword arguments """ @@ -156,6 +160,10 @@ def __new__( conditions.append(not _FAIRSCALE_AVAILABLE) reasons.append("Fairscale") + if fairscale_fully_sharded: + conditions.append(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE) + reasons.append("Fairscale Fully Sharded") + if fairscale_pipe: conditions.append(not _FAIRSCALE_PIPE_AVAILABLE) reasons.append("Fairscale Pipe") diff --git a/tests/plugins/test_fully_sharded_plugin.py b/tests/plugins/test_fully_sharded_plugin.py index 54c3ed078d2e2..2fc7ac13e00d6 100644 --- a/tests/plugins/test_fully_sharded_plugin.py +++ b/tests/plugins/test_fully_sharded_plugin.py @@ -5,6 +5,7 @@ import torch from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.plugins import FullyShardedNativeMixedPrecisionPlugin, FullyShardedPlugin from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -15,7 +16,7 @@ from fairscale.nn import auto_wrap, default_auto_wrap_policy, FullyShardedDataParallel -@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") +@RunIf(fairscale_fully_sharded=True) def test_sharded_ddp_choice(tmpdir): """ Test to ensure that plugin is correctly chosen @@ -27,8 +28,7 @@ def test_sharded_ddp_choice(tmpdir): assert isinstance(trainer.accelerator.training_type_plugin, FullyShardedPlugin) -@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") -@RunIf(amp_apex=True) +@RunIf(amp_apex=True, fairscale_fully_sharded=True) def test_invalid_apex_sharded(tmpdir): """ Test to ensure that we raise an error when we try to use apex and sharded @@ -46,11 +46,10 @@ def test_invalid_apex_sharded(tmpdir): trainer.fit(model) -@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") @mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0"}) @mock.patch('torch.cuda.device_count', return_value=1) @mock.patch('torch.cuda.is_available', return_value=True) -@RunIf(amp_native=True) +@RunIf(amp_native=True, fairscale_fully_sharded=True) def test_ddp_choice_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): """ Test to ensure that plugin native amp plugin is correctly chosen when using sharded @@ -66,8 +65,7 @@ def test_ddp_choice_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): assert isinstance(trainer.accelerator.precision_plugin, FullyShardedNativeMixedPrecisionPlugin) -@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") -@RunIf(min_gpus=1, skip_windows=True) +@RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True) def test_fully_sharded_plugin_checkpoint(tmpdir): """ Test to ensure that checkpoint is saved correctly when using a single GPU. @@ -92,8 +90,7 @@ def configure_optimizers(self): @pytest.mark.parametrize('automatic_module_wrap', [True, False]) -@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") -@RunIf(min_gpus=1, skip_windows=True) +@RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True) def test_fully_sharded_plugin_checkpoint_manual_autowrap(automatic_module_wrap, tmpdir): """ Test to ensure that checkpoint is saved correctly when using automatic, and manual auto_wrap. @@ -130,14 +127,10 @@ def configure_optimizers(self): _assert_save_equality(tmpdir, trainer) -@pytest.mark.skipif(not _FAIRSCALE_FULLY_SHARDED_AVAILABLE, reason="Fairscale is not available") -@pytest.mark.skipif( - not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" -) -@RunIf(min_gpus=2, skip_windows=True) -def test_fully_sharded_plugin_checkpoint_multi_gpu(tmpdir): +@RunIf(min_gpus=2, skip_windows=True, fairscale_fully_sharded=True, special=False) +def test_fully_sharded_plugin_multi_gpu(tmpdir): """ - Test to ensure that checkpoint is saved correctly when using multiple GPUs + Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run. """ class TestModel(BoringModel): @@ -145,28 +138,35 @@ class TestModel(BoringModel): def configure_optimizers(self): return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) + ck = ModelCheckpoint(save_last=True) model = TestModel() trainer = Trainer( gpus=2, - plugins='fully_sharded', - fast_dev_run=True, + plugins='ddp_fully_sharded', + max_epochs=5, precision=16, ) trainer.fit(model) + trainer.test(model) + trainer.test(ck.last_model_path) + trainer.validate() + trainer.validate(ck.last_model_path) + trainer.predict(dataloaders=model.val_dataloader()) _assert_save_equality(tmpdir, trainer) def _assert_save_equality(tmpdir, trainer): - if trainer.global_rank == 0: + checkpoint_path = os.path.join(tmpdir, 'model.pt') + trainer.save_checkpoint(checkpoint_path) - checkpoint_path = os.path.join(tmpdir, 'model.pt') - trainer.save_checkpoint(checkpoint_path) + # Use FullySharded to get the state dict for the sake of comparison + model_state_dict = trainer.accelerator.training_type_plugin.collate_state_dict() + + if trainer.global_rank == 0: saved_model = BoringModel.load_from_checkpoint(checkpoint_path) - # Ensure we gather all shards for comparison - model_state_dict = trainer.accelerator.training_type_plugin.collate_state_dict() # Assert model parameters are identical after loading for ddp_param, shard_param in zip(model_state_dict.values(), saved_model.state_dict().values()): assert torch.equal(ddp_param.float().cpu(), shard_param) From b53ba36e88785ac5abd8e7ba546d3fcffc5a33d7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 14 Apr 2021 12:56:01 +0100 Subject: [PATCH 44/62] Fix imports --- pytorch_lightning/utilities/imports.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 7a86bd9f638cf..18a32ceb5305d 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -75,9 +75,9 @@ def _compare_version(package: str, op, version) -> bool: _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') -_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel') -_FAIRSCALE_FULLY_SHARDED_AVAILABLE = not _IS_WINDOWS and _compare_version("fairscale", operator.ge, "0.3.4") -_FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3") +_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale') and _TORCH_GREATER_EQUAL_1_6 +_FAIRSCALE_FULLY_SHARDED_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.ge, "0.3.4") +_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and _compare_version("fairscale", operator.le, "0.1.3") _GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group') _HOROVOD_AVAILABLE = _module_available("horovod.torch") _HYDRA_AVAILABLE = _module_available("hydra") From 0da52498dc61aae8615e098daf16de415e35930b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 14 Apr 2021 12:56:59 +0100 Subject: [PATCH 45/62] Set none --- tests/trainer/properties/test_get_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/properties/test_get_model.py b/tests/trainer/properties/test_get_model.py index 6f0b0c18fe8e1..23f9afce5b9b0 100644 --- a/tests/trainer/properties/test_get_model.py +++ b/tests/trainer/properties/test_get_model.py @@ -20,7 +20,7 @@ from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf -FullyShardedDataParallel = None +FullyShardedDataParallel, ShardedDataParallel = None, None if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel import ShardedDataParallel if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: From 90c647944ebd38d9ad280fa455c731f6e1693203 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 14 Apr 2021 12:59:27 +0100 Subject: [PATCH 46/62] Swap to warnings --- pytorch_lightning/plugins/training_type/rpc_sequential.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 7f291d681d8f5..2c4e5c1cb9c70 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -13,6 +13,7 @@ # limitations under the License import logging import os +import warnings from typing import Callable, List, Optional import torch @@ -91,7 +92,7 @@ def __init__( at the same time. Defaults to `True` if `get_model_parallel_world_size() > 1` """ - rank_zero_warn( + warnings.warn( "RPC Sequential Plugin has been deprecated. Please use the `FullyShardedPlugin` " "which provides better performance and scaling without pipelining the model.", DeprecationWarning ) From 69d81786102539ce1445dad858074b4078be0c78 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 14 Apr 2021 14:21:53 +0100 Subject: [PATCH 47/62] Remove fairscale from container --- .github/workflows/events-nightly.yml | 1 + dockers/base-cuda/Dockerfile | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/events-nightly.yml b/.github/workflows/events-nightly.yml index 5ad4396a006f7..91d509f193339 100644 --- a/.github/workflows/events-nightly.yml +++ b/.github/workflows/events-nightly.yml @@ -3,6 +3,7 @@ name: Nightly events # https://jasonet.co/posts/scheduled-actions/ # https://github.community/t/distinct-job-for-each-schedule/17811/2 on: + push: {} # fixme schedule: - cron: "0 0 * * *" # At the end of every day diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index ec0d7cb97563b..228f7f0751970 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -115,7 +115,7 @@ RUN \ RUN \ # install DeepSpeed and FairScale - pip install deepspeed>=0.3.14 fairscale>=0.3.4 + pip install deepspeed>=0.3.14 RUN \ # Show what we have From a459d100c38a4578dd40b46abbedb9595741d752 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 14 Apr 2021 14:22:11 +0100 Subject: [PATCH 48/62] pull --- .github/workflows/events-nightly.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/events-nightly.yml b/.github/workflows/events-nightly.yml index 91d509f193339..5ad4396a006f7 100644 --- a/.github/workflows/events-nightly.yml +++ b/.github/workflows/events-nightly.yml @@ -3,7 +3,6 @@ name: Nightly events # https://jasonet.co/posts/scheduled-actions/ # https://github.community/t/distinct-job-for-each-schedule/17811/2 on: - push: {} # fixme schedule: - cron: "0 0 * * *" # At the end of every day From a7842d9e1e035c695495d466f865f84474354503 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Wed, 14 Apr 2021 14:26:13 +0100 Subject: [PATCH 49/62] Update dockers/base-cuda/Dockerfile --- dockers/base-cuda/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index 228f7f0751970..d98cabd12a469 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -114,7 +114,7 @@ RUN \ rm -rf apex RUN \ - # install DeepSpeed and FairScale + # install DeepSpeed pip install deepspeed>=0.3.14 RUN \ From 48ee83faf0ef159a9f854de89a37bac6dc3efa78 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 15 Apr 2021 14:36:16 +0100 Subject: [PATCH 50/62] Add defaults, add test to ensure nested wrapper is set correctly --- tests/plugins/test_fully_sharded_plugin.py | 35 +++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/plugins/test_fully_sharded_plugin.py b/tests/plugins/test_fully_sharded_plugin.py index 2fc7ac13e00d6..6628c11030eca 100644 --- a/tests/plugins/test_fully_sharded_plugin.py +++ b/tests/plugins/test_fully_sharded_plugin.py @@ -13,7 +13,7 @@ from tests.helpers.runif import RunIf if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: - from fairscale.nn import auto_wrap, default_auto_wrap_policy, FullyShardedDataParallel + from fairscale.nn import auto_wrap, default_auto_wrap_policy, FullyShardedDataParallel, wrap @RunIf(fairscale_fully_sharded=True) @@ -22,6 +22,7 @@ def test_sharded_ddp_choice(tmpdir): Test to ensure that plugin is correctly chosen """ trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, plugins='ddp_fully_sharded', ) @@ -37,6 +38,7 @@ def test_invalid_apex_sharded(tmpdir): model = BoringModel() with pytest.raises(MisconfigurationException, match='Sharded Plugins are not supported with Apex AMP'): trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, plugins='ddp_fully_sharded', precision=16, @@ -55,6 +57,7 @@ def test_ddp_choice_sharded_amp(device_count_mock, mock_cuda_available, tmpdir): Test to ensure that plugin native amp plugin is correctly chosen when using sharded """ trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, gpus=1, precision=16, @@ -78,6 +81,7 @@ def configure_optimizers(self): model = TestModel() trainer = Trainer( + default_root_dir=tmpdir, gpus=1, plugins='ddp_fully_sharded', fast_dev_run=True, @@ -89,6 +93,33 @@ def configure_optimizers(self): _assert_save_equality(tmpdir, trainer) +@RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True) +def test_nested_fsdp(tmpdir): + """ + Test that nested FSDP wrappers are set correctly to reshard after forward/backward pass. + This happens lazily so we need to run at-least one forward pass. + """ + + class TestModel(BoringModel): + + def configure_sharded_model(self) -> None: + self.layer = wrap( + torch.nn.Sequential(wrap(torch.nn.Linear(32, 32)), torch.nn.ReLU(), wrap(torch.nn.Linear(32, 2))) + ) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, fast_dev_run=True, gpus=1, plugins=FullyShardedPlugin(reshard_after_forward=True) + ) + trainer.fit(model) + + # root should not be resharding + assert model.layer.reshard_after_forward is False + # Assert that the nested layers are set reshard_after_forward to True + assert model.layer.module[0].reshard_after_forward is True + assert model.layer.module[2].reshard_after_forward is True + + @pytest.mark.parametrize('automatic_module_wrap', [True, False]) @RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True) def test_fully_sharded_plugin_checkpoint_manual_autowrap(automatic_module_wrap, tmpdir): @@ -116,6 +147,7 @@ def configure_optimizers(self): model = TestModel() trainer = Trainer( + default_root_dir=tmpdir, gpus=1, plugins=FullyShardedPlugin(automatic_module_wrap=automatic_module_wrap, min_num_params=1), fast_dev_run=True, @@ -141,6 +173,7 @@ def configure_optimizers(self): ck = ModelCheckpoint(save_last=True) model = TestModel() trainer = Trainer( + default_root_dir=tmpdir, gpus=2, plugins='ddp_fully_sharded', max_epochs=5, From 57a696c6029c595cc35acc9330c403f837b4e106 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Thu, 15 Apr 2021 16:09:37 +0100 Subject: [PATCH 51/62] Remove deprecation as this will be removed completely --- .../plugins/training_type/rpc_sequential.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index 2c4e5c1cb9c70..37b7ae994585b 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -13,7 +13,6 @@ # limitations under the License import logging import os -import warnings from typing import Callable, List, Optional import torch @@ -26,7 +25,7 @@ from pytorch_lightning.overrides.distributed import LightningDistributedModule from pytorch_lightning.plugins.training_type.rpc import DEFAULT_RPC_TIMEOUT_SEC, RPCPlugin from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only, rank_zero_warn +from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_PIPE_AVAILABLE: @@ -57,10 +56,6 @@ def __init__( .. _RPCSequentialPlugin: https://arxiv.org/abs/1811.06965 - .. deprecated:: - This plugin has been deprecated. Please use the ``FullyShardedPlugin`` which provides better performance - and scaling without pipelining the model. - Pipeline parallelism comes with with checkpointing to reduce peak memory required to train while minimizing device under-utilization. This is turned on by default and can be turned off via the checkpoint argument. @@ -92,10 +87,6 @@ def __init__( at the same time. Defaults to `True` if `get_model_parallel_world_size() > 1` """ - warnings.warn( - "RPC Sequential Plugin has been deprecated. Please use the `FullyShardedPlugin` " - "which provides better performance and scaling without pipelining the model.", DeprecationWarning - ) self._check_pipe_available() super().__init__(rpc_timeout_sec=rpc_timeout_sec, **kwargs) From 36889b80909956e1550c2e84c649f3eb61f0545e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Fri, 16 Apr 2021 11:54:16 +0100 Subject: [PATCH 52/62] Check for nested FSDP wrappers, and omit wrapping algorithm --- pytorch_lightning/plugins/training_type/fully_sharded.py | 8 +++++++- tests/plugins/test_fully_sharded_plugin.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 89523a7b8e048..a0b79752833e4 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -156,9 +156,15 @@ def wrap_policy(*args, **kwargs): ): yield + def _model_has_nested_fsdp(self): + for module in self.model.modules(): + if isinstance(module, FullyShardedDataParallel): + return True + return False + def configure_ddp(self): with self.model_sharded_context(): - if self.automatic_module_wrap: + if self.automatic_module_wrap and not self._model_has_nested_fsdp(): self.model = auto_wrap(LightningFullyShardedModule(self.model)) if not isinstance(self.model, FullyShardedDataParallel): self.model = wrap(self.model) diff --git a/tests/plugins/test_fully_sharded_plugin.py b/tests/plugins/test_fully_sharded_plugin.py index 6628c11030eca..95b1c8d367dd0 100644 --- a/tests/plugins/test_fully_sharded_plugin.py +++ b/tests/plugins/test_fully_sharded_plugin.py @@ -159,7 +159,7 @@ def configure_optimizers(self): _assert_save_equality(tmpdir, trainer) -@RunIf(min_gpus=2, skip_windows=True, fairscale_fully_sharded=True, special=False) +@RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True, special=True) def test_fully_sharded_plugin_multi_gpu(tmpdir): """ Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run. From 0c1d2de5ede60a87d486a34effb492d96508b94f Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Wed, 21 Apr 2021 10:44:44 +0100 Subject: [PATCH 53/62] Update pytorch_lightning/trainer/connectors/accelerator_connector.py Co-authored-by: ananthsub --- pytorch_lightning/trainer/connectors/accelerator_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/accelerator_connector.py b/pytorch_lightning/trainer/connectors/accelerator_connector.py index 91094a6154dc4..d5f0d2c3e44d3 100644 --- a/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -364,7 +364,7 @@ def select_precision_plugin(self) -> PrecisionPlugin: if self._sharded_training_type or self._fully_sharded_training_type: raise MisconfigurationException( "Sharded Plugins are not supported with Apex AMP," - " please using native AMP for 16-bit precision." + " please use native AMP for 16-bit precision." ) log.info("Using APEX 16bit precision.") return ApexMixedPrecisionPlugin(self.amp_level) From 592bb28089a2e5e64a2b08b216775c777c496b74 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Wed, 21 Apr 2021 12:21:35 +0100 Subject: [PATCH 54/62] Address code review points --- .../plugins/training_type/ddp.py | 13 ++------ .../plugins/training_type/fully_sharded.py | 33 ++++++++++--------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 113d22044f437..977145a4cc7ba 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -270,9 +270,8 @@ def init_ddp_connection(self, global_rank: Optional[int] = None, world_size: Opt torch_distrib.init_process_group(self.torch_distributed_backend, rank=global_rank, world_size=world_size) def pre_dispatch(self): - if self.move_to_device_in_prefetch: - # move the model to the correct device - self.model_to_device() + # move the model to the correct device + self.model_to_device() if self.sync_batchnorm: self.model = self.configure_sync_batchnorm(self.model) @@ -284,14 +283,6 @@ def pre_dispatch(self): def post_dispatch(self) -> None: self.cluster_environment.teardown() - @property - def move_to_device_in_prefetch(self) -> bool: - """ - We will call the model_to_device hook within pre-fetch if this is set to True. - Useful for when plugin would like to call model_to_device at another time, or skip the call. - """ - return True - def barrier(self, *args, **kwargs): if torch_distrib.is_initialized(): torch_distrib.barrier() diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index a0b79752833e4..ee40a9a72fdac 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -43,9 +43,9 @@ def __init__( automatic_module_wrap: bool = False, min_num_params: int = 1e8, parallel_devices: Optional[List[torch.device]] = None, - num_nodes: int = 1, + num_nodes: Optional[int] = None, cluster_environment: ClusterEnvironment = None, - sync_batchnorm: Optional[bool] = False + sync_batchnorm: Optional[bool] = None ): """ @@ -111,7 +111,7 @@ def __init__( if sync_batchnorm: raise MisconfigurationException("Currently sync batch norm is not supported by Full Sharded Training.") - super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm=sync_batchnorm) + super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm) self.cpu_offload = cpu_offload self.move_grads_to_cpu = move_grads_to_cpu self.flatten_parameters = flatten_parameters @@ -129,13 +129,13 @@ def process_group(self): self._process_group = torch.distributed.new_group() return self._process_group - @contextlib.contextmanager - def model_sharded_context(self) -> Generator: - - # set the device before instantiate the wrapper + def setup_distributed(self): + super().setup_distributed() if self.root_device.type == "cuda": torch.cuda.set_device(self.root_device) + @contextlib.contextmanager + def model_sharded_context(self) -> Generator: precision = self.lightning_module.trainer.precision def wrap_policy(*args, **kwargs): @@ -156,12 +156,6 @@ def wrap_policy(*args, **kwargs): ): yield - def _model_has_nested_fsdp(self): - for module in self.model.modules(): - if isinstance(module, FullyShardedDataParallel): - return True - return False - def configure_ddp(self): with self.model_sharded_context(): if self.automatic_module_wrap and not self._model_has_nested_fsdp(): @@ -182,6 +176,12 @@ def model_to_device(self): # ensure we update the device type in the lightning module self.lightning_module.to(self.root_device) + def pre_dispatch(self): + if self.sync_batchnorm: + self.model = self.configure_sync_batchnorm(self.model) + self.configure_ddp() + self.barrier() + @property def lightning_module(self) -> LightningModule: return unwrap_lightning_module_fully_sharded(self.model) @@ -206,7 +206,8 @@ def setup_optimizers_in_pre_dispatch(self) -> bool: # Setup optimizers after the Fully Sharded Model has been made return True - @property - def move_to_device_in_prefetch(self): - # Fully Sharded handles moving to device + def _model_has_nested_fsdp(self): + for module in self.model.modules(): + if isinstance(module, FullyShardedDataParallel): + return True return False From ca8e586758fb7e86113be04be0ef5ac07e1be73f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 26 Apr 2021 22:05:37 +0100 Subject: [PATCH 55/62] Add back missing model that was removed from clipping signature --- pytorch_lightning/accelerators/accelerator.py | 4 +++- pytorch_lightning/plugins/precision/deepspeed_precision.py | 2 ++ .../plugins/precision/fully_sharded_native_amp.py | 6 ++++-- pytorch_lightning/plugins/precision/precision_plugin.py | 1 + pytorch_lightning/trainer/training_loop.py | 5 ++++- 5 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index c1d7878e4e38f..52c205a213573 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -331,7 +331,9 @@ def clip_gradients( gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, ) -> None: """clips all the optimizer parameters to the given value""" - self.precision_plugin.clip_gradients(optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm) + self.precision_plugin.clip_gradients( + self.model, optimizer, clip_val, gradient_clip_algorithm=gradient_clip_algorithm + ) def on_train_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: """Hook to do something on the end of an training epoch diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index dc29a5cee4014..41e84e5c33178 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Any, Callable, Union +import torch from torch import Tensor from torch.optim import Optimizer @@ -76,6 +77,7 @@ def backward( def clip_gradients( self, + model: Union[torch.nn.Module, 'pl.LightningModule'], optimizer: Optimizer, clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 5e225caa5acac..2790dd5d8a569 100644 --- a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union +from typing import Union +import torch from torch.optim import Optimizer +import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType @@ -24,7 +26,7 @@ class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): def clip_gradients( self, - model: 'LightningModule', + model: Union[torch.nn.Module, 'pl.LightningModule'], optimizer: 'Optimizer', clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index c1ea3287964a8..cd271b75e23ba 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -101,6 +101,7 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: def clip_gradients( self, + model: Union[torch.nn.Module, 'pl.LightningModule'], optimizer: Optimizer, clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 9284c75879270..b235033a8af74 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -437,7 +437,10 @@ def track_and_norm_grad(self, optimizer): # clip gradients self.trainer.accelerator.clip_gradients( - optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm + self.trainer.model, + optimizer, + self.trainer.gradient_clip_val, + gradient_clip_algorithm=self.trainer.gradient_clip_algorithm ) self._cur_grad_norm_dict = grad_norm_dic From 54f501d97693aa6854c84aeb84a5c7b37e4d47f7 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Mon, 26 Apr 2021 22:30:56 +0100 Subject: [PATCH 56/62] Do not pass model through, accelerator does it --- pytorch_lightning/trainer/training_loop.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index b235033a8af74..9284c75879270 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -437,10 +437,7 @@ def track_and_norm_grad(self, optimizer): # clip gradients self.trainer.accelerator.clip_gradients( - self.trainer.model, - optimizer, - self.trainer.gradient_clip_val, - gradient_clip_algorithm=self.trainer.gradient_clip_algorithm + optimizer, self.trainer.gradient_clip_val, gradient_clip_algorithm=self.trainer.gradient_clip_algorithm ) self._cur_grad_norm_dict = grad_norm_dic From b67f1a9941f28cea158ace10cebd75b016d949bc Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 27 Apr 2021 13:34:50 +0100 Subject: [PATCH 57/62] Fix merge --- .../plugins/precision/deepspeed_precision.py | 1 - .../plugins/precision/fully_sharded_native_amp.py | 9 ++++----- pytorch_lightning/plugins/precision/precision_plugin.py | 1 - 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 47ffde9a461e5..a169f66312877 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -78,7 +78,6 @@ def backward( def clip_gradients( self, - model: Union[torch.nn.Module, 'pl.LightningModule'], optimizer: Optimizer, clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 2790dd5d8a569..42e8861245074 100644 --- a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Optional, Union -import torch +from torch.nn import Module from torch.optim import Optimizer -import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin from pytorch_lightning.utilities import GradClipAlgorithmType @@ -26,11 +25,11 @@ class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): def clip_gradients( self, - model: Union[torch.nn.Module, 'pl.LightningModule'], optimizer: 'Optimizer', clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, - norm_type: float = 2.0 + norm_type: float = 2.0, + model: Optional[Module] = None ) -> None: # Model manages clipping of gradients model.clip_grad_norm_(clip_val, norm_type) diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index b1480b3a31efa..f324b21732235 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -101,7 +101,6 @@ def post_optimizer_step(self, optimizer: Optimizer, optimizer_idx: int) -> None: def clip_gradients( self, - model: Union[torch.nn.Module, 'pl.LightningModule'], optimizer: Optimizer, clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, From 132eb64b4ccd6610a20903a4f027d23da3ea1af3 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 27 Apr 2021 13:37:01 +0100 Subject: [PATCH 58/62] Fix imports --- pytorch_lightning/plugins/precision/deepspeed_precision.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index a169f66312877..f05fd4d54b811 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Any, Callable, Optional, Union -import torch from torch import Tensor from torch.nn import Module from torch.optim import Optimizer From e6ce3cf86622613d5042eb38630295859715371b Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 27 Apr 2021 14:19:04 +0100 Subject: [PATCH 59/62] Changes to precision plugin --- .../plugins/precision/fully_sharded_native_amp.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 42e8861245074..7220f71438762 100644 --- a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -11,13 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import cast, Optional, Union from torch.nn import Module from torch.optim import Optimizer from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.utilities import GradClipAlgorithmType +from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE, GradClipAlgorithmType + +if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + from fairscale.nn.data_parallel import FullyShardedDataParallel class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): @@ -28,8 +31,9 @@ def clip_gradients( optimizer: 'Optimizer', clip_val: Union[int, float], gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM, - norm_type: float = 2.0, model: Optional[Module] = None ) -> None: # Model manages clipping of gradients - model.clip_grad_norm_(clip_val, norm_type) + model = cast(FullyShardedDataParallel, model) + # todo: expose norm type once precision plugin supports this. + model.clip_grad_norm_(clip_val, norm_type=2.0) From 01153af93ef65942328743f815927b8a636132ba Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 27 Apr 2021 17:24:13 +0100 Subject: [PATCH 60/62] Require 2 GPU for multi gpu test --- tests/plugins/test_fully_sharded_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_fully_sharded_plugin.py b/tests/plugins/test_fully_sharded_plugin.py index 95b1c8d367dd0..6628c11030eca 100644 --- a/tests/plugins/test_fully_sharded_plugin.py +++ b/tests/plugins/test_fully_sharded_plugin.py @@ -159,7 +159,7 @@ def configure_optimizers(self): _assert_save_equality(tmpdir, trainer) -@RunIf(min_gpus=1, skip_windows=True, fairscale_fully_sharded=True, special=True) +@RunIf(min_gpus=2, skip_windows=True, fairscale_fully_sharded=True, special=False) def test_fully_sharded_plugin_multi_gpu(tmpdir): """ Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run. From efa81ab1a300249092eb4dff2cf4839c5aff98df Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 4 May 2021 16:11:11 +0100 Subject: [PATCH 61/62] Use callback in test, swap to DynamicLossScaler from fairscale to test it out --- .../plugins/precision/fully_sharded_native_amp.py | 5 +++++ .../plugins/training_type/fully_sharded.py | 13 +++++++++---- tests/plugins/test_fully_sharded_plugin.py | 12 ++++-------- 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 7220f71438762..303d7066ca3c4 100644 --- a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -20,12 +20,17 @@ from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE, GradClipAlgorithmType if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: + from fairscale.experimental.optim import DynamicLossScaler from fairscale.nn.data_parallel import FullyShardedDataParallel class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): """Mixed Precision for Full Sharded Training""" + def __init__(self) -> None: + super().__init__() + self.scaler = DynamicLossScaler() + def clip_gradients( self, optimizer: 'Optimizer', diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index ee40a9a72fdac..0e4d1afa56588 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib -from typing import Generator, List, Optional +from typing import Dict, Generator, List, Optional import torch @@ -23,7 +23,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: - from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, wrap + from fairscale.nn import auto_wrap, default_auto_wrap_policy, enable_wrap, FlattenParamsWrapper, wrap from fairscale.nn.data_parallel import FullyShardedDataParallel from pytorch_lightning.overrides.fairscale import LightningFullyShardedModule, unwrap_lightning_module_fully_sharded @@ -109,8 +109,6 @@ def __init__( "Full Sharded Training is not available. Install the latest FairScale via `pip install fairscale -U`" ) - if sync_batchnorm: - raise MisconfigurationException("Currently sync batch norm is not supported by Full Sharded Training.") super().__init__(parallel_devices, num_nodes, cluster_environment, sync_batchnorm) self.cpu_offload = cpu_offload self.move_grads_to_cpu = move_grads_to_cpu @@ -211,3 +209,10 @@ def _model_has_nested_fsdp(self): if isinstance(module, FullyShardedDataParallel): return True return False + + @classmethod + def register_plugins(cls, plugin_registry: Dict): + plugin_registry.register("fsdp", cls, description="Fully Sharded with LightningModule wrap") + plugin_registry.register( + "fsdp_offload", cls, description="Fully Sharded Training with CPU Offloading.", cpu_offload=True + ) diff --git a/tests/plugins/test_fully_sharded_plugin.py b/tests/plugins/test_fully_sharded_plugin.py index 6628c11030eca..a93896ec2d943 100644 --- a/tests/plugins/test_fully_sharded_plugin.py +++ b/tests/plugins/test_fully_sharded_plugin.py @@ -159,7 +159,7 @@ def configure_optimizers(self): _assert_save_equality(tmpdir, trainer) -@RunIf(min_gpus=2, skip_windows=True, fairscale_fully_sharded=True, special=False) +@RunIf(min_gpus=2, skip_windows=True, fairscale_fully_sharded=True, special=True) def test_fully_sharded_plugin_multi_gpu(tmpdir): """ Test to ensure that checkpoint is saved correctly when using multiple GPUs, and all stages can be run. @@ -173,18 +173,14 @@ def configure_optimizers(self): ck = ModelCheckpoint(save_last=True) model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, - gpus=2, - plugins='ddp_fully_sharded', - max_epochs=5, - precision=16, + default_root_dir=tmpdir, gpus=2, plugins='ddp_fully_sharded', max_epochs=5, precision=16, callbacks=ck ) trainer.fit(model) trainer.test(model) - trainer.test(ck.last_model_path) + trainer.test(ckpt_path=ck.last_model_path) trainer.validate() - trainer.validate(ck.last_model_path) + trainer.validate(ckpt_path=ck.last_model_path) trainer.predict(dataloaders=model.val_dataloader()) _assert_save_equality(tmpdir, trainer) From 78d52b52ebab1b1e9aa4f8e836ac10086e0f87ac Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 4 May 2021 17:27:35 +0100 Subject: [PATCH 62/62] Disable loss scaler for now --- .../plugins/precision/fully_sharded_native_amp.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 303d7066ca3c4..7220f71438762 100644 --- a/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -20,17 +20,12 @@ from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE, GradClipAlgorithmType if _FAIRSCALE_FULLY_SHARDED_AVAILABLE: - from fairscale.experimental.optim import DynamicLossScaler from fairscale.nn.data_parallel import FullyShardedDataParallel class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): """Mixed Precision for Full Sharded Training""" - def __init__(self) -> None: - super().__init__() - self.scaler = DynamicLossScaler() - def clip_gradients( self, optimizer: 'Optimizer',