diff --git a/CHANGELOG.md b/CHANGELOG.md index cf200ea15f007..6a1e23adaa369 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - +- Fixed `dtype` and `device` properties not getting updated in submodules ([#2657](https://github.com/PyTorchLightning/pytorch-lightning/pull/2657)) ## [0.8.5] - 2020-07-09 diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index f2b591cbacfbe..1739133edabf3 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -68,12 +68,6 @@ def __init__(self, *args, **kwargs): #: True if using amp self.use_amp = False - #: Current dtype - self._dtype = torch.float - - #: device reference - self._device = torch.device('cpu') - # optionally can be set by user self._example_input_array = None diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 349a6ecfa2f82..94e8a0ea4e442 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -27,8 +27,6 @@ def __init__(self, name: str): """ super().__init__() self.name = name - self._dtype = torch.get_default_dtype() - self._device = torch.device('cpu') @abstractmethod def forward(self, *args, **kwargs) -> torch.Tensor: diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index 48ccad5307552..bea3df3e5ced9 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -5,8 +5,11 @@ class DeviceDtypeModuleMixin(Module): - _device: ... - _dtype: Union[str, torch.dtype] + + def __init__(self): + super().__init__() + self._dtype = torch.get_default_dtype() + self._device = torch.device('cpu') @property def dtype(self) -> Union[str, torch.dtype]: @@ -79,17 +82,14 @@ def to(self, *args, **kwargs) -> Module: ExampleModule() >>> module.weight #doctest: +ELLIPSIS tensor([[...]], dtype=torch.float16) + >>> module.device + device(type='cpu') + >>> module.dtype + torch.float16 """ # there is diff nb vars in PT 1.5 out = torch._C._nn._parse_to(*args, **kwargs) - device = out[0] - dtype = out[1] - if device is not None: - self._device = device - - if dtype is not None: - self._dtype = dtype - + self.__update_properties(device=out[0], dtype=out[1]) return super().to(*args, **kwargs) def cuda(self, device: Optional[int] = None) -> Module: @@ -105,8 +105,7 @@ def cuda(self, device: Optional[int] = None) -> Module: Returns: Module: self """ - - self._device = torch.device('cuda', index=device) + self.__update_properties(device=torch.device('cuda', index=device)) return super().cuda(device=device) def cpu(self) -> Module: @@ -114,7 +113,7 @@ def cpu(self) -> Module: Returns: Module: self """ - self._device = torch.device('cpu') + self.__update_properties(device=torch.device('cpu')) return super().cpu() def type(self, dst_type: Union[str, torch.dtype]) -> Module: @@ -126,7 +125,7 @@ def type(self, dst_type: Union[str, torch.dtype]) -> Module: Returns: Module: self """ - self._dtype = dst_type + self.__update_properties(dtype=dst_type) return super().type(dst_type=dst_type) def float(self) -> Module: @@ -135,7 +134,7 @@ def float(self) -> Module: Returns: Module: self """ - self._dtype = torch.float + self.__update_properties(dtype=torch.float) return super().float() def double(self) -> Module: @@ -144,7 +143,7 @@ def double(self) -> Module: Returns: Module: self """ - self._dtype = torch.double + self.__update_properties(dtype=torch.double) return super().double() def half(self) -> Module: @@ -153,5 +152,17 @@ def half(self) -> Module: Returns: Module: self """ - self._dtype = torch.half + self.__update_properties(dtype=torch.half) return super().half() + + def __update_properties(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): + + def apply_fn(module): + if not isinstance(module, DeviceDtypeModuleMixin): + return + if device is not None: + module._device = device + if dtype is not None: + module._dtype = dtype + + self.apply(apply_fn) diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py new file mode 100644 index 0000000000000..f755cf5c634ed --- /dev/null +++ b/tests/utilities/test_dtype_device_mixin.py @@ -0,0 +1,89 @@ +import pytest +import torch +import torch.nn as nn + +from pytorch_lightning import Trainer, Callback +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from tests.base import EvalModelTemplate + + +class SubSubModule(DeviceDtypeModuleMixin): + pass + + +class SubModule(nn.Module): + + def __init__(self): + super().__init__() + self.module = SubSubModule() + + +class TopModule(EvalModelTemplate): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.module = SubModule() + + +class DeviceAssertCallback(Callback): + + def on_batch_start(self, trainer, model): + rank = trainer.local_rank + assert isinstance(model, TopModule) + # index = None also means first device + assert (model.device.index is None and rank == 0) or model.device.index == rank + assert model.device == model.module.module.device + + +@pytest.mark.parametrize(['dst_dtype'], [ + pytest.param(torch.float), + pytest.param(torch.double), + pytest.param(torch.half), +]) +@pytest.mark.parametrize(['dst_device'], [ + pytest.param(torch.device('cpu')), + pytest.param(torch.device('cuda')), + pytest.param(torch.device('cuda', 0)), +]) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") +def test_submodules_device_and_dtype(dst_device, dst_dtype): + """ + Test that the device and dtype property updates propagate through mixed nesting of regular + nn.Modules and the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule). + """ + + model = TopModule() + assert model.device == torch.device('cpu') + model = model.to(device=dst_device, dtype=dst_dtype) + # nn.Module does not have these attributes + assert not hasattr(model.module, '_device') + assert not hasattr(model.module, '_dtype') + # device and dtype change should propagate down into all children + assert model.device == model.module.module.device == dst_device + assert model.dtype == model.module.module.dtype == dst_dtype + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_submodules_multi_gpu_dp(tmpdir): + model = TopModule() + trainer = Trainer( + default_root_dir=tmpdir, + distributed_backend='dp', + gpus=2, + callbacks=[DeviceAssertCallback()], + max_steps=1, + ) + trainer.fit(model) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_submodules_multi_gpu_ddp_spawn(tmpdir): + model = TopModule() + trainer = Trainer( + default_root_dir=tmpdir, + distributed_backend='dpp_spawn', + gpus=2, + callbacks=[DeviceAssertCallback()], + max_steps=1, + ) + trainer.fit(model)