From f974e4d26e7d847b1d0952c4c23ba052a6a80f1b Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 14 Jun 2021 08:21:26 -0400 Subject: [PATCH 1/8] resolve manual optimization --- pytorch_lightning/overrides/distributed.py | 4 + .../plugins/training_type/ddp.py | 2 +- .../plugins/training_type/deepspeed.py | 3 +- tests/plugins/test_deepspeed_plugin.py | 78 +++++++++++++++++-- 4 files changed, 78 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index d4b1e6ed22d55..6f4e57d76efd9 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -63,6 +63,10 @@ def _find_tensors(obj): # pragma: no-cover # Note: Keep track of Pytorch DDP and update if there is a change # https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 def prepare_for_backward(model: DistributedDataParallel, output: Any): + # DDP based plugin are being subclasses + # and `prepare_for_backward` is `DistributedDataParallel` specific. + if not isinstance(model, DistributedDataParallel): + return if torch.is_grad_enabled() and model.require_backward_grad_sync: model.require_forward_param_sync = True # We'll return the output object verbatim since it is a freeform diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 4990f95f14ac0..42be8009aab28 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -320,7 +320,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int): """Run before precision plugin executes backward""" - if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync: + if not self.lightning_module.automatic_optimization: prepare_for_backward(self.model, closure_loss) def model_to_device(self): diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index dc688de65cd34..d04c21104c020 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -312,7 +312,8 @@ def _initialize_deepspeed_train(self, model): @contextlib.contextmanager def model_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: - model_parallel_context = deepspeed.zero.Init(remote_device="cpu", pin_memory=True) + dtype = torch.float16 if self.lightning_module.trainer.accelerator.precision == 16 else torch.float32 + model_parallel_context = deepspeed.zero.Init(remote_device="cpu", pin_memory=True, dtype=dtype) else: model_parallel_context = super().model_sharded_context() diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 85d069b90288d..4587c99997fad 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -24,10 +24,32 @@ class ModelParallelBoringModel(BoringModel): def __init__(self): super().__init__() - self.linear = None + self.layer = None def configure_sharded_model(self) -> None: - self.linear = torch.nn.Linear(32, 2) + self.layer = torch.nn.Linear(32, 2) + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + self.configure_sharded_model() + + +class ModelParallelBoringModelManualOptim(BoringModel): + + def __init__(self): + super().__init__() + self.layer = None + self.automatic_optimization = False + + def training_step(self, batch, batch_idx): + opt = self.optimizers()[0] + output = self(batch) + loss = self.loss(batch, output) + opt.zero_grad() + self.manual_backward(loss) + opt.step() + + def configure_sharded_model(self) -> None: + self.layer = torch.nn.Linear(32, 2) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self.configure_sharded_model() @@ -425,7 +447,7 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config): class ModelParallelClassificationModel(LightningModule): - def __init__(self, lr: float = 0.01, num_blocks: int = 5): + def __init__(self, lr: float = 0.01, num_blocks: int = 5, automatic_optimization: bool = True): super().__init__() self.lr = lr self.num_blocks = num_blocks @@ -433,6 +455,8 @@ def __init__(self, lr: float = 0.01, num_blocks: int = 5): self.train_acc = Accuracy() self.valid_acc = Accuracy() self.test_acc = Accuracy() + self.automatic_optimization = automatic_optimization + self.training_step = self.training_step_automatic if self.automatic_optimization else self.training_step_manual def make_block(self): return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU()) @@ -447,7 +471,7 @@ def forward(self, x): logits = F.softmax(x, dim=1) return logits - def training_step(self, batch, batch_idx): + def training_step_automatic(self, batch, batch_idx): x, y = batch logits = self.forward(x) loss = F.cross_entropy(logits, y) @@ -455,6 +479,17 @@ def training_step(self, batch, batch_idx): self.log('train_acc', self.train_acc(logits, y), prog_bar=True, sync_dist=True) return {"loss": loss} + def training_step_manual(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.cross_entropy(logits, y) + opt = self.optimizers()[0] + self.log('train_loss', loss, prog_bar=True) + self.log('train_acc', self.train_acc(logits, y), prog_bar=True, sync_dist=True) + opt.zero_grad() + self.manual_backward(loss) + opt.step() + def validation_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) @@ -502,9 +537,29 @@ def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config): _assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModel) -def run_checkpoint_test(tmpdir, save_full_weights): +@RunIf(min_gpus=2, deepspeed=True, special=True) +def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config): + """ + Test to ensure ZeRO Stage 3 works with a parallel model. + """ + model = ModelParallelBoringModelManualOptim() + model.training_epoch_end = None + trainer = Trainer( + default_root_dir=tmpdir, + plugins=[DeepSpeedPlugin(stage=3)], + gpus=2, + fast_dev_run=True, + precision=16, + ) + trainer.fit(model) + trainer.test(model) + + _assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelManualOptim) + + +def run_checkpoint_test(tmpdir, save_full_weights, automatic_optimization=True, accumulate_grad_batches=2): seed_everything(1) - model = ModelParallelClassificationModel() + model = ModelParallelClassificationModel(automatic_optimization=automatic_optimization) dm = ClassifDataModule() ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) trainer = Trainer( @@ -514,7 +569,7 @@ def run_checkpoint_test(tmpdir, save_full_weights): plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)], gpus=2, precision=16, - accumulate_grad_batches=2, + accumulate_grad_batches=accumulate_grad_batches, callbacks=[ck] ) trainer.fit(model, datamodule=dm) @@ -562,6 +617,15 @@ def test_deepspeed_multigpu_stage_3_checkpointing_full_weights(tmpdir): run_checkpoint_test(tmpdir, save_full_weights=True) +@RunIf(min_gpus=2, deepspeed=True, special=True) +def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir): + """ + Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint, + where we save the full weights to one file. + """ + run_checkpoint_test(tmpdir, save_full_weights=True, automatic_optimization=False, accumulate_grad_batches=1) + + @RunIf(min_gpus=2, deepspeed=True, special=True) @pytest.mark.parametrize('cpu_offload', [True, False]) def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, cpu_offload): From 0bdaef8505deddef5f0f1a62400fce4cb8f64acb Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Mon, 14 Jun 2021 08:36:42 -0400 Subject: [PATCH 2/8] resolve manual optimization --- dockers/base-cuda/Dockerfile | 3 +-- .../plugins/training_type/deepspeed.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index e16971bdc2a1a..49dbaae3472aa 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -118,8 +118,7 @@ RUN \ RUN \ # install DeepSpeed - # TODO(@SeanNaren): CI failing with `>=0.3.15` - skipping to unblock - pip install deepspeed==0.3.14 + pip install deepspeed>=0.4.0 RUN \ # Show what we have diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index d04c21104c020..83b2412155244 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -21,7 +21,7 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch - +import inspect import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase @@ -312,8 +312,15 @@ def _initialize_deepspeed_train(self, model): @contextlib.contextmanager def model_sharded_context(self) -> Generator[None, None, None]: if self.zero_stage_3: - dtype = torch.float16 if self.lightning_module.trainer.accelerator.precision == 16 else torch.float32 - model_parallel_context = deepspeed.zero.Init(remote_device="cpu", pin_memory=True, dtype=dtype) + kwargs = { + "remote_device": "cpu", + "pin_memory": True, + } + # from DeepSpeed 0.4.0, weights need to be properly casted before calling `deepspeed.initialize`. + if "dtype" in inspect.signature(deepspeed.zero.Init).parameters.keys(): + precision: int = self.lightning_module.trainer.accelerator.precision + kwargs["dtype"] = torch.float16 if precision == 16 else torch.float32 + model_parallel_context = deepspeed.zero.Init(**kwargs) else: model_parallel_context = super().model_sharded_context() From a138fa273393cc92e3ece3fd1292d33283319b4d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 14 Jun 2021 12:38:34 +0000 Subject: [PATCH 3/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/overrides/distributed.py | 2 +- pytorch_lightning/plugins/training_type/deepspeed.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 6f4e57d76efd9..a7827b0667582 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -63,7 +63,7 @@ def _find_tensors(obj): # pragma: no-cover # Note: Keep track of Pytorch DDP and update if there is a change # https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 def prepare_for_backward(model: DistributedDataParallel, output: Any): - # DDP based plugin are being subclasses + # DDP based plugin are being subclasses # and `prepare_for_backward` is `DistributedDataParallel` specific. if not isinstance(model, DistributedDataParallel): return diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 83b2412155244..e65959177c6b9 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import contextlib +import inspect import json import logging import os @@ -21,7 +22,7 @@ from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch -import inspect + import pytorch_lightning as pl from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.overrides.base import _LightningModuleWrapperBase From f60ee759bf01a081bcb29dc31aeb64a52ef70d9f Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 14 Jun 2021 13:47:47 +0100 Subject: [PATCH 4/8] update changelog --- CHANGELOG.md | 3 +++ pytorch_lightning/overrides/distributed.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 784a1581ee97a..b1fb641a7fca0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -226,6 +226,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924)) +- Support manual optimization with DeepSpeed ([#7970](https://github.com/PyTorchLightning/pytorch-lightning/pull/7970)) + + ## [1.3.5] - 2021-06-08 ### Added diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index a7827b0667582..47a93d05d32cc 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -63,7 +63,7 @@ def _find_tensors(obj): # pragma: no-cover # Note: Keep track of Pytorch DDP and update if there is a change # https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 def prepare_for_backward(model: DistributedDataParallel, output: Any): - # DDP based plugin are being subclasses + # DDP based plugin are being subclasses # and `prepare_for_backward` is `DistributedDataParallel` specific. if not isinstance(model, DistributedDataParallel): return From 94c72836b20e46d95d6206352cb933f7f001a34f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 15 Jun 2021 10:58:44 +0100 Subject: [PATCH 5/8] Simplify message --- pytorch_lightning/overrides/distributed.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/overrides/distributed.py b/pytorch_lightning/overrides/distributed.py index 47a93d05d32cc..71ed9c8018ec3 100644 --- a/pytorch_lightning/overrides/distributed.py +++ b/pytorch_lightning/overrides/distributed.py @@ -63,8 +63,7 @@ def _find_tensors(obj): # pragma: no-cover # Note: Keep track of Pytorch DDP and update if there is a change # https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638 def prepare_for_backward(model: DistributedDataParallel, output: Any): - # DDP based plugin are being subclasses - # and `prepare_for_backward` is `DistributedDataParallel` specific. + # `prepare_for_backward` is `DistributedDataParallel` specific. if not isinstance(model, DistributedDataParallel): return if torch.is_grad_enabled() and model.require_backward_grad_sync: From 222be315a9b5ceb74e7b61aff9690efe90548fe2 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 15 Jun 2021 13:48:42 +0100 Subject: [PATCH 6/8] Move from deprecated --- tests/plugins/test_deepspeed_plugin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 4587c99997fad..a95b734c57758 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -627,8 +627,8 @@ def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir): @RunIf(min_gpus=2, deepspeed=True, special=True) -@pytest.mark.parametrize('cpu_offload', [True, False]) -def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, cpu_offload): +@pytest.mark.parametrize('offload_optimizer', [True, False]) +def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer): """ Test to ensure with Stage 2 and multiple GPUs, accumulated grad batches works. """ @@ -649,7 +649,7 @@ def on_train_batch_start( default_root_dir=tmpdir, progress_bar_refresh_rate=0, max_epochs=5, - plugins=[DeepSpeedPlugin(stage=2, cpu_offload=cpu_offload)], + plugins=[DeepSpeedPlugin(stage=2, offload_optimizer=offload_optimizer)], gpus=2, limit_val_batches=2, precision=16, From 19789952058271cfafb028abcf960a81eedcd48e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 15 Jun 2021 14:11:54 +0100 Subject: [PATCH 7/8] Split model parallel/manual model --- tests/plugins/test_deepspeed_plugin.py | 43 ++++++++++++++++---------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index a95b734c57758..0ef037e192682 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -447,7 +447,7 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config): class ModelParallelClassificationModel(LightningModule): - def __init__(self, lr: float = 0.01, num_blocks: int = 5, automatic_optimization: bool = True): + def __init__(self, lr: float = 0.01, num_blocks: int = 5): super().__init__() self.lr = lr self.num_blocks = num_blocks @@ -455,8 +455,6 @@ def __init__(self, lr: float = 0.01, num_blocks: int = 5, automatic_optimization self.train_acc = Accuracy() self.valid_acc = Accuracy() self.test_acc = Accuracy() - self.automatic_optimization = automatic_optimization - self.training_step = self.training_step_automatic if self.automatic_optimization else self.training_step_manual def make_block(self): return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU()) @@ -471,7 +469,7 @@ def forward(self, x): logits = F.softmax(x, dim=1) return logits - def training_step_automatic(self, batch, batch_idx): + def training_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) loss = F.cross_entropy(logits, y) @@ -479,17 +477,6 @@ def training_step_automatic(self, batch, batch_idx): self.log('train_acc', self.train_acc(logits, y), prog_bar=True, sync_dist=True) return {"loss": loss} - def training_step_manual(self, batch, batch_idx): - x, y = batch - logits = self.forward(x) - loss = F.cross_entropy(logits, y) - opt = self.optimizers()[0] - self.log('train_loss', loss, prog_bar=True) - self.log('train_acc', self.train_acc(logits, y), prog_bar=True, sync_dist=True) - opt.zero_grad() - self.manual_backward(loss) - opt.step() - def validation_step(self, batch, batch_idx): x, y = batch logits = self.forward(x) @@ -518,6 +505,23 @@ def configure_optimizers(self): }] +class ManualModelParallelClassificationModel(ModelParallelClassificationModel): + + def automatic_optimization(self) -> bool: + return False + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self.forward(x) + loss = F.cross_entropy(logits, y) + opt = self.optimizers()[0] + self.log('train_loss', loss, prog_bar=True) + self.log('train_acc', self.train_acc(logits, y), prog_bar=True, sync_dist=True) + opt.zero_grad() + self.manual_backward(loss) + opt.step() + + @RunIf(min_gpus=2, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config): """ @@ -557,9 +561,14 @@ def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config _assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelManualOptim) -def run_checkpoint_test(tmpdir, save_full_weights, automatic_optimization=True, accumulate_grad_batches=2): +def run_checkpoint_test( + tmpdir: str, save_full_weights: bool, automatic_optimization: bool = True, accumulate_grad_batches: int = 2 +): seed_everything(1) - model = ModelParallelClassificationModel(automatic_optimization=automatic_optimization) + if automatic_optimization: + model = ModelParallelClassificationModel() + else: + model = ManualModelParallelClassificationModel() dm = ClassifDataModule() ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1) trainer = Trainer( From 8a36389933c1c486db7fddc7c1c363a480d95e3e Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Tue, 15 Jun 2021 22:35:08 +0100 Subject: [PATCH 8/8] Use property --- tests/plugins/test_deepspeed_plugin.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 0ef037e192682..65ccf05361d0e 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -38,7 +38,6 @@ class ModelParallelBoringModelManualOptim(BoringModel): def __init__(self): super().__init__() self.layer = None - self.automatic_optimization = False def training_step(self, batch, batch_idx): opt = self.optimizers()[0] @@ -54,6 +53,10 @@ def configure_sharded_model(self) -> None: def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self.configure_sharded_model() + @property + def automatic_optimization(self) -> bool: + return False + def test_deepspeed_lightning_module(tmpdir): """ @@ -507,6 +510,7 @@ def configure_optimizers(self): class ManualModelParallelClassificationModel(ModelParallelClassificationModel): + @property def automatic_optimization(self) -> bool: return False