Skip to content

Commit

Permalink
[bug-fix] Call transfer_batch_to_device in DDPlugin (#5195)
Browse files Browse the repository at this point in the history
* hacking out

* update

* remove useless on_before_forward

* update

* remove overriden

* iremove os

* use on_before_forward

* resolve flake8

* add test

* update

* add single_process_per_device

* resolve flake8

* update

* resolve

* update

* update

* update

* add comment

* resolve bug with sharded

* update

* remove property

* update

* resolve test

* resolve bug

* update on comments

* update doc

* Update pytorch_lightning/core/hooks.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* update on comments

* Update pytorch_lightning/plugins/ddp_plugin.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/plugins/ddp_plugin.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* resolve pep8

* add device_ids to pipe

* update on comments

* update

* resolve

* update

* update

* update

Co-authored-by: Ubuntu <ubuntu@ip-172-31-62-109.ec2.internal>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

(cherry picked from commit d510707)
  • Loading branch information
tchaton authored and Borda committed Jan 23, 2021
1 parent 450956f commit 9fa9b44
Show file tree
Hide file tree
Showing 14 changed files with 74 additions and 21 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195))



## [1.1.3] - 2021-01-05

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def ddp_train(self, process_idx, mp_queue, model):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def ddp_train(self, process_idx, model):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def ddp_train(self, process_idx, model):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,9 @@ def transfer_batch_to_device(self, batch, device)
any other device than the one passed in as argument (unless you know what you are doing).
Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` and
This hook only runs on single GPU training and DDP (no data-parallel).
If you need multi-GPU support for your custom batch objects in ``dp`` or ``ddp2``,
you need to define your custom :class:`~torch.nn.parallel.DistributedDataParallel` or
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
See Also:
Expand Down
17 changes: 9 additions & 8 deletions pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,23 @@ def init_ddp_connection(
torch_backend, rank=global_rank, world_size=world_size
)

@property
def is_running_single_process_per_device(self) -> bool:
# objects do not need to be scattered in single process per device, move objects upfront to device
# This property is used in ``self.on_before_forward`` function.
return self.device_ids is not None and len(self.device_ids) == 1

def on_before_forward(self, model: LightningModule, *args):
"""
Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally
within the DDP wrapper.
Example::
def on_before_forward(self, model, *args):
batch, batch_idx = args
return batch.to(model.device)
Override to handle custom edge case.
Args:
args: Inputs to the model.
model: Model to train.
Returns: args moved to correct device if needed.
"""
if self.is_running_single_process_per_device:
args = model.transfer_batch_to_device(args, model.device)
return args

def optimizer_state(self, optimizer: Optimizer) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/ddp_sequential_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from torch import nn
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import LightningModule
from pytorch_lightning import _logger as log
from pytorch_lightning import LightningModule
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/plugins/sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]:
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

def on_before_forward(self, model: LightningModule, *args):
return model.transfer_batch_to_device(args, model.trainer.root_gpu)

def _check_fairscale(self):
if not _FAIRSCALE_AVAILABLE:
raise MisconfigurationException(
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,14 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
return function(data, *args, **kwargs)

# Recursively apply to collection items
elif isinstance(data, Mapping):
if isinstance(data, Mapping):
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs)
for k, v in data.items()})
elif isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple

if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
elif isinstance(data, Sequence) and not isinstance(data, str):

if isinstance(data, Sequence) and not isinstance(data, str):
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])

# data is neither of dtype, nor a collection
Expand Down
46 changes: 45 additions & 1 deletion tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from unittest.mock import MagicMock

import pytest
Expand All @@ -20,7 +21,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from pytorch_lightning.trainer.states import TrainerState
from tests.base import BoringModel, EvalModelTemplate
from tests.base import BoringModel, EvalModelTemplate, RandomDataset


@pytest.mark.parametrize('max_steps', [1, 2, 3])
Expand Down Expand Up @@ -125,6 +126,49 @@ def transfer_batch_to_device(self, data, device):
assert batch_gpu.samples.device == batch_gpu.targets.device == expected


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_transfer_batch_hook_ddp(tmpdir):
"""
Test custom data are properly moved to the right device using ddp
"""

class CustomBatch:

def __init__(self, data):
self.samples = data[0]

def to(self, device, **kwargs):
self.samples = self.samples.to(device, **kwargs)
return self

def collate_fn(batch):
return CustomBatch(batch)

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert batch.samples.device == self.device
assert isinstance(batch_idx, int)

def train_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64), collate_fn=collate_fn)

model = TestModel()
model.validation_step = None
model.training_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=0,
max_epochs=1,
weights_summary=None,
accelerator="ddp",
gpus=2,
)
trainer.fit(model)


@pytest.mark.parametrize(
'max_epochs,batch_idx_',
[(2, 5), (3, 8), (4, 12)]
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import torch.nn.functional as F

from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.utilities import FLOAT16_EPSILON
from tests.base.datamodules import MNISTDataModule
from tests.base.develop_utils import set_random_master_port
Expand Down
3 changes: 2 additions & 1 deletion tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequent
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
python ${DEFAULTS} tests/trainer/logging_process/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp

0 comments on commit 9fa9b44

Please sign in to comment.