Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix dtype/device property not getting updated in submodules #2657

Merged
merged 16 commits into from
Jul 21, 2020
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed


- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657))

## [0.8.5] - 2020-07-09

Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ def __init__(self, *args, **kwargs):
#: True if using amp
self.use_amp = False

#: Current dtype
self._dtype = torch.float

#: device reference
self._device = torch.device('cpu')

# optionally can be set by user
self._example_input_array = None

Expand Down
2 changes: 0 additions & 2 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ def __init__(self, name: str):
"""
super().__init__()
self.name = name
self._dtype = torch.get_default_dtype()
self._device = torch.device('cpu')

@abstractmethod
def forward(self, *args, **kwargs) -> torch.Tensor:
Expand Down
45 changes: 28 additions & 17 deletions pytorch_lightning/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@


class DeviceDtypeModuleMixin(Module):
_device: ...
_dtype: Union[str, torch.dtype]

def __init__(self):
super().__init__()
self._dtype = torch.get_default_dtype()
self._device = torch.device('cpu')

@property
def dtype(self) -> Union[str, torch.dtype]:
Expand Down Expand Up @@ -79,17 +82,14 @@ def to(self, *args, **kwargs) -> Module:
ExampleModule()
>>> module.weight #doctest: +ELLIPSIS
tensor([[...]], dtype=torch.float16)
>>> module.device
device(type='cpu')
>>> module.dtype
torch.float16
"""
# there is diff nb vars in PT 1.5
out = torch._C._nn._parse_to(*args, **kwargs)
device = out[0]
dtype = out[1]
if device is not None:
self._device = device

if dtype is not None:
self._dtype = dtype

self.__update_properties(device=out[0], dtype=out[1])
return super().to(*args, **kwargs)

def cuda(self, device: Optional[int] = None) -> Module:
Expand All @@ -105,16 +105,15 @@ def cuda(self, device: Optional[int] = None) -> Module:
Returns:
Module: self
"""

self._device = torch.device('cuda', index=device)
self.__update_properties(device=torch.device('cuda', index=device))
return super().cuda(device=device)

def cpu(self) -> Module:
"""Moves all model parameters and buffers to the CPU.
Returns:
Module: self
"""
self._device = torch.device('cpu')
self.__update_properties(device=torch.device('cpu'))
return super().cpu()

def type(self, dst_type: Union[str, torch.dtype]) -> Module:
Expand All @@ -126,7 +125,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> Module:
Returns:
Module: self
"""
self._dtype = dst_type
self.__update_properties(dtype=dst_type)
return super().type(dst_type=dst_type)

def float(self) -> Module:
Expand All @@ -135,7 +134,7 @@ def float(self) -> Module:
Returns:
Module: self
"""
self._dtype = torch.float
self.__update_properties(dtype=torch.float)
return super().float()

def double(self) -> Module:
Expand All @@ -144,7 +143,7 @@ def double(self) -> Module:
Returns:
Module: self
"""
self._dtype = torch.double
self.__update_properties(dtype=torch.double)
return super().double()

def half(self) -> Module:
Expand All @@ -153,5 +152,17 @@ def half(self) -> Module:
Returns:
Module: self
"""
self._dtype = torch.half
self.__update_properties(dtype=torch.half)
return super().half()

def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):

def apply_fn(module):
if not isinstance(module, DeviceDtypeModuleMixin):
return
if device is not None:
module._device = device
if dtype is not None:
module._dtype = dtype

self.apply(apply_fn)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason of using apply?

Copy link
Contributor Author

@awaelchli awaelchli Jul 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

part of the answer is here in my comment.
apply in contrast to "to" works recursively on all modules and allows us to update our custom properties.
I'm writing a test right now to make sure it fixes what failed before.

89 changes: 89 additions & 0 deletions tests/utilities/test_dtype_device_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
import torch
import torch.nn as nn

from pytorch_lightning import Trainer, Callback
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from tests.base import EvalModelTemplate


class SubSubModule(DeviceDtypeModuleMixin):
pass


class SubModule(nn.Module):

def __init__(self):
super().__init__()
self.module = SubSubModule()


class TopModule(EvalModelTemplate):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.module = SubModule()


class DeviceAssertCallback(Callback):

def on_batch_start(self, trainer, model):
rank = trainer.local_rank
assert isinstance(model, TopModule)
# index = None also means first device
assert (model.device.index is None and rank == 0) or model.device.index == rank
assert model.device == model.module.module.device


@pytest.mark.parametrize(['dst_dtype'], [
pytest.param(torch.float),
pytest.param(torch.double),
pytest.param(torch.half),
])
@pytest.mark.parametrize(['dst_device'], [
pytest.param(torch.device('cpu')),
pytest.param(torch.device('cuda')),
pytest.param(torch.device('cuda', 0)),
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_submodules_device_and_dtype(dst_device, dst_dtype):
"""
Test that the device and dtype property updates propagate through mixed nesting of regular
nn.Modules and the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule).
"""

model = TopModule()
assert model.device == torch.device('cpu')
model = model.to(device=dst_device, dtype=dst_dtype)
# nn.Module does not have these attributes
assert not hasattr(model.module, '_device')
assert not hasattr(model.module, '_dtype')
# device and dtype change should propagate down into all children
assert model.device == model.module.module.device == dst_device
assert model.dtype == model.module.module.dtype == dst_dtype


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_submodules_multi_gpu_dp(tmpdir):
model = TopModule()
trainer = Trainer(
default_root_dir=tmpdir,
distributed_backend='dp',
gpus=2,
callbacks=[DeviceAssertCallback()],
max_steps=1,
)
trainer.fit(model)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_submodules_multi_gpu_ddp_spawn(tmpdir):
model = TopModule()
trainer = Trainer(
default_root_dir=tmpdir,
distributed_backend='dpp_spawn',
gpus=2,
callbacks=[DeviceAssertCallback()],
max_steps=1,
)
trainer.fit(model)