From 3b0e4e0b2bc5b62bba09df5976e1460774ae7337 Mon Sep 17 00:00:00 2001 From: Sean Naren Date: Sun, 21 Feb 2021 00:24:44 +0000 Subject: [PATCH] Enable ZeRO tests for CI, fix to/half function calls (#6070) * Enable ZeRO optimization, and make sure that the lightning module hook is called when we move to half precision * Added test, update to function --- CHANGELOG.md | 3 + pytorch_lightning/overrides/base.py | 6 +- .../utilities/device_dtype_mixin.py | 5 +- tests/plugins/test_deepspeed_plugin.py | 84 +++++++++++++++---- tests/utilities/test_dtype_device_mixin.py | 13 ++- 5 files changed, 88 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 52cdb000a2a0f..55895318cba4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed priority of plugin/accelerator when setting distributed mode ([#6089](https://github.com/PyTorchLightning/pytorch-lightning/pull/6089)) +- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070) + + ## [1.2.0] - 2021-02-18 ### Added diff --git a/pytorch_lightning/overrides/base.py b/pytorch_lightning/overrides/base.py index 2fcb4b11a0b7f..c0b691bb07cb8 100644 --- a/pytorch_lightning/overrides/base.py +++ b/pytorch_lightning/overrides/base.py @@ -19,12 +19,13 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.warnings import WarningCache warning_cache = WarningCache() -class _LightningModuleWrapperBase(torch.nn.Module): +class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module): def __init__(self, pl_module: LightningModule): """ @@ -72,6 +73,9 @@ def forward(self, *inputs, **kwargs): return output + def on_post_move_to_device(self): + pass + def warn_if_output_is_none(output: Any, method_name: str) -> None: """ Warns user about which method returned None. """ diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index 6408c6e21cad4..3e3eccc93b368 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -119,7 +119,7 @@ def to(self, *args, **kwargs) -> Module: self.__update_properties(device=out[0], dtype=out[1]) return super().to(*args, **kwargs) - def cuda(self, device: Optional[int] = None) -> Module: + def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Module: """Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will @@ -132,7 +132,8 @@ def cuda(self, device: Optional[int] = None) -> Module: Returns: Module: self """ - self.__update_properties(device=torch.device('cuda', index=device)) + property_device = device if isinstance(device, torch.device) else torch.device('cuda', index=device) + self.__update_properties(device=property_device) return super().cuda(device=device) def cpu(self) -> Module: diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 9c9c5c097b4c5..fbb53974efd33 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -8,11 +8,52 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin +from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule from pytorch_lightning.utilities import _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _NATIVE_AMP_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel +def test_deepspeed_lightning_module(tmpdir): + """ + Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly. + """ + + model = BoringModel() + module = LightningDeepSpeedModule(model, precision=16) + + module.half() + assert module.dtype == torch.half + assert model.dtype == torch.half + + module.to(torch.double) + assert module.dtype == torch.double + assert model.dtype == torch.double + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +def test_deepspeed_lightning_module_precision(tmpdir): + """ + Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves tensors to half when precision 16. + """ + + model = BoringModel() + module = LightningDeepSpeedModule(model, precision=16) + + module.cuda().half() + assert module.dtype == torch.half + assert model.dtype == torch.half + + x = torch.randn((1, 32), dtype=torch.float).cuda() + out = module(x) + + assert out.dtype == torch.half + + module.to(torch.double) + assert module.dtype == torch.double + assert model.dtype == torch.double + + @pytest.fixture def deepspeed_config(): return { @@ -34,6 +75,11 @@ def deepspeed_config(): } +@pytest.fixture +def deepspeed_zero_config(deepspeed_config): + return {**deepspeed_config, 'zero_allow_untested_optimizer': True, 'zero_optimization': {'stage': 2}} + + @pytest.mark.skipif(not _DEEPSPEED_AVAILABLE, reason="DeepSpeed not available.") def test_deepspeed_plugin_string(tmpdir): """ @@ -179,12 +225,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args return loss.backward() model = TestModel() - trainer = Trainer( - fast_dev_run=True, - default_root_dir=tmpdir, - plugins=DeepSpeedPlugin(zero_optimization=False), - gpus=1, - ) + trainer = Trainer(fast_dev_run=True, default_root_dir=tmpdir, plugins=DeepSpeedPlugin(), gpus=1, precision=16) with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'): trainer.fit(model) @@ -203,17 +244,21 @@ def test_deepspeed_run_configure_optimizers(tmpdir): class TestModel(BoringModel): def on_train_start(self) -> None: - assert isinstance(self.trainer.optimizers[0], torch.optim.SGD) + from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer + + assert isinstance(self.trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) + assert isinstance(self.trainer.optimizers[0].optimizer, torch.optim.SGD) assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally # Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler assert isinstance(self.trainer.model.lr_scheduler, torch.optim.lr_scheduler.StepLR) model = TestModel() trainer = Trainer( - plugins=DeepSpeedPlugin(zero_optimization=False), + plugins=DeepSpeedPlugin(), # disable ZeRO so our optimizers are not wrapped default_root_dir=tmpdir, gpus=1, fast_dev_run=True, + precision=16 ) trainer.fit(model) @@ -226,7 +271,7 @@ def on_train_start(self) -> None: @pytest.mark.skipif( not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest" ) -def test_deepspeed_config(tmpdir, deepspeed_config): +def test_deepspeed_config(tmpdir, deepspeed_zero_config): """ Test to ensure deepspeed works correctly when passed a DeepSpeed config object including optimizers/schedulers and saves the model weights to load correctly. @@ -235,18 +280,22 @@ def test_deepspeed_config(tmpdir, deepspeed_config): class TestModel(BoringModel): def on_train_start(self) -> None: - import deepspeed - assert isinstance(self.trainer.optimizers[0], torch.optim.SGD) + from deepspeed.runtime.lr_schedules import WarmupLR + from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer + + assert isinstance(self.trainer.optimizers[0], FP16_DeepSpeedZeroOptimizer) + assert isinstance(self.trainer.optimizers[0].optimizer, torch.optim.SGD) assert self.trainer.lr_schedulers == [] # DeepSpeed manages LR scheduler internally - assert isinstance(self.trainer.model.optimizer, torch.optim.SGD) - assert isinstance(self.trainer.model.lr_scheduler, deepspeed.runtime.lr_schedules.WarmupLR) + # Ensure DeepSpeed engine has initialized with our optimizer/lr_scheduler + assert isinstance(self.trainer.model.lr_scheduler, WarmupLR) model = TestModel() trainer = Trainer( - plugins=[DeepSpeedPlugin(config=deepspeed_config)], + plugins=[DeepSpeedPlugin(config=deepspeed_zero_config)], default_root_dir=tmpdir, gpus=1, fast_dev_run=True, + precision=16 ) trainer.fit(model) @@ -267,7 +316,7 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config): """ model = BoringModel() trainer = Trainer( - plugins=[DeepSpeedPlugin(zero_optimization=False)], + plugins=[DeepSpeedPlugin()], default_root_dir=tmpdir, gpus=2, fast_dev_run=True, @@ -285,8 +334,9 @@ def _assert_save_model_is_equal(model, tmpdir, trainer): # carry out the check only on rank 0 if trainer.global_rank == 0: saved_model = BoringModel.load_from_checkpoint(checkpoint_path) - saved_model = saved_model.float() - model = model.float().cpu() + if model.dtype == torch.half: + saved_model = saved_model.half() # model is loaded in float32 as default, move it to float16 + model = model.cpu() # Assert model parameters are identical after loading for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()): assert torch.equal(orig_param, trained_model_param) diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index 17e208022a5ac..45a85744f0415 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -101,12 +101,19 @@ def test_submodules_multi_gpu_ddp_spawn(tmpdir): trainer.fit(model) +@pytest.mark.parametrize( + ['device'], + [ + pytest.param(None), # explicitly call without an index to see if the returning device contains an index + pytest.param(0), + pytest.param(torch.device('cuda', 0)), + ] +) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_gpu_device_includes_index(): +def test_gpu_cuda_device(device): model = TopModule() - # explicitly call without an index to see if the returning device contains an index (it should!) - model.cuda() + model.cuda(device) device = model.device assert device.type == 'cuda'