Skip to content

Commit

Permalink
Use fsdp module to initialize precision scalar for fsdp native (#14092)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: Laverne Henderson <laverne.henderson@coupa.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
4 people committed Aug 17, 2022
1 parent d03a7e9 commit 9f189ff
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 44 deletions.
1 change: 1 addition & 0 deletions docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ precision
DeepSpeedPrecisionPlugin
DoublePrecisionPlugin
FullyShardedNativeMixedPrecisionPlugin
FullyShardedNativeNativeMixedPrecisionPlugin
HPUPrecisionPlugin
IPUPrecisionPlugin
MixedPrecisionPlugin
Expand Down
1 change: 1 addition & 0 deletions docs/source-pytorch/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ The full list of built-in precision plugins is listed below.
DeepSpeedPrecisionPlugin
DoublePrecisionPlugin
FullyShardedNativeMixedPrecisionPlugin
FullyShardedNativeNativeMixedPrecisionPlugin
HPUPrecisionPlugin
IPUPrecisionPlugin
MixedPrecisionPlugin
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `FullyShardedNativeNativeMixedPrecisionPlugin` to handle precision for `DDPFullyShardedNativeStrategy` ([#14092](https://github.com/Lightning-AI/lightning/pull/14092))
- Added profiling to these hooks: `on_before_batch_transfer`, `transfer_batch_to_device`, `on_after_batch_transfer`, `configure_gradient_clipping`, `clip_gradients` ([#14069](https://github.com/Lightning-AI/lightning/pull/14069))

### Changed
Expand All @@ -26,6 +27,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Avoid `metadata.entry_points` deprecation warning on Python 3.10 ([#14052](https://github.com/Lightning-AI/lightning/pull/14052))
- Avoid raising the sampler warning if num_replicas=1 ([#14097](https://github.com/Lightning-AI/lightning/pull/14097))
- Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/Lightning-AI/lightning/pull/9938))
- Avoided requiring the FairScale package to use precision with the fsdp native strategy ([#14092](https://github.com/Lightning-AI/lightning/pull/14092))


## [1.7.1] - 2022-08-09
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin
from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin
Expand Down Expand Up @@ -63,6 +64,7 @@
"FullyShardedNativeMixedPrecisionPlugin",
"SingleDevicePlugin",
"SingleTPUPlugin",
"FullyShardedNativeNativeMixedPrecisionPlugin",
"TPUPrecisionPlugin",
"TPUBf16PrecisionPlugin",
"TPUSpawnPlugin",
Expand Down
43 changes: 29 additions & 14 deletions src/pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import ( # noqa: F401
FullyShardedNativeMixedPrecisionPlugin,
)
from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin # noqa: F401
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.hpu import HPUPrecisionPlugin
from pytorch_lightning.plugins.precision.ipu import IPUPrecisionPlugin
from pytorch_lightning.plugins.precision.mixed import MixedPrecisionPlugin
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.tpu import TPUPrecisionPlugin
from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin

__all__ = [
"ApexMixedPrecisionPlugin",
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"FullyShardedNativeNativeMixedPrecisionPlugin",
"FullyShardedNativeMixedPrecisionPlugin",
"HPUPrecisionPlugin",
"IPUPrecisionPlugin",
"MixedPrecisionPlugin",
"NativeMixedPrecisionPlugin",
"PrecisionPlugin",
"ShardedNativeMixedPrecisionPlugin",
"TPUPrecisionPlugin",
"TPUBf16PrecisionPlugin",
]
65 changes: 65 additions & 0 deletions src/pytorch_lightning/plugins/precision/fsdp_native_native_amp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional, Union

import torch

from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
else:
MixedPrecision = None # type: ignore[misc,assignment]


class FullyShardedNativeNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
"""Native AMP for Fully Sharded Native Training."""

def __init__(
self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
) -> None:
if not _TORCH_GREATER_EQUAL_1_12:
raise MisconfigurationException(
"`FullyShardedNativeNativeMixedPrecisionPlugin` is supported from PyTorch v1.12.0 onwards."
)
super().__init__(precision, device, scaler=ShardedGradScaler() if scaler is None and precision == 16 else None)

def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
# see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
# section `Gradient Clipping`, using `torch.nn.utils.clip_grad_norm_` is incorrect
# for FSDP module. To overcome this, needs to call sharded_module.clip_grad_norm(clip_val)
# however we rely on LightningModule's configure_sharded_model to wrap FSDP, it would be hard to
# trace back the root FSDP. Now we only support clip by value.
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)

@property
def mixed_precision_config(self) -> Optional[MixedPrecision]:
assert MixedPrecision is not None
if self.precision == PrecisionType.HALF:
dtype = torch.float16
elif self.precision == PrecisionType.BFLOAT:
dtype = torch.bfloat16
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
return MixedPrecision(
param_dtype=dtype,
reduce_dtype=dtype,
buffer_dtype=dtype,
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional

import torch
from typing import Any

from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12

if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
else:
MixedPrecision = None


class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
Expand All @@ -38,18 +29,3 @@ def clip_grad_by_norm(self, *_: Any, **__: Any) -> None:
raise MisconfigurationException(
f"`gradient_clip_algorithm='norm'` is currently not supported for `{self.__class__.__name__}`"
)

@property
def mixed_precision_config(self) -> Optional[MixedPrecision]:
assert MixedPrecision is not None
if self.precision == PrecisionType.HALF:
dtype = torch.float16
elif self.precision == PrecisionType.BFLOAT:
dtype = torch.bfloat16
else:
raise MisconfigurationException(f"Was unable to infer precision type, received {self.precision!r}.")
return MixedPrecision(
param_dtype=dtype,
reduce_dtype=dtype,
buffer_dtype=dtype,
)
4 changes: 2 additions & 2 deletions src/pytorch_lightning/strategies/fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
Expand Down Expand Up @@ -158,7 +158,7 @@ def mixed_precision_config(self) -> Optional[MixedPrecision]:
if self.mixed_precision:
return self.mixed_precision
plugin = self.precision_plugin
if isinstance(plugin, FullyShardedNativeMixedPrecisionPlugin):
if isinstance(plugin, FullyShardedNativeNativeMixedPrecisionPlugin):
return plugin.mixed_precision_config

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
TorchElasticEnvironment,
)
from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
from pytorch_lightning.strategies import (
DDP2Strategy,
DDPFullyShardedNativeStrategy,
Expand Down Expand Up @@ -727,7 +728,9 @@ def _check_and_init_precision(self) -> PrecisionPlugin:

if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)):
return ShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
if isinstance(self.strategy, (DDPFullyShardedStrategy, DDPFullyShardedNativeStrategy)):
if isinstance(self.strategy, DDPFullyShardedNativeStrategy):
return FullyShardedNativeNativeMixedPrecisionPlugin(self._precision_flag, device)
if isinstance(self.strategy, DDPFullyShardedStrategy):
return FullyShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
return NativeMixedPrecisionPlugin(self._precision_flag, device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.plugins.precision.fully_sharded_native_amp import FullyShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
from pytorch_lightning.strategies import DDPFullyShardedNativeStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
Expand Down Expand Up @@ -35,7 +35,7 @@ def test_invalid_on_cpu(tmpdir):
@RunIf(min_torch="1.12", min_cuda_gpus=1)
@pytest.mark.parametrize("precision, expected", [(16, torch.float16), ("bf16", torch.bfloat16)])
def test_precision_plugin_config(precision, expected):
plugin = FullyShardedNativeMixedPrecisionPlugin(precision=precision, device="cuda")
plugin = FullyShardedNativeNativeMixedPrecisionPlugin(precision=precision, device="cuda")
config = plugin.mixed_precision_config
assert config.param_dtype == expected
assert config.buffer_dtype == expected
Expand Down Expand Up @@ -96,6 +96,7 @@ def on_predict_batch_end(self, outputs: Optional[Any], batch: Any, batch_idx: in

def _assert_layer_fsdp_instance(self) -> None:
assert isinstance(self.layer, FullyShardedDataParallel)
assert isinstance(self.trainer.strategy.precision_plugin, FullyShardedNativeNativeMixedPrecisionPlugin)
assert isinstance(self.layer.module[0], FullyShardedDataParallel)
assert isinstance(self.layer.module[2], FullyShardedDataParallel)
# root should not be resharding
Expand Down

0 comments on commit 9f189ff

Please sign in to comment.