Skip to content

Commit

Permalink
Enable ZeRO tests for CI, fix to/half function calls (#6070)
Browse files Browse the repository at this point in the history
* Enable ZeRO optimization, and make sure that the lightning module hook is called when we move to half precision

* Added test, update to function
  • Loading branch information
SeanNaren committed Feb 21, 2021
1 parent 97a81c3 commit 3b0e4e0
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 23 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089))


- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070)


## [1.2.0] - 2021-02-18

### Added
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()


class _LightningModuleWrapperBase(torch.nn.Module):
class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):

def __init__(self, pl_module: LightningModule):
"""
Expand Down Expand Up @@ -72,6 +73,9 @@ def forward(self, *inputs, **kwargs):

return output

def on_post_move_to_device(self):
pass


def warn_if_output_is_none(output: Any, method_name: str) -> None:
""" Warns user about which method returned None. """
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def to(self, *args, **kwargs) -> Module:
self.__update_properties(device=out[0], dtype=out[1])
return super().to(*args, **kwargs)

def cuda(self, device: Optional[int] = None) -> Module:
def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Module:
"""Moves all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
Expand All @@ -132,7 +132,8 @@ def cuda(self, device: Optional[int] = None) -> Module:
Returns:
Module: self
"""
self.__update_properties(device=torch.device('cuda', index=device))
property_device = device if isinstance(device, torch.device) else torch.device('cuda', index=device)
self.__update_properties(device=property_device)
return super().cuda(device=device)

def cpu(self) -> Module:
Expand Down
84 changes: 67 additions & 17 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,52 @@

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule
from pytorch_lightning.utilities import _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel


def test_deepspeed_lightning_module(tmpdir):
"""
Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly.
"""

model = BoringModel()
module = LightningDeepSpeedModule(model, precision=16)

module.half()
assert module.dtype == torch.half
assert model.dtype == torch.half

module.to(torch.double)
assert module.dtype == torch.double
assert model.dtype == torch.double


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
def test_deepspeed_lightning_module_precision(tmpdir):
"""
Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision 16.
"""

model = BoringModel()
module = LightningDeepSpeedModule(model, precision=16)

module.cuda().half()
assert module.dtype == torch.half
assert model.dtype == torch.half

x = torch.randn((1, 32), dtype=torch.float).cuda()
out = module(x)

assert out.dtype == torch.half

module.to(torch.double)
assert module.dtype == torch.double
assert model.dtype == torch.double


@pytest.fixture
def deepspeed_config():
return {
Expand All @@ -34,6 +75,11 @@ def deepspeed_config():
}


@pytest.fixture
def deepspeed_zero_config(deepspeed_config):
return {**deepspeed_config, 'zero_allow_untested_optimizer': True, 'zero_optimization': {'stage': 2}}


@pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.")
def test_deepspeed_plugin_string(tmpdir):
"""
Expand Down Expand Up @@ -179,12 +225,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args
return loss.backward()

model = TestModel()
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
plugins=DeepSpeedPlugin(zero_optimization=False),
gpus=1,
)
trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, plugins=DeepSpeedPlugin(), gpus=1, precision=16)
with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'):
trainer.fit(model)

Expand All @@ -203,17 +244,21 @@ def test_deepspeed_run_configure_optimizers(tmpdir):
class TestModel(BoringModel):

def on_train_start(self) -> None:
assert isinstance(self.trainer.optimizers[0], torch.optim.SGD)
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer

assert isinstance(self.trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer)
assert isinstance(self.trainer.optimizers[0].optimizer, torch.optim.SGD)
assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally
# Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler
assert isinstance(self.trainer.model.lr_scheduler, torch.optim.lr_scheduler.StepLR)

model = TestModel()
trainer = Trainer(
plugins=DeepSpeedPlugin(zero_optimization=False),
plugins=DeepSpeedPlugin(), # disable ZeRO so our optimizers are not wrapped
default_root_dir=tmpdir,
gpus=1,
fast_dev_run=True,
precision=16
)

trainer.fit(model)
Expand All @@ -226,7 +271,7 @@ def on_train_start(self) -> None:
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
)
def test_deepspeed_config(tmpdir, deepspeed_config):
def test_deepspeed_config(tmpdir, deepspeed_zero_config):
"""
Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers
and saves the model weights to load correctly.
Expand All @@ -235,18 +280,22 @@ def test_deepspeed_config(tmpdir, deepspeed_config):
class TestModel(BoringModel):

def on_train_start(self) -> None:
import deepspeed
assert isinstance(self.trainer.optimizers[0], torch.optim.SGD)
from deepspeed.runtime.lr_schedules import WarmupLR
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer

assert isinstance(self.trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer)
assert isinstance(self.trainer.optimizers[0].optimizer, torch.optim.SGD)
assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally
assert isinstance(self.trainer.model.optimizer, torch.optim.SGD)
assert isinstance(self.trainer.model.lr_scheduler, deepspeed.runtime.lr_schedules.WarmupLR)
# Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler
assert isinstance(self.trainer.model.lr_scheduler, WarmupLR)

model = TestModel()
trainer = Trainer(
plugins=[DeepSpeedPlugin(config=deepspeed_config)],
plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)],
default_root_dir=tmpdir,
gpus=1,
fast_dev_run=True,
precision=16
)

trainer.fit(model)
Expand All @@ -267,7 +316,7 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config):
"""
model = BoringModel()
trainer = Trainer(
plugins=[DeepSpeedPlugin(zero_optimization=False)],
plugins=[DeepSpeedPlugin()],
default_root_dir=tmpdir,
gpus=2,
fast_dev_run=True,
Expand All @@ -285,8 +334,9 @@ def _assert_save_model_is_equal(model, tmpdir, trainer):
# carry out the check only on rank 0
if trainer.global_rank == 0:
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
saved_model = saved_model.float()
model = model.float().cpu()
if model.dtype == torch.half:
saved_model = saved_model.half() # model is loaded in float32 as default, move it to float16
model = model.cpu()
# Assert model parameters are identical after loading
for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(orig_param, trained_model_param)
13 changes: 10 additions & 3 deletions tests/utilities/test_dtype_device_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,19 @@ def test_submodules_multi_gpu_ddp_spawn(tmpdir):
trainer.fit(model)


@pytest.mark.parametrize(
['device'],
[
pytest.param(None), # explicitly call without an index to see if the returning device contains an index
pytest.param(0),
pytest.param(torch.device('cuda', 0)),
]
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
def test_gpu_device_includes_index():
def test_gpu_cuda_device(device):
model = TopModule()

# explicitly call without an index to see if the returning device contains an index (it should!)
model.cuda()
model.cuda(device)

device = model.device
assert device.type == 'cuda'
Expand Down

0 comments on commit 3b0e4e0

Please sign in to comment.