Skip to content

Commit

Permalink
Disable batch transfer in DP mode (#6098)
Browse files Browse the repository at this point in the history
* add exceptions and test

* hook

* fix

* clean up

* clean up

* regex

* regex

* docs

* rev

* comment and docs

* chlog

* Apply suggestions from code review

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Apply suggestions from code review

Co-authored-by: chaton <thomas@grid.ai>

* Monkey-patch device count

* docs

* pep

* api_change

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
Co-authored-by: chaton <thomas@grid.ai>

(cherry picked from commit c53edce)
  • Loading branch information
rohitgr7 authored and SeanNaren committed Mar 16, 2021
1 parent 344b455 commit dfba137
Show file tree
Hide file tree
Showing 5 changed files with 205 additions and 29 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107))


- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))


## [1.2.0] - 2021-02-18

### Added
Expand Down
27 changes: 23 additions & 4 deletions pytorch_lightning/accelerators/gpu.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,59 @@
import logging
import os
from typing import TYPE_CHECKING, Any

import torch

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins import DataParallelPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException

if TYPE_CHECKING:
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.trainer import Trainer

_log = logging.getLogger(__name__)


class GPUAccelerator(Accelerator):

def setup(self, trainer, model):
def setup(self, trainer: 'Trainer', model: 'LightningModule') -> None:
"""
Raises:
MisconfigurationException:
If the selected device is not GPU.
"""
if "cuda" not in str(self.root_device):
raise MisconfigurationException(f"Device should be GPU, got {self.root_device} instead")
self.set_nvidia_flags()
torch.cuda.set_device(self.root_device)
return super().setup(trainer, model)

def on_train_start(self):
def on_train_start(self) -> None:
# clear cache before training
# use context because of:
# https://discuss.pytorch.org/t/out-of-memory-when-i-use-torch-cuda-empty-cache/57898
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()

def on_train_end(self):
def on_train_end(self) -> None:
# clean up memory
self.model.cpu()
with torch.cuda.device(self.root_device):
torch.cuda.empty_cache()

@staticmethod
def set_nvidia_flags():
def set_nvidia_flags() -> None:
# set the correct cuda visible devices (using pci order)
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())])
devices = os.getenv("CUDA_VISIBLE_DEVICES", all_gpu_ids)
_log.info(f"LOCAL_RANK: {os.getenv('LOCAL_RANK', 0)} - CUDA_VISIBLE_DEVICES: [{devices}]")

def to_device(self, batch: Any) -> Any:
# no need to transfer batch to device in DP mode
# TODO: Add support to allow batch transfer to device in Lightning for DP mode.
if not isinstance(self.training_type_plugin, DataParallelPlugin):
batch = super().to_device(batch)

return batch
21 changes: 16 additions & 5 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,10 +615,7 @@ def transfer_batch_to_device(self, batch: Any, device: Optional[torch.device] =
Note:
This hook only runs on single GPU training and DDP (no data-parallel).
If you need multi-GPU support for your custom batch objects, you need to define your custom
:class:`~torch.nn.parallel.DistributedDataParallel` or
:class:`~pytorch_lightning.overrides.data_parallel.LightningDistributedDataParallel` and
override :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_ddp`.
Data-Parallel support will come in near future.
Args:
batch: A batch of data that needs to be transferred to a new device.
Expand All @@ -638,6 +635,10 @@ def transfer_batch_to_device(self, batch, device):
batch = super().transfer_batch_to_device(data, device)
return batch
Raises:
MisconfigurationException:
If using data-parallel, ``Trainer(accelerator='dp')``.
See Also:
- :meth:`move_data_to_device`
- :meth:`apply_to_collection`
Expand All @@ -649,10 +650,11 @@ def on_before_batch_transfer(self, batch, dataloader_idx):
"""
Override to alter or apply batch augmentations to your batch before it is transferred to the device.
.. warning:: dataloader_idx always returns 0, and will be updated to support the true idx in the future.
.. warning:: ``dataloader_idx`` always returns 0, and will be updated to support the true index in the future.
Note:
This hook only runs on single GPU training and DDP (no data-parallel).
Data-Parallel support will come in near future.
Args:
batch: A batch of data that needs to be altered or augmented.
Expand All @@ -667,6 +669,10 @@ def on_before_batch_transfer(self, batch, dataloader_idx):
batch['x'] = transforms(batch['x'])
return batch
Raises:
MisconfigurationException:
If using data-parallel, ``Trainer(accelerator='dp')``.
See Also:
- :meth:`on_after_batch_transfer`
- :meth:`transfer_batch_to_device`
Expand All @@ -681,6 +687,7 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
Note:
This hook only runs on single GPU training and DDP (no data-parallel).
Data-Parallel support will come in near future.
Args:
batch: A batch of data that needs to be altered or augmented.
Expand All @@ -695,6 +702,10 @@ def on_after_batch_transfer(self, batch, dataloader_idx):
batch['x'] = gpu_transforms(batch['x'])
return batch
Raises:
MisconfigurationException:
If using data-parallel, ``Trainer(accelerator='dp')``.
See Also:
- :meth:`on_before_batch_transfer`
- :meth:`transfer_batch_to_device`
Expand Down
33 changes: 18 additions & 15 deletions pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):

# set up the passed in dataloaders (if needed)
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
self.attach_datamodule(model, datamodule, 'fit')
self.attach_datamodule(model, datamodule)
self._validate_data_hooks(model)

def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
Expand All @@ -90,6 +91,14 @@ def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloa
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)

def _validate_data_hooks(self, model):
# Raise Misconfiguration exception since these hooks are not supported in DP mode
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if self.trainer.accelerator_connector.use_dp and is_overridden(hook, model):
raise MisconfigurationException(f'Overriding `{hook}` is not supported in DP mode.')

def attach_dataloaders(
self,
model,
Expand Down Expand Up @@ -122,22 +131,16 @@ def attach_datamodule(self, model, datamodule: Optional[LightningDataModule], st
if datamodule:

# Override loader hooks
if is_overridden('train_dataloader', datamodule):
model.train_dataloader = datamodule.train_dataloader
if is_overridden('val_dataloader', datamodule):
model.val_dataloader = datamodule.val_dataloader
if is_overridden('test_dataloader', datamodule):
model.test_dataloader = datamodule.test_dataloader
if is_overridden('predict_dataloader', datamodule):
model.predict_dataloader = datamodule.predict_dataloader
dl_methods = ('train_dataloader', 'val_dataloader', 'test_dataloader', 'predict_dataloader')
for method in dl_methods:
if is_overridden(method, datamodule):
setattr(model, method, getattr(datamodule, method))

# Override data transfer hooks if dataset-specific to_device logic has been defined in datamodule
if is_overridden('on_before_batch_transfer', datamodule):
model.on_before_batch_transfer = datamodule.on_before_batch_transfer
if is_overridden('transfer_batch_to_device', datamodule):
model.transfer_batch_to_device = datamodule.transfer_batch_to_device
if is_overridden('on_after_batch_transfer', datamodule):
model.on_after_batch_transfer = datamodule.on_after_batch_transfer
batch_transfer_hooks = ('on_before_batch_transfer', 'transfer_batch_to_device', 'on_after_batch_transfer')
for hook in batch_transfer_hooks:
if is_overridden(hook, datamodule):
setattr(model, hook, getattr(datamodule, hook))

self.trainer.datamodule = datamodule
datamodule.trainer = self.trainer
Expand Down
150 changes: 145 additions & 5 deletions tests/accelerators/test_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,66 @@
# limitations under the License.
import pytest
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

import pytorch_lightning as pl
import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core import memory
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.simple_models import ClassificationModel
from tests.base import EvalModelTemplate

PRETEND_N_OF_GPUS = 16

class CustomClassificationModelDP(ClassificationModel):

def _step(self, batch, batch_idx):
x, y = batch
logits = self(x)
return {'logits': logits, 'y': y}

def training_step(self, batch, batch_idx):
out = self._step(batch, batch_idx)
loss = F.cross_entropy(out['logits'], out['y'])
return loss

def validation_step(self, batch, batch_idx):
return self._step(batch, batch_idx)

def test_step(self, batch, batch_idx):
return self._step(batch, batch_idx)

def validation_step_end(self, outputs):
self.log('val_acc', self.valid_acc(outputs['logits'], outputs['y']))

def test_step_end(self, outputs):
self.log('test_acc', self.test_acc(outputs['logits'], outputs['y']))


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_multi_gpu_early_stop_dp(tmpdir):
"""Make sure DDP works. with early stopping"""
tutils.set_random_master_port()

dm = ClassifDataModule()
model = CustomClassificationModelDP()

trainer_options = dict(
default_root_dir=tmpdir,
callbacks=[EarlyStopping()],
callbacks=[EarlyStopping(monitor='val_acc')],
max_epochs=50,
limit_train_batches=10,
limit_val_batches=10,
gpus=[0, 1],
accelerator='dp',
)

model = EvalModelTemplate()
tpipes.run_model_test(trainer_options, model)
tpipes.run_model_test(trainer_options, model, dm)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
Expand All @@ -57,14 +89,122 @@ def test_multi_gpu_model_dp(tmpdir):
progress_bar_refresh_rate=0,
)

model = EvalModelTemplate()
model = BoringModel()

tpipes.run_model_test(trainer_options, model)

# test memory helper functions
memory.get_memory_profile('min_max')


class ReductionTestModel(BoringModel):

def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)

def val_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)

def test_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)

def add_outputs(self, output, device):
output.update({
"reduce_int": torch.tensor(device.index, dtype=torch.int, device=device),
"reduce_float": torch.tensor(device.index, dtype=torch.float, device=device),
})

def training_step(self, batch, batch_idx):
output = super().training_step(batch, batch_idx)
self.add_outputs(output, batch.device)
return output

def validation_step(self, batch, batch_idx):
output = super().validation_step(batch, batch_idx)
self.add_outputs(output, batch.device)
return output

def test_step(self, batch, batch_idx):
output = super().test_step(batch, batch_idx)
self.add_outputs(output, batch.device)
return output

def training_epoch_end(self, outputs):
assert outputs[0]["loss"].shape == torch.Size([])
assert outputs[0]["reduce_int"].item() == 0 # mean([0, 1]) = 0
assert outputs[0]["reduce_float"].item() == 0.5 # mean([0., 1.]) = 0.5


def test_dp_raise_exception_with_batch_transfer_hooks(tmpdir, monkeypatch):
"""
Test that an exception is raised when overriding batch_transfer_hooks in DP model.
"""
monkeypatch.setattr("torch.cuda.device_count", lambda: 2)

class CustomModel(BoringModel):

def transfer_batch_to_device(self, batch, device):
batch = batch.to(device)
return batch

trainer_options = dict(
default_root_dir=tmpdir,
max_steps=7,
gpus=[0, 1],
accelerator='dp',
)

trainer = Trainer(**trainer_options)
model = CustomModel()

with pytest.raises(MisconfigurationException, match=r'Overriding `transfer_batch_to_device` is not .* in DP'):
trainer.fit(model)

class CustomModel(BoringModel):

def on_before_batch_transfer(self, batch, dataloader_idx):
batch += 1
return batch

trainer = Trainer(**trainer_options)
model = CustomModel()

with pytest.raises(MisconfigurationException, match=r'Overriding `on_before_batch_transfer` is not .* in DP'):
trainer.fit(model)

class CustomModel(BoringModel):

def on_after_batch_transfer(self, batch, dataloader_idx):
batch += 1
return batch

trainer = Trainer(**trainer_options)
model = CustomModel()

with pytest.raises(MisconfigurationException, match=r'Overriding `on_after_batch_transfer` is not .* in DP'):
trainer.fit(model)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_dp_training_step_dict(tmpdir):
""" This test verifies that dp properly reduces dictionaries """
model = ReductionTestModel()
model.training_step_end = None
model.validation_step_end = None
model.test_step_end = None

trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
gpus=2,
accelerator='dp',
)
trainer.fit(model)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_dp_test(tmpdir):
tutils.set_random_master_port()
Expand Down

0 comments on commit dfba137

Please sign in to comment.