diff --git a/src/pytorch_lightning/plugins/precision/deepspeed.py b/src/pytorch_lightning/plugins/precision/deepspeed.py index 4cc12de400ef44..0e0b2d833647cc 100644 --- a/src/pytorch_lightning/plugins/precision/deepspeed.py +++ b/src/pytorch_lightning/plugins/precision/deepspeed.py @@ -22,11 +22,9 @@ from pytorch_lightning.utilities import GradClipAlgorithmType from pytorch_lightning.utilities.enums import PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _RequirementAvailable from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.warnings import WarningCache -_DEEPSPEED_GREATER_EQUAL_0_6 = _RequirementAvailable("deepspeed>=0.6.0") if TYPE_CHECKING: if pl.strategies.deepspeed._DEEPSPEED_AVAILABLE: import deepspeed @@ -52,12 +50,6 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): """ def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None: - if precision == PrecisionType.BFLOAT and not _DEEPSPEED_GREATER_EQUAL_0_6: - raise MisconfigurationException( - f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported" - " with `deepspeed < v0.6`. Please upgrade it using `pip install -U deepspeed`." - ) - supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT, PrecisionType.MIXED) if precision not in supported_precision: raise ValueError( diff --git a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py index 8a68f7c73209b2..a4698e7c19c97c 100644 --- a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py @@ -11,20 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest import mock - import pytest from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin -from pytorch_lightning.utilities.exceptions import MisconfigurationException def test_invalid_precision_with_deepspeed_precision(): with pytest.raises(ValueError, match="is not supported. `precision` must be one of"): DeepSpeedPrecisionPlugin(precision=64, amp_type="native") - - -@mock.patch("pytorch_lightning.plugins.precision.deepspeed._DEEPSPEED_GREATER_EQUAL_0_6", False) -def test_incompatible_bfloat16_raises_error_with_deepspeed_version(): - with pytest.raises(MisconfigurationException, match="is not supported with `deepspeed < v0.6`"): - DeepSpeedPrecisionPlugin(precision="bf16", amp_type="native")