Skip to content

Commit

Permalink
[Hot Fix] Give priority to plugins to set distributed mode, and then …
Browse files Browse the repository at this point in the history
…accelerator (#6089)

* Give priority to plugins to set distributed mode, and then accelerator

* Add CHANGELOG.md

* Update CHANGELOG.md

* Remove very scary line

* Ensure we set cluster environment after slurm configured if necessary

* Simplify the fix with a reset

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
SeanNaren and carmocca authored Feb 20, 2021
1 parent 3bdc067 commit 97a81c3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011)


- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))


## [1.2.0] - 2021-02-18

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ def handle_given_plugins(

for plug in plugins:
if isinstance(plug, str):
# Reset the distributed type as the user has overridden training type
# via the plugins argument
self._distrib_type = None
self.set_distributed_mode(plug)

elif isinstance(plug, TrainingTypePlugin):
Expand Down Expand Up @@ -196,7 +199,6 @@ def handle_given_plugins(
)

self._training_type_plugin = training_type
self._training_type_plugin = self.training_type_plugin
self._precision_plugin = precision
self._cluster_environment = cluster_environment or self.select_cluster_environment()

Expand Down
24 changes: 23 additions & 1 deletion tests/accelerators/test_accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.accelerators.gpu import GPUAccelerator
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins import DDP2Plugin, DDPPlugin, DDPSpawnPlugin, PrecisionPlugin, SingleDevicePlugin
from pytorch_lightning.plugins import (
DDP2Plugin,
DDPPlugin,
DDPShardedPlugin,
DDPSpawnPlugin,
PrecisionPlugin,
SingleDevicePlugin,
)
from pytorch_lightning.plugins.environments import ClusterEnvironment, SLURMEnvironment, TorchElasticEnvironment
from tests.helpers.boring_model import BoringModel

Expand Down Expand Up @@ -378,3 +385,18 @@ def on_fit_start(self, trainer, pl_module):

with pytest.raises(SystemExit):
trainer.fit(model)


@pytest.mark.parametrize(
["accelerator", "plugin"],
[('ddp_spawn', 'ddp_sharded'), (None, 'ddp_sharded')],
)
def test_plugin_accelerator_choice(accelerator, plugin):
"""
Ensure that when a plugin and accelerator is passed in, that the plugin takes precedent.
"""
trainer = Trainer(accelerator=accelerator, plugins=plugin, num_processes=2)
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)

trainer = Trainer(plugins=plugin, num_processes=2)
assert isinstance(trainer.accelerator.training_type_plugin, DDPShardedPlugin)

0 comments on commit 97a81c3

Please sign in to comment.