-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add context manager to properly convert the precision #10079
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -124,12 +124,18 @@ def predict_dataloader(self): | |
class DoublePrecisionBoringModelComplexBuffer(BoringModel): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
self.register_buffer("complex_buffer", torch.complex(torch.rand(10), torch.rand(10)), False) | ||
self.register_buffer("complex_buffer", torch.tensor([1.2, 3.4j]), False) | ||
|
||
def on_fit_start(self): | ||
assert self.layer.weight.dtype == torch.float64 | ||
assert self.complex_buffer.dtype == torch.complex64 | ||
super().on_fit_start() | ||
# when the default floating point type is float64 the default complex type is complex128 | ||
assert self.complex_buffer.dtype == torch.complex128 | ||
# this hook is not wrapped. # TODO: should it be? | ||
assert torch.tensor([1.2, 3.4j]).dtype == torch.complex64 | ||
Comment on lines
+134
to
+135
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure whether this is working as expected or a bug. The precision context manager is only active during the forward context, and this hook is not part of it. Should we instead enter the context manager on setup and exit on teardown? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure, but I would say yes. real + img in float32 -> complex64, and real + img in float64 -> complex128. Makes sense to me at least. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's as expected. The problem here is that we only wrap the precision for the forward hooks. So, other hooks like Maybe we could change this to wrap everything from There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note after discussion with Thomas: It's likely we would need to disable it for backward and optimizer.step. This will also need to be considered for Lite |
||
|
||
def training_step(self, batch, batch_idx): | ||
assert torch.tensor([1.2, 3.4j]).dtype == torch.complex128 | ||
return super().training_step(batch, batch_idx) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
|
@@ -144,18 +150,16 @@ def on_fit_start(self): | |
], | ||
) | ||
def test_double_precision(tmpdir, boring_model): | ||
model = boring_model() | ||
|
||
trainer = Trainer(max_epochs=2, default_root_dir=tmpdir, fast_dev_run=2, precision=64, log_every_n_steps=1) | ||
with trainer.precision_plugin.autodtype(): | ||
model = boring_model() | ||
trainer.fit(model) | ||
trainer.test(model) | ||
trainer.predict(model) | ||
|
||
|
||
@RunIf(min_gpus=2) | ||
def test_double_precision_ddp(tmpdir): | ||
model = DoublePrecisionBoringModel() | ||
|
||
trainer = Trainer( | ||
max_epochs=1, | ||
default_root_dir=tmpdir, | ||
|
@@ -165,6 +169,8 @@ def test_double_precision_ddp(tmpdir): | |
precision=64, | ||
log_every_n_steps=1, | ||
) | ||
with trainer.precision_plugin.autodtype(): | ||
model = DoublePrecisionBoringModel() | ||
trainer.fit(model) | ||
|
||
|
||
|
@@ -173,3 +179,21 @@ def test_double_precision_pickle(tmpdir): | |
plugin = DoublePrecisionPlugin() | ||
model, _, __ = plugin.connect(model, MagicMock(), MagicMock()) | ||
pickle.dumps(model) | ||
|
||
|
||
def test_double_precision_restores_dtype(): | ||
class DummyException(BaseException): | ||
... | ||
|
||
class Model(BoringModel): | ||
def training_step(self, batch, batch_idx): | ||
assert torch.get_default_dtype() == torch.double | ||
raise DummyException | ||
|
||
model = Model() | ||
trainer = Trainer(precision=64, num_sanity_val_steps=0) | ||
|
||
assert torch.get_default_dtype() == torch.float | ||
with pytest.raises(DummyException): | ||
trainer.fit(model) | ||
assert torch.get_default_dtype() == torch.float |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be a dataclass ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe, but I don't think we want to. It's still a
PrecisionPlugin
(not a dataclass)