Skip to content

Commit

Permalink
Remove 0.6.0 check
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Jul 27, 2022
1 parent 1318cdc commit 811c432
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 17 deletions.
8 changes: 0 additions & 8 deletions src/pytorch_lightning/plugins/precision/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@
from pytorch_lightning.utilities import GradClipAlgorithmType
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _RequirementAvailable
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.warnings import WarningCache

_DEEPSPEED_GREATER_EQUAL_0_6 = _RequirementAvailable("deepspeed>=0.6.0")
if TYPE_CHECKING:
if pl.strategies.deepspeed._DEEPSPEED_AVAILABLE:
import deepspeed
Expand All @@ -52,12 +50,6 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
"""

def __init__(self, precision: Union[str, int], amp_type: str, amp_level: Optional[str] = None) -> None:
if precision == PrecisionType.BFLOAT and not _DEEPSPEED_GREATER_EQUAL_0_6:
raise MisconfigurationException(
f"`Trainer(strategy='deepspeed', precision={precision!r})` is not supported"
" with `deepspeed < v0.6`. Please upgrade it using `pip install -U deepspeed`."
)

supported_precision = (PrecisionType.HALF, PrecisionType.FLOAT, PrecisionType.BFLOAT, PrecisionType.MIXED)
if precision not in supported_precision:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock

import pytest

from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def test_invalid_precision_with_deepspeed_precision():
with pytest.raises(ValueError, match="is not supported. `precision` must be one of"):
DeepSpeedPrecisionPlugin(precision=64, amp_type="native")


@mock.patch("pytorch_lightning.plugins.precision.deepspeed._DEEPSPEED_GREATER_EQUAL_0_6", False)
def test_incompatible_bfloat16_raises_error_with_deepspeed_version():
with pytest.raises(MisconfigurationException, match="is not supported with `deepspeed < v0.6`"):
DeepSpeedPrecisionPlugin(precision="bf16", amp_type="native")

0 comments on commit 811c432

Please sign in to comment.