Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] Enable manual optimization DeepSpeed #7970

Merged
merged 11 commits into from
Jun 16, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a bug where `precision=64` with `accelerator='ddp_spawn'` would throw a pickle error ([#6924](https://github.com/PyTorchLightning/pytorch-lightning/pull/6924))


- Support manual optimization with DeepSpeed ([#7970](https://github.com/PyTorchLightning/pytorch-lightning/pull/7970))


- Fixed `dataloader_idx` argument value when predicting with only one `DataLoader` ([#7941](https://github.com/PyTorchLightning/pytorch-lightning/pull/7941))


Expand Down
3 changes: 3 additions & 0 deletions pytorch_lightning/overrides/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def _find_tensors(obj): # pragma: no-cover
# Note: Keep track of Pytorch DDP and update if there is a change
# https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/parallel/distributed.py#L626-L638
def prepare_for_backward(model: DistributedDataParallel, output: Any):
# `prepare_for_backward` is `DistributedDataParallel` specific.
if not isinstance(model, DistributedDataParallel):
return
if torch.is_grad_enabled() and model.require_backward_grad_sync:
model.require_forward_param_sync = True
# We'll return the output object verbatim since it is a freeform
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def broadcast(self, obj: object, src: int = 0) -> object:

def pre_backward(self, closure_loss: torch.Tensor, should_accumulate: bool, optimizer: Optimizer, opt_idx: int):
"""Run before precision plugin executes backward"""
if not self.lightning_module.automatic_optimization and self.model.require_backward_grad_sync:
if not self.lightning_module.automatic_optimization:
prepare_for_backward(self.model, closure_loss)

def model_to_device(self):
Expand Down
93 changes: 85 additions & 8 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,40 @@ class ModelParallelBoringModel(BoringModel):

def __init__(self):
super().__init__()
self.linear = None
self.layer = None

def configure_sharded_model(self) -> None:
self.linear = torch.nn.Linear(32, 2)
self.layer = torch.nn.Linear(32, 2)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.configure_sharded_model()


class ModelParallelBoringModelManualOptim(BoringModel):

def __init__(self):
super().__init__()
self.layer = None

def training_step(self, batch, batch_idx):
opt = self.optimizers()[0]
output = self(batch)
loss = self.loss(batch, output)
opt.zero_grad()
self.manual_backward(loss)
opt.step()

def configure_sharded_model(self) -> None:
self.layer = torch.nn.Linear(32, 2)

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.configure_sharded_model()

@property
def automatic_optimization(self) -> bool:
return False


def test_deepspeed_lightning_module(tmpdir):
"""
Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly.
Expand Down Expand Up @@ -483,6 +508,24 @@ def configure_optimizers(self):
}]


class ManualModelParallelClassificationModel(ModelParallelClassificationModel):

@property
def automatic_optimization(self) -> bool:
return False

def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
opt = self.optimizers()[0]
self.log('train_loss', loss, prog_bar=True)
self.log('train_acc', self.train_acc(logits, y), prog_bar=True, sync_dist=True)
opt.zero_grad()
self.manual_backward(loss)
opt.step()


@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config):
"""
Expand All @@ -502,9 +545,34 @@ def test_deepspeed_multigpu_stage_3(tmpdir, deepspeed_config):
_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModel)


def run_checkpoint_test(tmpdir, save_full_weights):
@RunIf(min_gpus=2, deepspeed=True, special=True)
def test_deepspeed_multigpu_stage_3_manual_optimization(tmpdir, deepspeed_config):
"""
Test to ensure ZeRO Stage 3 works with a parallel model.
"""
model = ModelParallelBoringModelManualOptim()
model.training_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
plugins=[DeepSpeedPlugin(stage=3)],
gpus=2,
fast_dev_run=True,
precision=16,
)
trainer.fit(model)
trainer.test(model)

_assert_save_model_is_equal(model, tmpdir, trainer, cls=ModelParallelBoringModelManualOptim)


def run_checkpoint_test(
tmpdir: str, save_full_weights: bool, automatic_optimization: bool = True, accumulate_grad_batches: int = 2
):
seed_everything(1)
model = ModelParallelClassificationModel()
if automatic_optimization:
model = ModelParallelClassificationModel()
else:
model = ManualModelParallelClassificationModel()
dm = ClassifDataModule()
ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1)
trainer = Trainer(
Expand All @@ -514,7 +582,7 @@ def run_checkpoint_test(tmpdir, save_full_weights):
plugins=[DeepSpeedPlugin(stage=3, save_full_weights=save_full_weights)],
gpus=2,
precision=16,
accumulate_grad_batches=2,
accumulate_grad_batches=accumulate_grad_batches,
callbacks=[ck]
)
trainer.fit(model, datamodule=dm)
Expand Down Expand Up @@ -563,8 +631,17 @@ def test_deepspeed_multigpu_stage_3_checkpointing_full_weights(tmpdir):


@RunIf(min_gpus=2, deepspeed=True, special=True)
@pytest.mark.parametrize('cpu_offload', [True, False])
def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, cpu_offload):
def test_deepspeed_multigpu_stage_3_checkpointing_full_weights_manual(tmpdir):
"""
Test to ensure with Stage 3 and multiple GPUs that we can save/load a model resuming from a checkpoint,
where we save the full weights to one file.
"""
run_checkpoint_test(tmpdir, save_full_weights=True, automatic_optimization=False, accumulate_grad_batches=1)


@RunIf(min_gpus=2, deepspeed=True, special=True)
@pytest.mark.parametrize('offload_optimizer', [True, False])
def test_deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimizer):
"""
Test to ensure with Stage 2 and multiple GPUs, accumulated grad batches works.
"""
Expand All @@ -585,7 +662,7 @@ def on_train_batch_start(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=5,
plugins=[DeepSpeedPlugin(stage=2, cpu_offload=cpu_offload)],
plugins=[DeepSpeedPlugin(stage=2, offload_optimizer=offload_optimizer)],
gpus=2,
limit_val_batches=2,
precision=16,
Expand Down