Skip to content

Commit

Permalink
Remove <0.6.0 checks
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jul 27, 2022
1 parent 1318cdc commit edff47c
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 26 deletions.
8 changes: 0 additions & 8 deletions src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _RequirementAvailable
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache

_DEEPSPEED_GREATER_EQUAL_0_6 = _RequirementAvailable("deepspeed>=0.6.0")
if TYPE_CHECKING:
if pl.strategies.deepspeed._DEEPSPEED_AVAILABLE:
import deepspeed
Expand All @@ -52,12 +50,6 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
"""

def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None:
if precision == PrecisionType.BFLOAT and not _DEEPSPEED_GREATER_EQUAL_0_6:
raise MisconfigurationException(
f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported"
" with `deepspeed < v0.6`. Please upgrade it using `pip install -U deepspeed`."
)

supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT, PrecisionType.MIXED)
if precision not in supported_precision:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import pytest

from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def test_invalid_precision_with_deepspeed_precision():
with pytest.raises(ValueError, match="is not supported. `precision` must be one of"):
DeepSpeedPrecisionPlugin(precision=64, amp_type="native")


@mock.patch("pytorch_lightning.plugins.precision.deepspeed._DEEPSPEED_GREATER_EQUAL_0_6", False)
def test_incompatible_bfloat16_raises_error_with_deepspeed_version():
with pytest.raises(MisconfigurationException, match="is not supported with `deepspeed < v0.6`"):
DeepSpeedPrecisionPlugin(precision="bf16", amp_type="native")
10 changes: 1 addition & 9 deletions tests/tests_pytorch/strategies/test_deepspeed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,19 @@
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset
from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.precision.deepspeed import _DEEPSPEED_GREATER_EQUAL_0_6
from pytorch_lightning.strategies import DeepSpeedStrategy
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _RequirementAvailable
from pytorch_lightning.utilities.meta import init_meta_context
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.datasets import RandomIterableDataset
from tests_pytorch.helpers.runif import RunIf

if _DEEPSPEED_AVAILABLE:
import deepspeed
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict

_DEEPSPEED_GREATER_EQUAL_0_5_9 = _RequirementAvailable("deepspeed>=0.5.9")
if _DEEPSPEED_GREATER_EQUAL_0_5_9:
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
else:
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer as DeepSpeedZeroOptimizer


class ModelParallelBoringModel(BoringModel):
def __init__(self):
Expand Down Expand Up @@ -1294,7 +1287,6 @@ def training_step(self, *args, **kwargs):


@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
@pytest.mark.skipif(not _DEEPSPEED_GREATER_EQUAL_0_6, reason="requires deepspeed >= 0.6")
def test_deepspeed_with_bfloat16_precision(tmpdir):
"""Test that deepspeed works with bfloat16 precision."""
model = BoringModel()
Expand Down

0 comments on commit edff47c

Please sign in to comment.