Skip to content

Commit

Permalink
[bug-fix] Call transfer_batch_to_device in DDPlugin (#5195)
Browse files Browse the repository at this point in the history
* hacking out

* update

* remove useless on_before_forward

* update

* remove overriden

* iremove os

* use on_before_forward

* resolve flake8

* add test

* update

* add single_process_per_device

* resolve flake8

* update

* resolve

* update

* update

* update

* add comment

* resolve bug with sharded

* update

* remove property

* update

* resolve test

* resolve bug

* update on comments

* update doc

* Update pytorch_lightning/core/hooks.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* update on comments

* Update pytorch_lightning/plugins/ddp_plugin.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* Update pytorch_lightning/plugins/ddp_plugin.py

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>

* resolve pep8

* add device_ids to pipe

* update on comments

* update

* resolve

* update

* update

* update

Co-authored-by: Ubuntu <ubuntu@ip-172-31-62-109.ec2.internal>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>

(cherry picked from commit d510707)
  • Loading branch information
tchaton authored and SeanNaren committed Jan 13, 2021
1 parent 7f4a8e5 commit 02f5159
Show file tree
Hide file tree
Showing 14 changed files with 80 additions and 24 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195))



## [1.1.3] - 2021-01-05

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def ddp_train(self, process_idx, mp_queue, model):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def ddp_train(self, process_idx, model):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def ddp_train(self, process_idx, model):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
6 changes: 3 additions & 3 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,9 @@ def transfer_batch_to_device(self, batch, device)
any other device than the one passed in as argument (unless you know what you are doing).
Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
This hook only runs on single GPU training and DDP.
If you need multi-GPU support for your custom batch objects in ``dp`` or ``ddp2``,
you need to define your custom :class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
Expand Down
24 changes: 13 additions & 11 deletions pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from contextlib import contextmanager
from typing import Any, Dict, List, Union

import torch
import torch.distributed as torch_distrib
from torch.optim import Optimizer

Expand Down Expand Up @@ -61,7 +62,7 @@ def configure_ddp(
def configure_ddp(self, model, device_ids):
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
model, device_ids=device_ids, find_unused_parameters=False
)
return model
Expand All @@ -73,9 +74,9 @@ def configure_ddp(self, model, device_ids):
the model wrapped in LightningDistributedDataParallel
"""
# if unset, default `find_unused_parameters` `True`
# if unset, default `find_unused_parameters` `False`
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
"find_unused_parameters", True
"find_unused_parameters", False
)
model = LightningDistributedDataParallel(
model,
Expand Down Expand Up @@ -106,22 +107,23 @@ def init_ddp_connection(
torch_backend, rank=global_rank, world_size=world_size
)

@property
def is_running_single_process_per_device(self) -> bool:
# objects do not need to be scattered in single process per device, move objects upfront to device
# This property is used in ``self.on_before_forward`` function.
return self.device_ids is not None and len(self.device_ids) == 1

def on_before_forward(self, model: LightningModule, *args):
"""
Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally
within the DDP wrapper.
Example::
def on_before_forward(self, model, *args):
batch, batch_idx = args
return batch.to(model.device)
Override to handle custom edge case.
Args:
args: Inputs to the model.
model: Model to train.
Returns: args moved to correct device if needed.
"""
if self.is_running_single_process_per_device:
args = model.transfer_batch_to_device(args, model.device)
return args

def optimizer_state(self, optimizer: Optimizer) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/ddp_sequential_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from torch import nn
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import LightningModule
from pytorch_lightning import _logger as log
from pytorch_lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
Expand Down
5 changes: 1 addition & 4 deletions pytorch_lightning/plugins/sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union, Any
from typing import Any, List, Optional, Union

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
Expand Down Expand Up @@ -42,9 +42,6 @@ def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]:
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

def on_before_forward(self, model: LightningModule, *args):
return model.transfer_batch_to_device(args, model.trainer.root_gpu)

def _check_fairscale(self):
if not _FAIRSCALE_AVAILABLE:
raise MisconfigurationException(
Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,14 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
return function(data, *args, **kwargs)

# Recursively apply to collection items
elif isinstance(data, Mapping):
if isinstance(data, Mapping):
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs)
for k, v in data.items()})
elif isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple

if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
elif isinstance(data, Sequence) and not isinstance(data, str):

if isinstance(data, Sequence) and not isinstance(data, str):
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])

# data is neither of dtype, nor a collection
Expand Down
46 changes: 45 additions & 1 deletion tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from unittest.mock import MagicMock

import pytest
Expand All @@ -20,7 +21,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from pytorch_lightning.trainer.states import TrainerState
from tests.base import BoringModel, EvalModelTemplate
from tests.base import BoringModel, EvalModelTemplate, RandomDataset


@pytest.mark.parametrize('max_steps', [1, 2, 3])
Expand Down Expand Up @@ -125,6 +126,49 @@ def transfer_batch_to_device(self, data, device):
assert batch_gpu.samples.device == batch_gpu.targets.device == expected


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_transfer_batch_hook_ddp(tmpdir):
"""
Test custom data are properly moved to the right device using ddp
"""

class CustomBatch:

def __init__(self, data):
self.samples = data[0]

def to(self, device, **kwargs):
self.samples = self.samples.to(device, **kwargs)
return self

def collate_fn(batch):
return CustomBatch(batch)

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert batch.samples.device == self.device
assert isinstance(batch_idx, int)

def train_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64), collate_fn=collate_fn)

model = TestModel()
model.validation_step = None
model.training_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=0,
max_epochs=1,
weights_summary=None,
accelerator="ddp",
gpus=2,
)
trainer.fit(model)


@pytest.mark.parametrize(
'max_epochs,batch_idx_',
[(2, 5), (3, 8), (4, 12)]
Expand Down
4 changes: 3 additions & 1 deletion tests/models/test_sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import torch.nn as nn
import torch.nn.functional as F

from pytorch_lightning import Trainer, seed_everything, LightningModule
from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.utilities import FLOAT16_EPSILON
from tests.base.datamodules import MNISTDataModule
from tests.base.develop_utils import set_random_master_port
Expand Down Expand Up @@ -108,6 +109,7 @@ def test_sync_batchnorm_ddp(tmpdir):
sync_batchnorm=True,
num_sanity_val_steps=0,
replace_sampler_ddp=False,
plugins=[DDPPlugin(find_unused_parameters=True)]
)

trainer.fit(model, dm)
Expand Down
1 change: 1 addition & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequent
python ${DEFAULTS} tests/utilities/test_all_gather_grad.py::test_all_gather_collection
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp

0 comments on commit 02f5159

Please sign in to comment.