Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DeepSpeed] fix flag forwarding in DeepSpeedPlugin #10899

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed support for `CombinedLoader` while checking for warning raised with eval dataloaders ([#10994](https://github.com/PyTorchLightning/pytorch-lightning/pull/10994))


-
- Fixed a bug where the DeepSpeedPlugin arguments `cpu_checkpointing` and `contiguous_memory_optimization` were not being forwarded to deepspeed correctly ([#10874](https://github.com/PyTorchLightning/pytorch-lightning/issues/10874))


-
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,8 @@ def _set_deepspeed_activation_checkpointing(self):
deepspeed.checkpointing.configure(
mpu_=None,
partition_activations=checkpoint_config.get("partition_activations"),
contiguous_checkpointing=checkpoint_config.get("contiguous_checkpointing"),
checkpoint_in_cpu=checkpoint_config.get("checkpoint_in_cpu"),
contiguous_checkpointing=checkpoint_config.get("contiguous_memory_optimization"),
checkpoint_in_cpu=checkpoint_config.get("cpu_checkpointing"),
profile=checkpoint_config.get("profile"),
)

Expand Down
30 changes: 30 additions & 0 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,36 @@ def test_deepspeed_custom_activation_checkpointing_params(tmpdir):
assert checkpoint_config["synchronize_checkpoint_boundary"]


@RunIf(min_gpus=1, deepspeed=True, standalone=True)
def test_deepspeed_custom_activation_checkpointing_params_forwarded(tmpdir):
"""Ensure if we modify the activation checkpointing parameters, we pass these to
deepspeed.checkpointing.configure correctly."""
ds = DeepSpeedPlugin(
partition_activations=True,
cpu_checkpointing=True,
contiguous_memory_optimization=True,
synchronize_checkpoint_boundary=True,
)

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
fast_dev_run=1,
strategy=ds,
precision=16,
gpus=1,
)
with mock.patch(
"deepspeed.checkpointing.configure", wraps=deepspeed.checkpointing.configure
) as deepspeed_checkpointing_configure:
trainer.fit(model)

deepspeed_checkpointing_configure.assert_called_with(
mpu_=None, partition_activations=True, contiguous_checkpointing=True, checkpoint_in_cpu=True, profile=None
)


@RunIf(min_gpus=1, deepspeed=True)
def test_deepspeed_assert_config_zero_offload_disabled(tmpdir, deepspeed_zero_config):
"""Ensure if we use a config and turn off offload_optimizer, that this is set to False within the config."""
Expand Down