Skip to content

Commit

Permalink
fix self.device access in DataParallel (#6414)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Apr 13, 2021
1 parent 030b76b commit 80c5293
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 0 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))


- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414))


## [1.2.7] - 2021-04-06

### Fixed
Expand Down
34 changes: 34 additions & 0 deletions pytorch_lightning/overrides/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
from pytorch_lightning.overrides.distributed import LightningDistributedModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection


Expand Down Expand Up @@ -71,6 +72,8 @@ def __init__(self, pl_module: LightningModule):
super().__init__(pl_module)

def forward(self, *inputs, **kwargs):
self.update_replica_device_attributes(inputs)
# forward call will redirect to training_step, validation_step, etc.
output = super().forward(*inputs, **kwargs)

def output_transform(data: Any):
Expand All @@ -85,6 +88,37 @@ def output_transform(data: Any):
)
return output

def update_replica_device_attributes(self, inputs: Any) -> None:
"""
Updates the device information of LightningModule by reading the device from the inputs.
In :class:`~torch.nn.data_parallel.DataParallel` changes to the state during the `forward` pass
are lost when the replicas get discarded. The only way to know the current device is from the
inputs passed into the model.
Args:
inputs: A collection of inputs (typically a tuple). If the inputs don't contain tensors,
a warning is shown that accessing ``self.device`` will not return the correct device.
"""
replica_device = None

def find_tensor_with_device(tensor: torch.Tensor) -> torch.Tensor:
nonlocal replica_device
if replica_device is None and tensor.device != torch.device("cpu"):
replica_device = tensor.device
return tensor

apply_to_collection(inputs, dtype=torch.Tensor, function=find_tensor_with_device)

if replica_device is not None:
# by calling .to() we force the update to the self.device property
self.module.to(device=replica_device)
else:
rank_zero_warn(
"Could not determine on which device the inputs are."
" When using DataParallel (accelerator='dp'), be aware that in case you are using self.device"
" in your code, it will reference only the root device."
)


def python_scalar_to_tensor(data: Any, device: torch.device = torch.device("cpu")) -> Any:
""" Converts a Python scalar number to a torch tensor and places it on the given device. """
Expand Down
68 changes: 68 additions & 0 deletions tests/overrides/test_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import pytest
import torch
import torch.nn as nn
from torch.nn import DataParallel

from pytorch_lightning import LightningModule
from pytorch_lightning.core.decorators import auto_move_data
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.data_parallel import (
LightningParallelModule,
Expand Down Expand Up @@ -123,3 +126,68 @@ def training_step(self, batch, batch_idx):
wrapped_model = LightningParallelModule(model)
output = wrapped_model(batch, batch_idx)
assert output["python scalar"] == torch.tensor([12.3], device=device)


@RunIf(min_gpus=2)
@pytest.mark.parametrize(
"nest, unnest", [
(lambda x: x, lambda x: x),
(lambda x: dict(data=x), lambda x: x["data"]),
(lambda x: [x, (x, x)], lambda x: x[1][0]),
]
)
def test_lightning_parallel_module_device_access(nest, unnest):
""" Test that self.device returns the correct value in replicas of DataParallel. """

class DeviceAccessModel(LightningModule):

def __init__(self):
super().__init__()
self.layer = nn.Linear(2, 3)

@auto_move_data
def training_step(self, batch, batch_idx):
batch = unnest(batch)
assert batch.shape == torch.Size([1, 1])
assert self.device.index == batch.item()
assert self.device == self.layer.weight.device
return torch.tensor(1, device=self.device)

pl_module = DeviceAccessModel()
# required for redirecting the forward call to training_step
pl_module.trainer = Mock()
pl_module.trainer._running_stage = RunningStage.TRAINING

root_device = torch.device("cuda", 0)
wrapped_module = LightningParallelModule(pl_module).to(root_device)
model = DataParallel(wrapped_module, device_ids=[0, 1])

data = torch.tensor([0.0, 1.0], device=root_device).view(2, 1) # one value per gpu
data = data.to(root_device)
data = nest(data)
output = model(data, 0)
assert output.device == root_device
assert pl_module.device == root_device
assert torch.all(output.cpu().eq(torch.tensor([1, 1])))


@RunIf(min_gpus=2)
def test_lightning_parallel_module_device_access_warning():
""" Test that we show a warning when the device can't be inferred from the input. """

class DeviceAccessModel(LightningModule):

def training_step(self, batch, batch_idx):
pass

pl_module = DeviceAccessModel()
# required for redirecting the forward call to training_step
pl_module.trainer = Mock()
pl_module.trainer._running_stage = RunningStage.TRAINING

wrapped_module = LightningParallelModule(pl_module).cuda()
model = DataParallel(wrapped_module, device_ids=[0, 1])

data = dict(x=1) # contains no tensors
with pytest.warns(UserWarning, match="Could not determine on which device the inputs are."):
_ = model(data, 0)

0 comments on commit 80c5293

Please sign in to comment.