Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add branch condition for calling move to device in prefetch (FSDP 3/n) #6342

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path


log = logging.getLogger(__name__)


Expand Down Expand Up @@ -253,8 +252,9 @@ def pre_dispatch(self):
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

# move the model to the correct device
self.model_to_device()
if self.call_move_to_device_hook_in_pre_dispatch:
# move the model to the correct device
self.model_to_device()

self.configure_ddp()

Expand Down Expand Up @@ -313,3 +313,11 @@ def predict(self, *args, **kwargs):
def post_training_step(self):
if not self.lightning_module.automatic_optimization:
self.model.require_backward_grad_sync = True

@property
def call_move_to_device_hook_in_pre_dispatch(self) -> bool:
"""
Call the ``model_to_device`` function within pre_dispatch if this is set to True.
Useful for when plugin would like to call model_to_device at another time, or skip the call.
"""
return True
Comment on lines +317 to +323
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be in ParallelPlugin?

15 changes: 12 additions & 3 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
self.sync_batchnorm = sync_batchnorm
self._ddp_kwargs = kwargs
self.dist = LightningDistributed()
self.num_processes = len(parallel_devices)
self.num_processes = len(parallel_devices) if parallel_devices is not None else parallel_devices
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
self.node_rank = 0
self.mp_queue = None

Expand Down Expand Up @@ -151,8 +151,9 @@ def new_process(self, process_idx, trainer, mp_queue):
if self.sync_batchnorm:
self.model = self.configure_sync_batchnorm(self.model)

# move the model to the correct device
self.model_to_device()
if self.call_move_to_device_hook_in_pre_dispatch:
# move the model to the correct device
self.model_to_device()

self.configure_ddp()

Expand Down Expand Up @@ -290,3 +291,11 @@ def predict(self, *args, **kwargs):
def post_training_step(self):
if not self.lightning_module.automatic_optimization:
self.model.require_backward_grad_sync = True

@property
def call_move_to_device_hook_in_pre_dispatch(self) -> bool:
"""
Call the ``model_to_device`` function within pre_dispatch if this is set to True.
Useful for when plugin would like to call model_to_device at another time, or skip the call.
"""
return True
28 changes: 27 additions & 1 deletion tests/accelerators/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import patch

import pytest
import torch

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DDPPlugin, DDPSpawnPlugin
from tests.accelerators import ddp_model, DDPLauncher
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -91,7 +93,6 @@ def test_torch_distributed_backend_env_variables(tmpdir):
_environ = {"PL_TORCH_DISTRIBUTED_BACKEND": "undefined", "CUDA_VISIBLE_DEVICES": "0,1", "WORLD_SIZE": "2"}
with patch.dict(os.environ, _environ), \
patch('torch.cuda.device_count', return_value=2):

with pytest.raises(ValueError, match="Invalid backend: 'undefined'"):
model = BoringModel()
trainer = Trainer(
Expand All @@ -102,3 +103,28 @@ def test_torch_distributed_backend_env_variables(tmpdir):
logger=False,
)
trainer.fit(model)


@pytest.mark.parametrize('move_to_device_pre_dispatch_enabled', [True, False])
@mock.patch('pytorch_lightning.plugins.DDPPlugin.model_to_device')
def test_move_to_device_in_pre_dispatch(mock_model_to_device, tmpdir, move_to_device_pre_dispatch_enabled):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two things:

  • Would've liked to add a spawn test but mocks don't seem to carry across to new processes?
  • Can I combine patch and parametrize?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Nope, you'd have to create the mock in each process
  2. What do you mean? Having different mocks for each parametrize? I don't think so

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might be able to do so with a callback, applying the patch context manager.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make patch conditional on parametrize. with the return object.

"""
Test if ``call_move_to_device_hook_in_pre_dispatch`` is disabled we do not move to device till later
in training.
"""

with mock.patch(
f'pytorch_lightning.plugins.DDPPlugin.call_move_to_device_hook_in_pre_dispatch',
move_to_device_pre_dispatch_enabled
):
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir, fast_dev_run=True, accelerator='ddp', plugins=DDPPlugin(), num_processes=1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to pass DDPPlugin?

)
trainer.fit(model)

# Check if mocked device was called. Since we're on CPU, model_to_device does nothing anyway.
if move_to_device_pre_dispatch_enabled:
mock_model_to_device.assert_called()
else:
mock_model_to_device.assert_not_called()