-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix dtype/device property not getting updated in submodules (#2657)
* recursive dtype device apply * simplify * simple test * submodule test * rename * explicit * type hints * test for dp backend * fix test skip * rename * add ddp_spawn test * fix None index in test * try fix ddp_spawn test * changelog * move _dtype and _device to mixin * additional doctest
- Loading branch information
Showing
5 changed files
with
118 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
|
||
from pytorch_lightning import Trainer, Callback | ||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin | ||
from tests.base import EvalModelTemplate | ||
|
||
|
||
class SubSubModule(DeviceDtypeModuleMixin): | ||
pass | ||
|
||
|
||
class SubModule(nn.Module): | ||
|
||
def __init__(self): | ||
super().__init__() | ||
self.module = SubSubModule() | ||
|
||
|
||
class TopModule(EvalModelTemplate): | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
self.module = SubModule() | ||
|
||
|
||
class DeviceAssertCallback(Callback): | ||
|
||
def on_batch_start(self, trainer, model): | ||
rank = trainer.local_rank | ||
assert isinstance(model, TopModule) | ||
# index = None also means first device | ||
assert (model.device.index is None and rank == 0) or model.device.index == rank | ||
assert model.device == model.module.module.device | ||
|
||
|
||
@pytest.mark.parametrize(['dst_dtype'], [ | ||
pytest.param(torch.float), | ||
pytest.param(torch.double), | ||
pytest.param(torch.half), | ||
]) | ||
@pytest.mark.parametrize(['dst_device'], [ | ||
pytest.param(torch.device('cpu')), | ||
pytest.param(torch.device('cuda')), | ||
pytest.param(torch.device('cuda', 0)), | ||
]) | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") | ||
def test_submodules_device_and_dtype(dst_device, dst_dtype): | ||
""" | ||
Test that the device and dtype property updates propagate through mixed nesting of regular | ||
nn.Modules and the special modules of type DeviceDtypeModuleMixin (e.g. Metric or LightningModule). | ||
""" | ||
|
||
model = TopModule() | ||
assert model.device == torch.device('cpu') | ||
model = model.to(device=dst_device, dtype=dst_dtype) | ||
# nn.Module does not have these attributes | ||
assert not hasattr(model.module, '_device') | ||
assert not hasattr(model.module, '_dtype') | ||
# device and dtype change should propagate down into all children | ||
assert model.device == model.module.module.device == dst_device | ||
assert model.dtype == model.module.module.dtype == dst_dtype | ||
|
||
|
||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") | ||
def test_submodules_multi_gpu_dp(tmpdir): | ||
model = TopModule() | ||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
distributed_backend='dp', | ||
gpus=2, | ||
callbacks=[DeviceAssertCallback()], | ||
max_steps=1, | ||
) | ||
trainer.fit(model) | ||
|
||
|
||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") | ||
def test_submodules_multi_gpu_ddp_spawn(tmpdir): | ||
model = TopModule() | ||
trainer = Trainer( | ||
default_root_dir=tmpdir, | ||
distributed_backend='dpp_spawn', | ||
gpus=2, | ||
callbacks=[DeviceAssertCallback()], | ||
max_steps=1, | ||
) | ||
trainer.fit(model) |