diff --git a/.azure/gpu-tests.yml b/.azure/gpu-tests.yml index 74c1df4553fe0..7e5af0ea6cfe6 100644 --- a/.azure/gpu-tests.yml +++ b/.azure/gpu-tests.yml @@ -73,7 +73,7 @@ jobs: CUDA_VERSION_MM=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda.split('.')[:2])))") pip install "bagua-cuda$CUDA_VERSION_MM>=0.9.0" pip install -e .[strategies] - pip install deepspeed==0.6.4 # TODO: remove when docker images are upgraded + pip install deepspeed>0.6.4 # TODO: remove when docker images are upgraded pip install --requirement requirements/pytorch/devel.txt pip list env: diff --git a/requirements/pytorch/strategies.txt b/requirements/pytorch/strategies.txt index db29ce556e839..4eafac99b8c66 100644 --- a/requirements/pytorch/strategies.txt +++ b/requirements/pytorch/strategies.txt @@ -1,5 +1,5 @@ fairscale>=0.4.5, <=0.4.6 -deepspeed>=0.6.0, <0.6.5 +deepspeed>=0.6.0, <0.7.0 # no need to install with [pytorch] as pytorch is already installed horovod>=0.21.2, !=0.24.0, <0.25.1 hivemind>=1.0.1, <=1.0.1; sys_platform == 'linux' diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 981eed30635f6..5125bf4486a9d 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -40,6 +40,7 @@ has_iterable_dataset, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.seed import seed_everything @@ -105,6 +106,8 @@ def __init__( self._precision_plugin = self._strategy.precision_plugin self._models_setup: int = 0 + self._check_deepspeed_support() + # wrap the run method so we can inject setup logic or spawn processes for the user setattr(self, "run", partial(self._run_impl, self.run)) @@ -456,6 +459,18 @@ def _check_strategy_support(self, strategy: Optional[Union[str, Strategy]]) -> N f" Choose one of {supported} or pass in a `Strategy` instance." ) + def _check_deepspeed_support(self) -> None: + if ( + isinstance(self._strategy, DeepSpeedStrategy) + and self._strategy.zero_stage_3 + and _RequirementAvailable("deepspeed>=0.6.5") + ): + # https://github.com/microsoft/DeepSpeed/issues/2139 + raise RuntimeError( + "DeepSpeed ZeRO-3 is not supported with this version of Lightning Lite and `deepspeed>=0.6.5`." + " Please downgrade deepspeed to 0.6.4 or check if a newer version of Lightning is available." + ) + @staticmethod def _supported_device_types() -> Sequence[_AcceleratorType]: return ( diff --git a/src/pytorch_lightning/utilities/deepspeed_model_summary.py b/src/pytorch_lightning/utilities/deepspeed_model_summary.py index 89dd6a9f9a25f..45d55392df51d 100644 --- a/src/pytorch_lightning/utilities/deepspeed_model_summary.py +++ b/src/pytorch_lightning/utilities/deepspeed_model_summary.py @@ -17,7 +17,9 @@ from typing import Dict, List, Tuple import torch +from torch.nn import Parameter +from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.model_summary import ( _is_lazy_weight_tensor, get_human_readable_count, @@ -40,7 +42,11 @@ def num_parameters(self) -> int: @property def average_shard_parameters(self) -> int: """Returns the number of parameters in this module.""" - return sum(p.partitioned_size() if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()) + + def partitioned_size(p: Parameter) -> int: + return p.partitioned_size() if _RequirementAvailable("deepspeed<0.6.6") else p.partition_numel() + + return sum(partitioned_size(p) if not _is_lazy_weight_tensor(p) else 0 for p in self._module.parameters()) class DeepSpeedSummary(ModelSummary): diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index 86a0a5a82195a..2215ab3129780 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import os from copy import deepcopy from unittest import mock @@ -29,6 +30,7 @@ from pytorch_lightning.strategies import DeepSpeedStrategy, Strategy from pytorch_lightning.utilities import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.seed import pl_worker_init_function from tests_pytorch.helpers.runif import RunIf @@ -478,4 +480,13 @@ def run(self): assert self.broadcast(True) assert self.is_global_zero == (self.local_rank == 0) - Lite(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run() + if _RequirementAvailable("deepspeed>=0.6.5"): + # https://github.com/microsoft/DeepSpeed/issues/2139 + raise_if_deepspeed_incompatible = pytest.raises( + RuntimeError, match="DeepSpeed ZeRO-3 is not supported with this version of Lightning Lite" + ) + else: + raise_if_deepspeed_incompatible = contextlib.suppress() + + with raise_if_deepspeed_incompatible: + Lite(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()