Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable batch transfer in DP mode #6098

Merged
merged 20 commits into from
Mar 11, 2021
Merged
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107))


- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))


## [1.2.0] - 2021-02-18

### Added
Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import os
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if TYPE_CHECKING:
Expand Down Expand Up @@ -48,3 +49,11 @@ def set_nvidia_flags() -> None:
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]")

def to_device(self, batch: Any) -> Any:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
# no need to transfer batch to device in DP mode
# TODO: Add support to allow batch transfer to device in Lightning for DP mode.
if not isinstance(self.training_type_plugin, DataParallelPlugin):
batch = super().to_device(batch)

return batch
21 changes: 16 additions & 5 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,10 +615,7 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] =

Note:
This hook only runs on single GPU training and DDP (no data-parallel).
If you need multi-GPU support for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
Data-Parallel support will come in near future.

Args:
batch: A batch of data that needs to be transferred to a new device.
Expand All @@ -638,6 +635,10 @@ def transfer_batch_to_device(self, batch, device):
batch = super().transfer_batch_to_device(data, device)
return batch

Raises:
MisconfigurationException:
If using data-parallel, ``Trainer(accelerator='dp')``.

See Also:
- :meth:`move_data_to_device`
- :meth:`apply_to_collection`
Expand All @@ -649,10 +650,11 @@ def on_before_batch_transfer(self, batch, dataloader_idx):
"""
Override to alter or apply batch augmentations to your batch before it is transferred to the device.

.. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future.
.. warning:: ``dataloader_idx`` always returns 0, and will be updated to support the true index in the future.

Note:
This hook only runs on single GPU training and DDP (no data-parallel).
Data-Parallel support will come in near future.

Args:
batch: A batch of data that needs to be altered or augmented.
Expand All @@ -667,6 +669,10 @@ def on_before_batch_transfer(self, batch, dataloader_idx):
batch['x'] = transforms(batch['x'])
return batch

Raises:
MisconfigurationException:
If using data-parallel, ``Trainer(accelerator='dp')``.

See Also:
- :meth:`on_after_batch_transfer`
- :meth:`transfer_batch_to_device`
Expand All @@ -681,6 +687,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx):

Note:
This hook only runs on single GPU training and DDP (no data-parallel).
Data-Parallel support will come in near future.

Args:
batch: A batch of data that needs to be altered or augmented.
Expand All @@ -695,6 +702,10 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
batch['x'] = gpu_transforms(batch['x'])
return batch

Raises:
MisconfigurationException:
If using data-parallel, ``Trainer(accelerator='dp')``.

See Also:
- :meth:`on_before_batch_transfer`
- :meth:`transfer_batch_to_device`
Expand Down
31 changes: 17 additions & 14 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
# set up the passed in dataloaders (if needed)
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
self.attach_datamodule(model, datamodule)
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
self._validate_data_hooks(model)

def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
Expand All @@ -97,6 +98,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)

def _validate_data_hooks(self, model):
# Raise Misconfiguration exception since these hooks are not supported in DP mode
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.')

def attach_dataloaders(
self,
model,
Expand Down Expand Up @@ -127,22 +136,16 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule] = N
if datamodule:

# Override loader hooks
if is_overridden('train_dataloader', datamodule):
model.train_dataloader = datamodule.train_dataloader
if is_overridden('val_dataloader', datamodule):
model.val_dataloader = datamodule.val_dataloader
if is_overridden('test_dataloader', datamodule):
model.test_dataloader = datamodule.test_dataloader
if is_overridden('predict_dataloader', datamodule):
model.predict_dataloader = datamodule.predict_dataloader
dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader')
for method in dl_methods:
if is_overridden(method, datamodule):
setattr(model, method, getattr(datamodule, method))

# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
if is_overridden('on_before_batch_transfer', datamodule):
model.on_before_batch_transfer = datamodule.on_before_batch_transfer
if is_overridden('transfer_batch_to_device', datamodule):
model.transfer_batch_to_device = datamodule.transfer_batch_to_device
if is_overridden('on_after_batch_transfer', datamodule):
model.on_after_batch_transfer = datamodule.on_after_batch_transfer
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if is_overridden(hook, datamodule):
setattr(model, hook, getattr(datamodule, hook))

self.trainer.datamodule = datamodule
datamodule.trainer = self.trainer
Expand Down
53 changes: 53 additions & 0 deletions tests/accelerators/test_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl
import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core import memory
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -132,6 +135,56 @@ def training_epoch_end(self, outputs):
assert outputs[0]["reduce_float"].item() == 0.5 # mean([0., 1.]) = 0.5


def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch):
"""
Test that an exception is raised when overriding batch_transfer_hooks in DP model.
"""
monkeypatch.setattr("torch.cuda.device_count", lambda: 2)

class CustomModel(BoringModel):

def transfer_batch_to_device(self, batch, device):
batch = batch.to(device)
return batch

trainer_options = dict(
default_root_dir=tmpdir,
max_steps=7,
gpus=[0, 1],
accelerator='dp',
)

trainer = Trainer(**trainer_options)
model = CustomModel()

with pytest.raises(MisconfigurationException, match=r'Overriding `transfer_batch_to_device` is not .* in DP'):
trainer.fit(model)

class CustomModel(BoringModel):

def on_before_batch_transfer(self, batch, dataloader_idx):
batch += 1
return batch

trainer = Trainer(**trainer_options)
model = CustomModel()

with pytest.raises(MisconfigurationException, match=r'Overriding `on_before_batch_transfer` is not .* in DP'):
trainer.fit(model)

class CustomModel(BoringModel):

def on_after_batch_transfer(self, batch, dataloader_idx):
batch += 1
return batch

trainer = Trainer(**trainer_options)
model = CustomModel()

with pytest.raises(MisconfigurationException, match=r'Overriding `on_after_batch_transfer` is not .* in DP'):
trainer.fit(model)


@RunIf(min_gpus=2)
def test_dp_training_step_dict(tmpdir):
""" This test verifies that dp properly reduces dictionaries """
Expand Down