Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ZeRO tests for CI, fix to/half function calls for LightningDistributedWrapper #6070

Merged
merged 20 commits into from
Feb 21, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pytorch_lightning/plugins/training_type/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def forward(self, *inputs, **kwargs):

return super().forward(*inputs, **kwargs)

def half(self):
self.module.half()

def to(self, *args, **kwargs):
self.module.to(*args, **kwargs)

@staticmethod
def batch_to(data):
return data.half()
Expand Down
29 changes: 24 additions & 5 deletions tests/plugins/test_deepspeed_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from unittest.mock import patch

import pytest
import torch
Expand All @@ -8,11 +9,28 @@

from pytorch_lightning import Trainer
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule
from pytorch_lightning.utilities import _APEX_AVAILABLE, _DEEPSPEED_AVAILABLE, _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel


@patch.object(BoringModel, 'to')
def test_deepspeed_wrapper(mocked_to, tmpdir):
"""
Test to ensure that a model wrapped in `LightningDeepSpeedModule` moves types and device correctly.
"""

model = BoringModel()
module = LightningDeepSpeedModule(model, precision=16)

module.half()
assert model.dtype == torch.half

module.to('cuda')
assert mocked_to.called, "LightningDeepSpeedModule did not call LightningModule `to` hook when transferring device"


@pytest.fixture
def deepspeed_config():
return {
Expand Down Expand Up @@ -182,7 +200,7 @@ def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
plugins=DeepSpeedPlugin(zero_optimization=False),
plugins=DeepSpeedPlugin(),
gpus=1,
)
with pytest.warns(UserWarning, match='Overridden backward hook in the LightningModule will be ignored'):
Expand Down Expand Up @@ -210,7 +228,7 @@ def on_train_start(self) -> None:

model = TestModel()
trainer = Trainer(
plugins=DeepSpeedPlugin(zero_optimization=False),
plugins=DeepSpeedPlugin(),
default_root_dir=tmpdir,
gpus=1,
fast_dev_run=True,
Expand Down Expand Up @@ -267,7 +285,7 @@ def test_deepspeed_multigpu(tmpdir, deepspeed_config):
"""
model = BoringModel()
trainer = Trainer(
plugins=[DeepSpeedPlugin(zero_optimization=False)],
plugins=[DeepSpeedPlugin()],
default_root_dir=tmpdir,
gpus=2,
fast_dev_run=True,
Expand All @@ -285,8 +303,9 @@ def _assert_save_model_is_equal(model, tmpdir, trainer):
# carry out the check only on rank 0
if trainer.global_rank == 0:
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
saved_model = saved_model.float()
model = model.float().cpu()
if model.dtype == torch.half:
saved_model = saved_model.half() # model is loaded in float32 as default, move it to float16
model = model.cpu()
# Assert model parameters are identical after loading
for orig_param, trained_model_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(orig_param, trained_model_param)