Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Enable manual optimization DeepSpeed #7970

Merged
merged 11 commits into from
Jun 16, 2021
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))


- Support manual optimization with DeepSpeed ([#7970](https://github.com/PyTorchLightning/pytorch-lightning/pull/7970))


- Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941))


Expand Down
3 changes: 1 addition & 2 deletions dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ RUN \

RUN \
# install DeepSpeed
# TODO(@SeanNaren): CI failing with `>=0.3.15` - skipping to unblock
pip install deepspeed==0.3.14
pip install deepspeed>=0.4.0

RUN \
# Show what we have
Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def _find_tensors(obj): # pragma: no-cover
# Note: Keep track of Pytorch DDP and update if there is a change
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
def prepare_for_backward(model: DistributedDataParallel, output: Any):
# DDP based plugin are being subclasses
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
# and `prepare_for_backward` is `DistributedDataParallel` specific.
if not isinstance(model, DistributedDataParallel):
return
if torch.is_grad_enabled() and model.require_backward_grad_sync:
model.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:

def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
if not self.lightning_module.automatic_optimization:
prepare_for_backward(self.model, closure_loss)

def model_to_device(self):
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import json
import logging
import os
Expand Down Expand Up @@ -312,7 +313,15 @@ def _initialize_deepspeed_train(self, model):
@contextlib.contextmanager
def model_sharded_context(self) -> Generator[None, None, None]:
if self.zero_stage_3:
model_parallel_context = deepspeed.zero.Init(remote_device="cpu", pin_memory=True)
kwargs = {
"remote_device": "cpu",
"pin_memory": True,
}
# from DeepSpeed 0.4.0, weights need to be properly casted before calling `deepspeed.initialize`.
if "dtype" in inspect.signature(deepspeed.zero.Init).parameters.keys():
precision: int = self.lightning_module.trainer.accelerator.precision
kwargs["dtype"] = torch.float16 if precision == 16 else torch.float32
model_parallel_context = deepspeed.zero.Init(**kwargs)
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
else:
model_parallel_context = super().model_sharded_context()

Expand Down
78 changes: 71 additions & 7 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,32 @@ class ModelParallelBoringModel(BoringModel):

def __init__(self):
super().__init__()
self.linear = None
self.layer = None

def configure_sharded_model(self) -> None:
self.linear = torch.nn.Linear(32, 2)
self.layer = torch.nn.Linear(32, 2)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.configure_sharded_model()


class ModelParallelBoringModelManualOptim(BoringModel):

def __init__(self):
super().__init__()
self.layer = None
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
opt = self.optimizers()[0]
output = self(batch)
loss = self.loss(batch, output)
opt.zero_grad()
self.manual_backward(loss)
opt.step()

def configure_sharded_model(self) -> None:
self.layer = torch.nn.Linear(32, 2)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.configure_sharded_model()
Expand Down Expand Up @@ -425,14 +447,16 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config):

class ModelParallelClassificationModel(LightningModule):

def __init__(self, lr: float = 0.01, num_blocks: int = 5):
def __init__(self, lr: float = 0.01, num_blocks: int = 5, automatic_optimization: bool = True):
super().__init__()
self.lr = lr
self.num_blocks = num_blocks

self.train_acc = Accuracy()
self.valid_acc = Accuracy()
self.test_acc = Accuracy()
self.automatic_optimization = automatic_optimization
self.training_step = self.training_step_automatic if self.automatic_optimization else self.training_step_manual

def make_block(self):
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())
Expand All @@ -447,14 +471,25 @@ def forward(self, x):
logits = F.softmax(x, dim=1)
return logits

def training_step(self, batch, batch_idx):
def training_step_automatic(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
self.log('train_loss', loss, prog_bar=True)
self.log('train_acc', self.train_acc(logits, y), prog_bar=True, sync_dist=True)
return {"loss": loss}

def training_step_manual(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
opt = self.optimizers()[0]
self.log('train_loss', loss, prog_bar=True)
self.log('train_acc', self.train_acc(logits, y), prog_bar=True, sync_dist=True)
opt.zero_grad()
self.manual_backward(loss)
opt.step()

def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
Expand Down Expand Up @@ -502,9 +537,29 @@ def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config):
_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModel)


def run_checkpoint_test(tmpdir, save_full_weights):
@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config):
"""
Test to ensure ZeRO Stage 3 works with a parallel model.
"""
model = ModelParallelBoringModelManualOptim()
model.training_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
plugins=[DeepSpeedPlugin(stage=3)],
gpus=2,
fast_dev_run=True,
precision=16,
)
trainer.fit(model)
trainer.test(model)

_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelManualOptim)


def run_checkpoint_test(tmpdir, save_full_weights, automatic_optimization=True, accumulate_grad_batches=2):
seed_everything(1)
model = ModelParallelClassificationModel()
model = ModelParallelClassificationModel(automatic_optimization=automatic_optimization)
dm = ClassifDataModule()
ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1)
trainer = Trainer(
Expand All @@ -514,7 +569,7 @@ def run_checkpoint_test(tmpdir, save_full_weights):
plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)],
gpus=2,
precision=16,
accumulate_grad_batches=2,
accumulate_grad_batches=accumulate_grad_batches,
callbacks=[ck]
)
trainer.fit(model, datamodule=dm)
Expand Down Expand Up @@ -562,6 +617,15 @@ def test_deepspeed_multigpu_stage_3_checkpointing_full_weights(tmpdir):
run_checkpoint_test(tmpdir, save_full_weights=True)


@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir):
"""
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
where we save the full weights to one file.
"""
run_checkpoint_test(tmpdir, save_full_weights=True, automatic_optimization=False, accumulate_grad_batches=1)


@RunIf(min_gpus=2, deepspeed=True, special=True)
@pytest.mark.parametrize('cpu_offload', [True, False])
def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, cpu_offload):
Expand Down