diff --git a/pytorch_lightning/plugins/precision/double.py b/pytorch_lightning/plugins/precision/double.py index 5e9e8bd43b820..0c155232da96d 100644 --- a/pytorch_lightning/plugins/precision/double.py +++ b/pytorch_lightning/plugins/precision/double.py @@ -11,92 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager -from typing import Any, cast, Generator, List, Tuple - import torch -import torch.nn as nn -from torch.optim import Optimizer - -import pytorch_lightning as pl -from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase -from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin -from pytorch_lightning.utilities.apply_func import apply_to_collection - - -class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase): - """LightningModule wrapper which converts incoming floating point data in ``*_step`` and ``forward`` to double - (``torch.float64``) precision. - - Args: - pl_module: the model to wrap - """ - - @staticmethod - def _to_double_precision(data: torch.Tensor) -> torch.Tensor: - if data.is_floating_point(): - return data.double() - return data - - @staticmethod - def _move_float_tensors_to_double(collection: Any) -> Any: - return apply_to_collection(collection, torch.Tensor, LightningDoublePrecisionModule._to_double_precision) - - def training_step(self, *args: Any, **kwargs: Any) -> Any: - return self.module.training_step( - *LightningDoublePrecisionModule._move_float_tensors_to_double(args), - **LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs), - ) +import torch.nn - def validation_step(self, *args: Any, **kwargs: Any) -> Any: - return self.module.validation_step( - *LightningDoublePrecisionModule._move_float_tensors_to_double(args), - **LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs), - ) +from pytorch_lightning.plugins.precision.dtype import DtypePrecisionPlugin - def test_step(self, *args: Any, **kwargs: Any) -> Any: - return self.module.test_step( - *LightningDoublePrecisionModule._move_float_tensors_to_double(args), - **LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs), - ) - def predict_step(self, *args: Any, **kwargs: Any) -> Any: - return self.module.predict_step( - *LightningDoublePrecisionModule._move_float_tensors_to_double(args), - **LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs), - ) - - def forward(self, *args: Any, **kwargs: Any) -> Any: - return self.module( - *LightningDoublePrecisionModule._move_float_tensors_to_double(args), - **LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs), - ) - - -class DoublePrecisionPlugin(PrecisionPlugin): +class DoublePrecisionPlugin(DtypePrecisionPlugin): """Plugin for training with double (``torch.float64``) precision.""" precision: int = 64 - def connect( - self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any] - ) -> Tuple[nn.Module, List["Optimizer"], List[Any]]: - """Converts the model to double precision and wraps it in a ``LightningDoublePrecisionModule`` to convert - incoming floating point data to double (``torch.float64``) precision. - - Does not alter `optimizers` or `lr_schedulers`. - """ - model = cast(pl.LightningModule, model.double()) - model = LightningDoublePrecisionModule(model) - - return super().connect(model, optimizers, lr_schedulers) - - @contextmanager - def forward_context(self) -> Generator[None, None, None]: - """A context manager to change the default tensor type. - - See: :meth:`torch.set_default_tensor_type` - """ - torch.set_default_tensor_type(torch.DoubleTensor) - yield - torch.set_default_tensor_type(torch.FloatTensor) + def __init__(self) -> None: + super().__init__(torch.double) diff --git a/pytorch_lightning/plugins/precision/dtype.py b/pytorch_lightning/plugins/precision/dtype.py new file mode 100644 index 0000000000000..9011333e0919f --- /dev/null +++ b/pytorch_lightning/plugins/precision/dtype.py @@ -0,0 +1,97 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager +from typing import Any, Generator, List, Tuple + +import torch +import torch.nn +from torch.nn import Module +from torch.optim import Optimizer + +import pytorch_lightning as pl +from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin +from pytorch_lightning.utilities.apply_func import apply_to_collection + + +class LightningPrecisionModule(_LightningPrecisionModuleWrapperBase): + """LightningModule wrapper which converts incoming data in ``*_step`` and ``forward`` to a specific + precision.""" + + def __init__(self, pl_module: "pl.LightningModule", dtype: torch.dtype) -> None: + """Wraps the user's LightningModule. + + Requires overriding all ``*_step`` methods and ``forward`` so that it can safely be wrapped by a + ``_LightningModuleWrapperBase`` and a ``*DataParallel``. + """ + super().__init__(pl_module) + self.__dtype = dtype + + def _move_tensors(self, *args: Any, **kwargs: Any) -> Any: + return apply_to_collection([args, kwargs], function=lambda t: t.to(self.__dtype), dtype=torch.Tensor) + + def training_step(self, *args: Any, **kwargs: Any) -> Any: + args, kwargs = self._move_tensors(*args, **kwargs) + return self.module.training_step(*args, **kwargs) + + def validation_step(self, *args: Any, **kwargs: Any) -> Any: + args, kwargs = self._move_tensors(*args, **kwargs) + return self.module.validation_step(*args, **kwargs) + + def test_step(self, *args: Any, **kwargs: Any) -> Any: + args, kwargs = self._move_tensors(*args, **kwargs) + return self.module.test_step(*args, **kwargs) + + def predict_step(self, *args: Any, **kwargs: Any) -> Any: + args, kwargs = self._move_tensors(*args, **kwargs) + return self.module.predict_step(*args, **kwargs) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + args, kwargs = self._move_tensors(*args, **kwargs) + return self.module(*args, **kwargs) + + +@contextmanager +def autodtype(dtype: torch.dtype) -> Generator[None, None, None]: + """A context manager to change the default tensor type. + + See: :meth:`torch.set_default_dtype` + """ + previous = torch.get_default_dtype() + torch.set_default_dtype(dtype) + try: + yield + finally: + # make sure the default dtype is restored. otherwise, the new dtype can leak if the program fails + torch.set_default_dtype(previous) + + +class DtypePrecisionPlugin(PrecisionPlugin): + """Plugin for training with double a specific :class:`torch.dtype`.""" + + def __init__(self, dtype: torch.dtype) -> None: + self.__dtype = dtype + + def connect( + self, model: "pl.LightningModule", optimizers: List[Optimizer], lr_schedulers: List[Any] + ) -> Tuple[Module, List[Optimizer], List[Any]]: + """Wraps the model it in a ``LightningPrecisionModule`` to convert incoming data to a specific + precision.""" + model = LightningPrecisionModule(model, self.__dtype) + return super().connect(model, optimizers, lr_schedulers) + + @contextmanager + def forward_context(self) -> Generator[None, None, None]: + with autodtype(self.__dtype): + yield diff --git a/tests/plugins/test_double_plugin.py b/tests/plugins/test_double_plugin.py index 6893a7aee324a..7c15fa4f3214e 100644 --- a/tests/plugins/test_double_plugin.py +++ b/tests/plugins/test_double_plugin.py @@ -20,6 +20,7 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins import DoublePrecisionPlugin +from pytorch_lightning.plugins.precision.dtype import autodtype from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_7 from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -124,12 +125,18 @@ def predict_dataloader(self): class DoublePrecisionBoringModelComplexBuffer(BoringModel): def __init__(self): super().__init__() - - self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False) + self.register_buffer("complex_buffer", torch.tensor([1.2, 3.4j]), False) def on_fit_start(self): - assert self.layer.weight.dtype == torch.float64 - assert self.complex_buffer.dtype == torch.complex64 + super().on_fit_start() + # when the default floating point type is float64 the default complex type is complex128 + assert self.complex_buffer.dtype == torch.complex128 + # this hook is not wrapped. # TODO: should it be? + assert torch.tensor([1.2, 3.4j]).dtype == torch.complex64 + + def training_step(self, batch, batch_idx): + assert torch.tensor([1.2, 3.4j]).dtype == torch.complex128 + return super().training_step(batch, batch_idx) @pytest.mark.parametrize( @@ -144,9 +151,9 @@ def on_fit_start(self): ], ) def test_double_precision(tmpdir, boring_model): - model = boring_model() - trainer = Trainer(max_epochs=2, default_root_dir=tmpdir, fast_dev_run=2, precision=64, log_every_n_steps=1) + with autodtype(torch.double): + model = boring_model() trainer.fit(model) trainer.test(model) trainer.predict(model) @@ -154,8 +161,6 @@ def test_double_precision(tmpdir, boring_model): @RunIf(min_gpus=2) def test_double_precision_ddp(tmpdir): - model = DoublePrecisionBoringModel() - trainer = Trainer( max_epochs=1, default_root_dir=tmpdir, @@ -165,6 +170,8 @@ def test_double_precision_ddp(tmpdir): precision=64, log_every_n_steps=1, ) + with trainer.precision_plugin.forward_context(): + model = DoublePrecisionBoringModel() trainer.fit(model) @@ -173,3 +180,21 @@ def test_double_precision_pickle(tmpdir): plugin = DoublePrecisionPlugin() model, _, __ = plugin.connect(model, MagicMock(), MagicMock()) pickle.dumps(model) + + +def test_double_precision_restores_dtype(): + class DummyException(BaseException): + ... + + class Model(BoringModel): + def training_step(self, batch, batch_idx): + assert torch.get_default_dtype() == torch.double + raise DummyException + + model = Model() + trainer = Trainer(precision=64, num_sanity_val_steps=0) + + assert torch.get_default_dtype() == torch.float + with pytest.raises(DummyException): + trainer.fit(model) + assert torch.get_default_dtype() == torch.float