Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug-fix] Call transfer_batch_to_device in DDPlugin #5195

Merged
merged 62 commits into from
Jan 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
d156d8a
hacking out
Dec 19, 2020
658dfcd
update
Dec 19, 2020
526e9e6
remove useless on_before_forward
Dec 20, 2020
7396760
update
Dec 20, 2020
dba7875
remove overriden
Dec 20, 2020
895dc72
iremove os
Dec 20, 2020
3b08bd9
Merge branch 'master' into pyg
tchaton Dec 21, 2020
ed0eb42
use on_before_forward
Dec 21, 2020
461a978
Merge branch 'pyg' of https://github.com/PyTorchLightning/pytorch-lig…
Dec 21, 2020
bfe3793
Merge branch 'master' into pyg
tchaton Dec 22, 2020
76fd6c9
Merge branch 'master' into pyg
tchaton Dec 23, 2020
31587f9
resolve flake8
tchaton Dec 23, 2020
bbea81a
Merge branch 'master' into pyg
tchaton Dec 23, 2020
0fd5a41
Merge branch 'master' into pyg
tchaton Dec 28, 2020
2902592
Merge branch 'master' into pyg
tchaton Dec 29, 2020
27f0955
Merge branch 'master' into pyg
tchaton Jan 4, 2021
819d330
Merge branch 'master' into pyg
tchaton Jan 4, 2021
d81a7d0
add test
Jan 4, 2021
ddd7a1c
update
tchaton Jan 4, 2021
f7e30be
add single_process_per_device
Jan 4, 2021
861ae11
Merge branch 'pyg' of https://github.com/PyTorchLightning/pytorch-lig…
Jan 4, 2021
a7e4a41
resolve flake8
tchaton Jan 4, 2021
223ab0e
update
tchaton Jan 4, 2021
6fbe83d
Merge branch 'master' into pyg
tchaton Jan 4, 2021
fd1b63e
resolve
Jan 4, 2021
5d60b49
update
Jan 4, 2021
7cc7760
update
tchaton Jan 4, 2021
f5c2843
update
tchaton Jan 4, 2021
70cde1d
add comment
tchaton Jan 4, 2021
0c7690d
resolve bug with sharded
Jan 4, 2021
f999be1
Merge branch 'pyg' of https://github.com/PyTorchLightning/pytorch-lig…
Jan 4, 2021
2506d3c
update
tchaton Jan 4, 2021
ae071a3
Merge branch 'master' into pyg
tchaton Jan 4, 2021
cdede9a
remove property
tchaton Jan 4, 2021
f38232b
update
tchaton Jan 4, 2021
e66d284
Merge branch 'pyg' of https://github.com/PyTorchLightning/pytorch-lig…
tchaton Jan 4, 2021
5527455
resolve test
tchaton Jan 4, 2021
bf94fa3
resolve bug
tchaton Jan 4, 2021
b23d5ba
update on comments
tchaton Jan 4, 2021
ddad54e
Merge branch 'master' into pyg
tchaton Jan 4, 2021
c337461
update doc
tchaton Jan 4, 2021
d1ef718
Merge branch 'pyg' of https://github.com/PyTorchLightning/pytorch-lig…
tchaton Jan 4, 2021
678e0b9
Update pytorch_lightning/core/hooks.py
tchaton Jan 5, 2021
5d0bcf1
update on comments
tchaton Jan 5, 2021
c911849
Merge branch 'master' into pyg
tchaton Jan 5, 2021
a667861
Update pytorch_lightning/plugins/ddp_plugin.py
tchaton Jan 5, 2021
c75466f
Update pytorch_lightning/plugins/ddp_plugin.py
tchaton Jan 5, 2021
2f4165d
Merge branch 'master' into pyg
tchaton Jan 5, 2021
dd6189b
Merge branch 'master' into pyg
tchaton Jan 5, 2021
d780c16
resolve pep8
tchaton Jan 5, 2021
ff96053
add device_ids to pipe
tchaton Jan 5, 2021
b16b611
update on comments
tchaton Jan 6, 2021
6ed7361
update
tchaton Jan 6, 2021
e71d8cd
resolve
tchaton Jan 6, 2021
525d26d
update
tchaton Jan 6, 2021
779a672
Merge branch 'master' into pyg
tchaton Jan 6, 2021
34def10
update
tchaton Jan 6, 2021
57b037e
Merge branch 'master' into pyg
tchaton Jan 6, 2021
f679440
update
tchaton Jan 6, 2021
37fb96e
Merge branch 'master' into pyg
tchaton Jan 6, 2021
64ef06a
Merge branch 'master' into pyg
SeanNaren Jan 7, 2021
e0358df
Merge branch 'master' into pyg
tchaton Jan 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `transfer_batch_to_device` for DDP with `len(devices_ids) == 1` ([#5195](https://github.com/PyTorchLightning/pytorch-lightning/pull/5195))



## [1.1.3] - 2021-01-05

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/ddp2_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available

if HYDRA_AVAILABLE:
Expand Down Expand Up @@ -213,6 +213,7 @@ def ddp_train(self, process_idx, mp_queue, model):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/ddp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
Expand Down Expand Up @@ -314,6 +314,7 @@ def ddp_train(self, process_idx, model):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
Expand Down Expand Up @@ -241,6 +241,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/ddp_hpc_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available

if HYDRA_AVAILABLE:
Expand Down Expand Up @@ -205,6 +205,7 @@ def ddp_train(self, process_idx, model):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/accelerators/ddp_spawn_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import AMPType, HYDRA_AVAILABLE
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
Expand Down Expand Up @@ -273,6 +273,7 @@ def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results):
def configure_ddp(
self, model: LightningModule, device_ids: List[int]
) -> DistributedDataParallel:
self.ddp_plugin.device_ids = device_ids
model = self.ddp_plugin.configure_ddp(model, device_ids)
return model

Expand Down
9 changes: 5 additions & 4 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from typing import Any, Dict, List, Optional, Union

import torch
from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

from pytorch_lightning.utilities import move_data_to_device, rank_zero_warn


class ModelHooks:
"""Hooks to be used in LightningModule."""
Expand Down Expand Up @@ -539,9 +540,9 @@ def transfer_batch_to_device(self, batch, device)
any other device than the one passed in as argument (unless you know what you are doing).

Note:
This hook only runs on single GPU training (no data-parallel). If you need multi-GPU support
for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
This hook only runs on single GPU training and DDP.
If you need multi-GPU support for your custom batch objects in ``dp`` or ``ddp2``,
you need to define your custom :class:`~torch.nn.parallel.DistributedDataParallel` or
tchaton marked this conversation as resolved.
Show resolved Hide resolved
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.

Expand Down
26 changes: 14 additions & 12 deletions pytorch_lightning/plugins/ddp_plugin.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Union

import torch
import torch.distributed as torch_distrib
from torch.optim import Optimizer

Expand Down Expand Up @@ -47,7 +48,7 @@ def configure_ddp(

def configure_ddp(self, model, device_ids):
model = LightningDistributedDataParallel(
model, device_ids=device_ids, find_unused_parameters=True
model, device_ids=device_ids, find_unused_parameters=False
)
return model

Expand All @@ -59,9 +60,9 @@ def configure_ddp(self, model, device_ids):
the model wrapped in LightningDistributedDataParallel

"""
# if unset, default `find_unused_parameters` `True`
# if unset, default `find_unused_parameters` `False`
self._ddp_kwargs["find_unused_parameters"] = self._ddp_kwargs.get(
"find_unused_parameters", True
"find_unused_parameters", False
tchaton marked this conversation as resolved.
Show resolved Hide resolved
)
model = LightningDistributedDataParallel(
model,
Expand Down Expand Up @@ -91,22 +92,23 @@ def init_ddp_connection(
torch_backend, rank=global_rank, world_size=world_size
)

@property
def is_running_single_process_per_device(self) -> bool:
# objects do not need to be scattered in single process per device, move objects upfront to device
# This property is used in ``self.on_before_forward`` function.
return self.device_ids is not None and len(self.device_ids) == 1

def on_before_forward(self, model: LightningModule, *args):
"""
Override to handle custom input to device logic. For DDP, no logic is required as this is handled internally
within the DDP wrapper.

Example::

def on_before_forward(self, model, *args):
batch, batch_idx = args
return batch.to(model.device)
Override to handle custom edge case.

Args:
args: Inputs to the model.
model: Model to train.
Returns: args moved to correct device if needed.
"""
if self.is_running_single_process_per_device:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
args = model.transfer_batch_to_device(args, model.device)
return args

def optimizer_state(self, optimizer: Optimizer) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/ddp_sequential_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from torch import nn
from torch.nn.parallel import DistributedDataParallel

from pytorch_lightning import LightningModule
from pytorch_lightning import _logger as log
from pytorch_lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/plugins/sharded_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union, Any
from typing import Any, List, Optional, Union

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, AMPType, rank_zero_only
from pytorch_lightning.utilities import AMPType, FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if FAIRSCALE_AVAILABLE:
Expand All @@ -42,9 +42,6 @@ def optimizer_state(self, optimizer: 'OSS') -> Optional[dict]:
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

def on_before_forward(self, model: LightningModule, *args):
return model.transfer_batch_to_device(args, model.trainer.root_gpu)

def _check_fairscale(self):
if not FAIRSCALE_AVAILABLE:
raise MisconfigurationException(
Expand Down
1 change: 0 additions & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import contextmanager
from copy import copy, deepcopy

Expand Down
8 changes: 5 additions & 3 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@ def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable
return function(data, *args, **kwargs)

# Recursively apply to collection items
elif isinstance(data, Mapping):
if isinstance(data, Mapping):
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs)
for k, v in data.items()})
elif isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple

if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
elif isinstance(data, Sequence) and not isinstance(data, str):

if isinstance(data, Sequence) and not isinstance(data, str):
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])

# data is neither of dtype, nor a collection
Expand Down
48 changes: 46 additions & 2 deletions tests/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from unittest.mock import MagicMock

import pytest
import torch
from unittest.mock import MagicMock

from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
from tests.base import EvalModelTemplate, BoringModel
from tests.base import BoringModel, EvalModelTemplate, RandomDataset


@pytest.mark.parametrize('max_steps', [1, 2, 3])
Expand Down Expand Up @@ -124,6 +125,49 @@ def transfer_batch_to_device(self, data, device):
assert batch_gpu.samples.device == batch_gpu.targets.device == expected


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_transfer_batch_hook_ddp(tmpdir):
"""
Test custom data are properly moved to the right device using ddp
"""

class CustomBatch:

def __init__(self, data):
self.samples = data[0]

def to(self, device, **kwargs):
self.samples = self.samples.to(device, **kwargs)
return self

def collate_fn(batch):
return CustomBatch(batch)

class TestModel(BoringModel):
def training_step(self, batch, batch_idx):
assert batch.samples.device == self.device
assert isinstance(batch_idx, int)

def train_dataloader(self):
return torch.utils.data.DataLoader(RandomDataset(32, 64), collate_fn=collate_fn)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

model = TestModel()
model.validation_step = None
model.training_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=0,
max_epochs=1,
weights_summary=None,
accelerator="ddp",
gpus=2,
)
trainer.fit(model)


@pytest.mark.parametrize(
'max_epochs,batch_idx_',
[(2, 5), (3, 8), (4, 12)]
Expand Down
4 changes: 3 additions & 1 deletion tests/models/test_sync_batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
import torch.nn as nn
import torch.nn.functional as F

from pytorch_lightning import Trainer, seed_everything, LightningModule
from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.core.step_result import TrainResult
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.utilities import FLOAT16_EPSILON
from tests.base.datamodules import MNISTDataModule
from tests.base.develop_utils import set_random_master_port
Expand Down Expand Up @@ -108,6 +109,7 @@ def test_sync_batchnorm_ddp(tmpdir):
sync_batchnorm=True,
num_sanity_val_steps=0,
replace_sampler_ddp=False,
plugins=[DDPPlugin(find_unused_parameters=True)]
)

result = trainer.fit(model, dm)
Expand Down
1 change: 1 addition & 0 deletions tests/special_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequent
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp
python ${DEFAULTS} tests/models/test_hooks.py::test_transfer_batch_hook_ddp
carmocca marked this conversation as resolved.
Show resolved Hide resolved