From 9f189ff88a3980f464c64e57fb0136af3741b1c9 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 17 Aug 2022 18:08:21 +0200 Subject: [PATCH] Use fsdp module to initialize precision scalar for fsdp native (#14092) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ Co-authored-by: Laverne Henderson Co-authored-by: Rohit Gupta --- docs/source-pytorch/api_references.rst | 1 + docs/source-pytorch/extensions/plugins.rst | 1 + src/pytorch_lightning/CHANGELOG.md | 2 + src/pytorch_lightning/plugins/__init__.py | 2 + .../plugins/precision/__init__.py | 43 ++++++++---- .../precision/fsdp_native_native_amp.py | 65 +++++++++++++++++++ .../precision/fully_sharded_native_amp.py | 26 +------- .../strategies/fully_sharded_native.py | 4 +- .../connectors/accelerator_connector.py | 5 +- .../test_ddp_fully_sharded_native.py | 5 +- 10 files changed, 110 insertions(+), 44 deletions(-) create mode 100644 src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py diff --git a/docs/source-pytorch/api_references.rst b/docs/source-pytorch/api_references.rst index db4fc1e2c4cf8..ce7723e418e77 100644 --- a/docs/source-pytorch/api_references.rst +++ b/docs/source-pytorch/api_references.rst @@ -173,6 +173,7 @@ precision DeepSpeedPrecisionPlugin DoublePrecisionPlugin FullyShardedNativeMixedPrecisionPlugin + FullyShardedNativeNativeMixedPrecisionPlugin HPUPrecisionPlugin IPUPrecisionPlugin MixedPrecisionPlugin diff --git a/docs/source-pytorch/extensions/plugins.rst b/docs/source-pytorch/extensions/plugins.rst index a0dbefd141464..27aff0c11fdcb 100644 --- a/docs/source-pytorch/extensions/plugins.rst +++ b/docs/source-pytorch/extensions/plugins.rst @@ -56,6 +56,7 @@ The full list of built-in precision plugins is listed below. DeepSpeedPrecisionPlugin DoublePrecisionPlugin FullyShardedNativeMixedPrecisionPlugin + FullyShardedNativeNativeMixedPrecisionPlugin HPUPrecisionPlugin IPUPrecisionPlugin MixedPrecisionPlugin diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 85d538d3e2b46..80f6f71a03515 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `FullyShardedNativeNativeMixedPrecisionPlugin` to handle precision for `DDPFullyShardedNativeStrategy` ([#14092](https://github.com/Lightning-AI/lightning/pull/14092)) - Added profiling to these hooks: `on_before_batch_transfer`, `transfer_batch_to_device`, `on_after_batch_transfer`, `configure_gradient_clipping`, `clip_gradients` ([#14069](https://github.com/Lightning-AI/lightning/pull/14069)) ### Changed @@ -26,6 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Avoid `metadata.entry_points` deprecation warning on Python 3.10 ([#14052](https://github.com/Lightning-AI/lightning/pull/14052)) - Avoid raising the sampler warning if num_replicas=1 ([#14097](https://github.com/Lightning-AI/lightning/pull/14097)) - Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/Lightning-AI/lightning/pull/9938)) +- Avoided requiring the FairScale package to use precision with the fsdp native strategy ([#14092](https://github.com/Lightning-AI/lightning/pull/14092)) ## [1.7.1] - 2022-08-09 diff --git a/src/pytorch_lightning/plugins/__init__.py b/src/pytorch_lightning/plugins/__init__.py index afd10c88c951d..50d83ee708cbe 100644 --- a/src/pytorch_lightning/plugins/__init__.py +++ b/src/pytorch_lightning/plugins/__init__.py @@ -10,6 +10,7 @@ from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin +from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin @@ -63,6 +64,7 @@ "FullyShardedNativeMixedPrecisionPlugin", "SingleDevicePlugin", "SingleTPUPlugin", + "FullyShardedNativeNativeMixedPrecisionPlugin", "TPUPrecisionPlugin", "TPUBf16PrecisionPlugin", "TPUSpawnPlugin", diff --git a/src/pytorch_lightning/plugins/precision/__init__.py b/src/pytorch_lightning/plugins/precision/__init__.py index 4bc29c1be1864..5206aed62c497 100644 --- a/src/pytorch_lightning/plugins/precision/__init__.py +++ b/src/pytorch_lightning/plugins/precision/__init__.py @@ -11,17 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401 - FullyShardedNativeMixedPrecisionPlugin, -) -from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin # noqa: F401 -from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin +from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin +from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin +from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin +from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin +from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin + +__all__ = [ + "ApexMixedPrecisionPlugin", + "DeepSpeedPrecisionPlugin", + "DoublePrecisionPlugin", + "FullyShardedNativeNativeMixedPrecisionPlugin", + "FullyShardedNativeMixedPrecisionPlugin", + "HPUPrecisionPlugin", + "IPUPrecisionPlugin", + "MixedPrecisionPlugin", + "NativeMixedPrecisionPlugin", + "PrecisionPlugin", + "ShardedNativeMixedPrecisionPlugin", + "TPUPrecisionPlugin", + "TPUBf16PrecisionPlugin", +] diff --git a/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py b/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py new file mode 100644 index 0000000000000..2201db94586a2 --- /dev/null +++ b/src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py @@ -0,0 +1,65 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Optional, Union + +import torch + +from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin +from pytorch_lightning.utilities.enums import PrecisionType +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 + +if _TORCH_GREATER_EQUAL_1_12: + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +else: + MixedPrecision = None # type: ignore[misc,assignment] + + +class FullyShardedNativeNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): + """Native AMP for Fully Sharded Native Training.""" + + def __init__( + self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None + ) -> None: + if not _TORCH_GREATER_EQUAL_1_12: + raise MisconfigurationException( + "`FullyShardedNativeNativeMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards." + ) + super().__init__(precision, device, scaler=ShardedGradScaler() if scaler is None and precision == 16 else None) + + def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: + # see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_ + # section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect + # for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val) + # however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to + # trace back the root FSDP. Now we only support clip by value. + raise MisconfigurationException( + f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" + ) + + @property + def mixed_precision_config(self) -> Optional[MixedPrecision]: + assert MixedPrecision is not None + if self.precision == PrecisionType.HALF: + dtype = torch.float16 + elif self.precision == PrecisionType.BFLOAT: + dtype = torch.bfloat16 + else: + raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.") + return MixedPrecision( + param_dtype=dtype, + reduce_dtype=dtype, + buffer_dtype=dtype, + ) diff --git a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 8c693f2975bbd..870e658bfc9c3 100644 --- a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -11,19 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional - -import torch +from typing import Any from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 - -if _TORCH_GREATER_EQUAL_1_12: - from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision -else: - MixedPrecision = None class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): @@ -38,18 +29,3 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None: raise MisconfigurationException( f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`" ) - - @property - def mixed_precision_config(self) -> Optional[MixedPrecision]: - assert MixedPrecision is not None - if self.precision == PrecisionType.HALF: - dtype = torch.float16 - elif self.precision == PrecisionType.BFLOAT: - dtype = torch.bfloat16 - else: - raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.") - return MixedPrecision( - param_dtype=dtype, - reduce_dtype=dtype, - buffer_dtype=dtype, - ) diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 4c351f26fa3b9..9b927aa757d17 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -23,7 +23,7 @@ from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin -from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy from pytorch_lightning.strategies.strategy import TBroadcast @@ -158,7 +158,7 @@ def mixed_precision_config(self) -> Optional[MixedPrecision]: if self.mixed_precision: return self.mixed_precision plugin = self.precision_plugin - if isinstance(plugin, FullyShardedNativeMixedPrecisionPlugin): + if isinstance(plugin, FullyShardedNativeNativeMixedPrecisionPlugin): return plugin.mixed_precision_config @property diff --git a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py index bd879cf85ff7a..44c3b3ec7540a 100644 --- a/src/pytorch_lightning/trainer/connectors/accelerator_connector.py +++ b/src/pytorch_lightning/trainer/connectors/accelerator_connector.py @@ -53,6 +53,7 @@ TorchElasticEnvironment, ) from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm +from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin from pytorch_lightning.strategies import ( DDP2Strategy, DDPFullyShardedNativeStrategy, @@ -727,7 +728,9 @@ def _check_and_init_precision(self) -> PrecisionPlugin: if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)): return ShardedNativeMixedPrecisionPlugin(self._precision_flag, device) - if isinstance(self.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy)): + if isinstance(self.strategy, DDPFullyShardedNativeStrategy): + return FullyShardedNativeNativeMixedPrecisionPlugin(self._precision_flag, device) + if isinstance(self.strategy, DDPFullyShardedStrategy): return FullyShardedNativeMixedPrecisionPlugin(self._precision_flag, device) return NativeMixedPrecisionPlugin(self._precision_flag, device) diff --git a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py index 74f9534c47ce3..ede201da1f68f 100644 --- a/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py +++ b/tests/tests_pytorch/strategies/test_ddp_fully_sharded_native.py @@ -7,7 +7,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin +from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12 @@ -35,7 +35,7 @@ def test_invalid_on_cpu(tmpdir): @RunIf(min_torch="1.12", min_cuda_gpus=1) @pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)]) def test_precision_plugin_config(precision, expected): - plugin = FullyShardedNativeMixedPrecisionPlugin(precision=precision, device="cuda") + plugin = FullyShardedNativeNativeMixedPrecisionPlugin(precision=precision, device="cuda") config = plugin.mixed_precision_config assert config.param_dtype == expected assert config.buffer_dtype == expected @@ -96,6 +96,7 @@ def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: in def _assert_layer_fsdp_instance(self) -> None: assert isinstance(self.layer, FullyShardedDataParallel) + assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin) assert isinstance(self.layer.module[0], FullyShardedDataParallel) assert isinstance(self.layer.module[2], FullyShardedDataParallel) # root should not be resharding