Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add context manager to properly convert the precision #10079

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 53 additions & 53 deletions pytorch_lightning/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, cast, Generator, List, Tuple
from typing import Any, Generator, List, Tuple

import torch
import torch.nn as nn
import torch.nn
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
Expand All @@ -24,79 +25,78 @@
from pytorch_lightning.utilities.apply_func import apply_to_collection


class LightningDoublePrecisionModule(_LightningPrecisionModuleWrapperBase):
"""LightningModule wrapper which converts incoming floating point data in ``*_step`` and ``forward`` to double
(``torch.float64``) precision.
class LightningPrecisionModule(_LightningPrecisionModuleWrapperBase):
"""LightningModule wrapper which converts incoming data in ``*_step`` and ``forward`` to a specific
precision."""

Args:
pl_module: the model to wrap
"""
def __init__(self, pl_module: "pl.LightningModule", dtype: torch.dtype) -> None:
"""Wraps the user's LightningModule.

@staticmethod
def _to_double_precision(data: torch.Tensor) -> torch.Tensor:
if data.is_floating_point():
return data.double()
return data
Requires overriding all ``*_step`` methods and ``forward`` so that it can safely be wrapped by a
``_LightningModuleWrapperBase`` and a ``*DataParallel``.
"""
super().__init__(pl_module)
self.__dtype = dtype

@staticmethod
def _move_float_tensors_to_double(collection: Any) -> Any:
return apply_to_collection(collection, torch.Tensor, LightningDoublePrecisionModule._to_double_precision)
def _to(data: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return data.to(dtype)

def _move_tensors(self, collection: Any) -> Any:
return apply_to_collection(collection, torch.Tensor, LightningPrecisionModule._to, self.__dtype)

def training_step(self, *args: Any, **kwargs: Any) -> Any:
return self.module.training_step(
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
)
return self.module.training_step(*self._move_tensors(args), **self._move_tensors(kwargs))

def validation_step(self, *args: Any, **kwargs: Any) -> Any:
return self.module.validation_step(
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
)
return self.module.validation_step(*self._move_tensors(args), **self._move_tensors(kwargs))

def test_step(self, *args: Any, **kwargs: Any) -> Any:
return self.module.test_step(
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
)
return self.module.test_step(*self._move_tensors(args), **self._move_tensors(kwargs))

def predict_step(self, *args: Any, **kwargs: Any) -> Any:
return self.module.predict_step(
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
)
return self.module.predict_step(*self._move_tensors(args), **self._move_tensors(kwargs))

def forward(self, *args: Any, **kwargs: Any) -> Any:
return self.module(
*LightningDoublePrecisionModule._move_float_tensors_to_double(args),
**LightningDoublePrecisionModule._move_float_tensors_to_double(kwargs),
)
return self.module(*self._move_tensors(args), **self._move_tensors(kwargs))


class DoublePrecisionPlugin(PrecisionPlugin):
"""Plugin for training with double (``torch.float64``) precision."""
class DtypePrecisionPlugin(PrecisionPlugin):
"""Plugin for training with double a specific :class:`torch.dtype`."""

precision: int = 64
def __init__(self, dtype: torch.dtype) -> None:
self.__dtype = dtype

def connect(
self, model: nn.Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
) -> Tuple[nn.Module, List["Optimizer"], List[Any]]:
"""Converts the model to double precision and wraps it in a ``LightningDoublePrecisionModule`` to convert
incoming floating point data to double (``torch.float64``) precision.

Does not alter `optimizers` or `lr_schedulers`.
"""
model = cast(pl.LightningModule, model.double())
model = LightningDoublePrecisionModule(model)

self, model: Module, optimizers: List[Optimizer], lr_schedulers: List[Any]
) -> Tuple[Module, List[Optimizer], List[Any]]:
"""Wraps the model it in a ``LightningPrecisionModule`` to convert incoming data to a specific
precision."""
model = LightningPrecisionModule(model, self.__dtype)
return super().connect(model, optimizers, lr_schedulers)

@contextmanager
def forward_context(self) -> Generator[None, None, None]:
def autodtype(self) -> Generator[None, None, None]:
"""A context manager to change the default tensor type.

See: :meth:`torch.set_default_tensor_type`
See: :meth:`torch.set_default_dtype`
"""
torch.set_default_tensor_type(torch.DoubleTensor)
yield
torch.set_default_tensor_type(torch.FloatTensor)
previous = torch.get_default_dtype()
torch.set_default_dtype(self.__dtype)
try:
yield
finally:
# make sure the default dtype is restored. otherwise, the new dtype can leak if the program fails
torch.set_default_dtype(previous)

def forward_context(self) -> Generator[None, None, None]:
return self.autodtype()


class DoublePrecisionPlugin(DtypePrecisionPlugin):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be a dataclass ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, but I don't think we want to. It's still a PrecisionPlugin (not a dataclass)

"""Plugin for training with double (``torch.float64``) precision."""

precision: int = 64

def __init__(self):
super().__init__(torch.double)
40 changes: 32 additions & 8 deletions tests/plugins/test_double_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,18 @@ def predict_dataloader(self):
class DoublePrecisionBoringModelComplexBuffer(BoringModel):
def __init__(self):
super().__init__()

self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False)
self.register_buffer("complex_buffer", torch.tensor([1.2, 3.4j]), False)

def on_fit_start(self):
assert self.layer.weight.dtype == torch.float64
assert self.complex_buffer.dtype == torch.complex64
super().on_fit_start()
# when the default floating point type is float64 the default complex type is complex128
assert self.complex_buffer.dtype == torch.complex128
# this hook is not wrapped. # TODO: should it be?
assert torch.tensor([1.2, 3.4j]).dtype == torch.complex64
Comment on lines +134 to +135
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure whether this is working as expected or a bug. The precision context manager is only active during the forward context, and this hook is not part of it.

Should we instead enter the context manager on setup and exit on teardown?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure, but I would say yes. real + img in float32 -> complex64, and real + img in float64 -> complex128. Makes sense to me at least.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's as expected.

The problem here is that we only wrap the precision for the forward hooks.

So, other hooks like setup and on_fit_start are not wrapped and as tested here, they do not use the correct precision.

Maybe we could change this to wrap everything from setup to teardown.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note after discussion with Thomas: It's likely we would need to disable it for backward and optimizer.step.

This will also need to be considered for Lite


def training_step(self, batch, batch_idx):
assert torch.tensor([1.2, 3.4j]).dtype == torch.complex128
return super().training_step(batch, batch_idx)


@pytest.mark.parametrize(
Expand All @@ -144,18 +150,16 @@ def on_fit_start(self):
],
)
def test_double_precision(tmpdir, boring_model):
model = boring_model()

trainer = Trainer(max_epochs=2, default_root_dir=tmpdir, fast_dev_run=2, precision=64, log_every_n_steps=1)
with trainer.precision_plugin.autodtype():
model = boring_model()
trainer.fit(model)
trainer.test(model)
trainer.predict(model)


@RunIf(min_gpus=2)
def test_double_precision_ddp(tmpdir):
model = DoublePrecisionBoringModel()

trainer = Trainer(
max_epochs=1,
default_root_dir=tmpdir,
Expand All @@ -165,6 +169,8 @@ def test_double_precision_ddp(tmpdir):
precision=64,
log_every_n_steps=1,
)
with trainer.precision_plugin.autodtype():
model = DoublePrecisionBoringModel()
trainer.fit(model)


Expand All @@ -173,3 +179,21 @@ def test_double_precision_pickle(tmpdir):
plugin = DoublePrecisionPlugin()
model, _, __ = plugin.connect(model, MagicMock(), MagicMock())
pickle.dumps(model)


def test_double_precision_restores_dtype():
class DummyException(BaseException):
...

class Model(BoringModel):
def training_step(self, batch, batch_idx):
assert torch.get_default_dtype() == torch.double
raise DummyException

model = Model()
trainer = Trainer(precision=64, num_sanity_val_steps=0)

assert torch.get_default_dtype() == torch.float
with pytest.raises(DummyException):
trainer.fit(model)
assert torch.get_default_dtype() == torch.float