Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix double precision casting complex buffers #8208

Merged
merged 7 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where calling `log` with a `Metric` instance would raise an error if it was a nested attribute of the model ([#8181](https://github.com/PyTorchLightning/pytorch-lightning/pull/8181))


- Fixed a bug where using `precision=64` would cause buffers with complex dtype to be cast to real ([#8208](https://github.com/PyTorchLightning/pytorch-lightning/pull/8208))

## [1.3.7] - 2021-06-22

- Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975))
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/plugins/precision/double.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def connect(
incoming floating point data to double (``torch.float64``) precision. Does not alter `optimizers` or
`lr_schedulers`.
"""
model = cast(pl.LightningModule, model.to(dtype=torch.float64))
model = cast(pl.LightningModule, model.double())
model = LightningDoublePrecisionModule(model)

return super().connect(model, optimizers, lr_schedulers)
Expand Down
7 changes: 7 additions & 0 deletions tests/plugins/test_double_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def __len__(self):

class DoublePrecisionBoringModel(BoringModel):

def __init__(self):
super().__init__()

self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

def training_step(self, batch, batch_idx):
float_data, int_data = batch
assert torch.tensor([0.]).dtype == torch.float64
Expand Down Expand Up @@ -77,9 +82,11 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):

def on_fit_start(self):
assert self.layer.weight.dtype == torch.float64
assert self.complex_buffer.dtype == torch.complex64

def on_after_backward(self):
assert self.layer.weight.grad.dtype == torch.float64
assert self.complex_buffer.dtype == torch.complex64

def train_dataloader(self):
dataset = RandomFloatIntDataset(32, 64)
Expand Down