Skip to content

Commit

Permalink
Support DeepSpeed >=0.6.0, <0.6.5 (#13863)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
  • Loading branch information
carmocca and awaelchli authored Jul 27, 2022
1 parent fff62f0 commit 511875e
Show file tree
Hide file tree
Showing 10 changed files with 42 additions and 39 deletions.
1 change: 1 addition & 0 deletions .azure/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ jobs:
CUDA_VERSION_MM=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))")
pip install "bagua-cuda$CUDA_VERSION_MM>=0.9.0"
pip install -e .[strategies]
pip install deepspeed==0.6.4 # TODO: remove when docker images are upgraded
pip install --requirement requirements/pytorch/devel.txt
pip list
env:
Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/strategies.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
fairscale>=0.4.5, <=0.4.6
deepspeed<0.6.0
deepspeed>=0.6.0, <0.6.5
# no need to install with [pytorch] as pytorch is already installed
horovod>=0.21.2, !=0.24.0, <0.25.1
hivemind>=1.0.1, <=1.0.1; sys_platform == 'linux'
3 changes: 2 additions & 1 deletion src/pytorch_lightning/plugins/precision/apex_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def backward(
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -71,7 +72,7 @@ def backward(
"""
opt = optimizer or model.trainer.optimizers
with amp.scale_loss(closure_loss, opt) as closure_loss:
super().backward(model, closure_loss, optimizer, *args, **kwargs)
super().backward(model, closure_loss, optimizer, optimizer_idx, *args, **kwargs)

def optimizer_step(
self,
Expand Down
22 changes: 11 additions & 11 deletions src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@
from pytorch_lightning.utilities.warnings import WarningCache

_DEEPSPEED_AVAILABLE = _RequirementAvailable("deepspeed")
_DEEPSPEED_GREATER_EQUAL_0_6 = _RequirementAvailable("deepspeed>=0.6.0")
if TYPE_CHECKING:
if _DEEPSPEED_AVAILABLE:
import deepspeed
if TYPE_CHECKING and _DEEPSPEED_AVAILABLE:
import deepspeed

warning_cache = WarningCache()

Expand All @@ -53,12 +51,6 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
"""

def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None:
if precision == PrecisionType.BFLOAT and not _DEEPSPEED_GREATER_EQUAL_0_6:
raise MisconfigurationException(
f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported"
" with `deepspeed < v0.6`. Please upgrade it using `pip install -U deepspeed`."
)

supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT, PrecisionType.MIXED)
if precision not in supported_precision:
raise ValueError(
Expand All @@ -71,7 +63,15 @@ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optiona
self.amp_type = amp_type
self.amp_level = amp_level

def backward(self, model: "pl.LightningModule", closure_loss: Tensor, *args: Any, **kwargs: Any) -> None:
def backward(
self,
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
if is_overridden("backward", model):
warning_cache.warn(
"You have overridden the `LightningModule.backward` hook but it will be ignored since DeepSpeed handles"
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/plugins/precision/ipu.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, precision: int) -> None:
super().__init__()
self.precision = precision

def backward(self, model: "pl.LightningModule", *args: Any, **kwargs: Any) -> None:
def backward(self, model: "pl.LightningModule", *_: Any, **__: Any) -> None:
if is_overridden("backward", model):
warning_cache.warn(
"You have overridden the `LightningModule.backward` hook but it will be ignored since IPUs handle"
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/plugins/precision/precision_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def backward(
model: "pl.LightningModule",
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
Expand All @@ -76,7 +77,7 @@ def backward(
"""
# do backward pass
if model is not None and isinstance(model, pl.LightningModule):
model.backward(closure_loss, optimizer, *args, **kwargs)
model.backward(closure_loss, optimizer, optimizer_idx, *args, **kwargs)
else:
self._run_backward(closure_loss, *args, **kwargs)

Expand Down
11 changes: 9 additions & 2 deletions src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,14 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
"""
return optimizer.state_dict()

def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
def backward(
self,
closure_loss: Tensor,
optimizer: Optional[Optimizer],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> Tensor:
"""Forwards backward-calls to the precision plugin.
Args:
Expand All @@ -181,7 +188,7 @@ def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
assert self.lightning_module is not None
closure_loss = self.precision_plugin.pre_backward(self.lightning_module, closure_loss)

self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
self.precision_plugin.backward(self.lightning_module, closure_loss, optimizer, optimizer_idx, *args, **kwargs)

closure_loss = self.precision_plugin.post_backward(self.lightning_module, closure_loss)
self.post_backward(closure_loss)
Expand Down
18 changes: 14 additions & 4 deletions tests/tests_pytorch/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,15 +412,21 @@ def run(self):
model = BoringModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
model, optimizer = self.setup(model, optimizer)
state_dict = deepcopy(model.state_dict())

for _ in range(2):
for i in range(2):
optimizer.zero_grad()
x = model(torch.randn(1, 32).to(self.device))
loss = x.sum()
if i == 0:
# the weights are not initialized with stage 3 until backward is run once
assert all(w.nelement() == 0 for w in model.state_dict().values())
self.backward(loss, model=model)
if i == 0:
# save for later to check that the weights were updated
state_dict = deepcopy(model.state_dict())
optimizer.step()

# check that the model trained, the weights from step 1 do not match the weights from step 2
for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()):
assert not torch.allclose(mw_b, mw_a)

Expand All @@ -438,6 +444,7 @@ def run(self):
model_1, optimizer_1 = self.setup(model_1, optimizer_1)
model_2, optimizer_2 = self.setup(model_2, optimizer_2)

# train model_1 first
self.seed_everything(42)
data_list = []
for _ in range(2):
Expand All @@ -449,16 +456,19 @@ def run(self):
self.backward(loss, model=model_1)
optimizer_1.step()

for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
assert not torch.allclose(mw_1, mw_2)
# the weights do not match
assert all(w.nelement() > 1 for w in model_1.state_dict().values())
assert all(w.nelement() == 0 for w in model_2.state_dict().values())

# now train model_2 with the same data
for data in data_list:
optimizer_2.zero_grad()
x = model_2(data)
loss = x.sum()
self.backward(loss, model=model_2)
optimizer_2.step()

# the weights should match
for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
assert torch.allclose(mw_1, mw_2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import pytest

from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def test_invalid_precision_with_deepspeed_precision():
with pytest.raises(ValueError, match="is not supported. `precision` must be one of"):
DeepSpeedPrecisionPlugin(precision=64, amp_type="native")


@mock.patch("pytorch_lightning.plugins.precision.deepspeed._DEEPSPEED_GREATER_EQUAL_0_6", False)
def test_incompatible_bfloat16_raises_error_with_deepspeed_version():
with pytest.raises(MisconfigurationException, match="is not supported with `deepspeed < v0.6`"):
DeepSpeedPrecisionPlugin(precision="bf16", amp_type="native")
10 changes: 1 addition & 9 deletions tests/tests_pytorch/strategies/test_deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,19 @@
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.precision.deepspeed import _DEEPSPEED_GREATER_EQUAL_0_6
from pytorch_lightning.strategies import DeepSpeedStrategy
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _RequirementAvailable
from pytorch_lightning.utilities.meta import init_meta_context
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.datasets import RandomIterableDataset
from tests_pytorch.helpers.runif import RunIf

if _DEEPSPEED_AVAILABLE:
import deepspeed
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict

_DEEPSPEED_GREATER_EQUAL_0_5_9 = _RequirementAvailable("deepspeed>=0.5.9")
if _DEEPSPEED_GREATER_EQUAL_0_5_9:
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
else:
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer as DeepSpeedZeroOptimizer


class ModelParallelBoringModel(BoringModel):
def __init__(self):
Expand Down Expand Up @@ -1294,7 +1287,6 @@ def training_step(self, *args, **kwargs):


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
@pytest.mark.skipif(not _DEEPSPEED_GREATER_EQUAL_0_6, reason="requires deepspeed >= 0.6")
def test_deepspeed_with_bfloat16_precision(tmpdir):
"""Test that deepspeed works with bfloat16 precision."""
model = BoringModel()
Expand Down

0 comments on commit 511875e

Please sign in to comment.