From f64b0d74ebe4cff62f75d7ee9067326293ba1a19 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 10 Jan 2022 18:07:20 +0100 Subject: [PATCH 1/4] Avoid in-place ops during logging result updates --- .../connectors/logger_connector/result.py | 7 ++++--- tests/core/test_metric_result_integration.py | 21 +++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index a9d3f0cbf55db..2048c964ee334 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -241,14 +241,15 @@ def update(self, value: _IN_METRIC, batch_size: int) -> None: # type: ignore[ov # perform accumulation with reduction if self.meta.is_mean_reduction: - self.value += value.mean() * batch_size + # do not use `+=` as it doesn't do type promotion + self.value = self.value + value.mean() * batch_size # `Metric.add_state` does not work well with mypy, mypy doesn't know this is a `Tensor` # we could add an assertion, but this is a hot code path - self.cumulated_batch_size += batch_size # type: ignore[operator] + self.cumulated_batch_size = self.cumulated_batch_size + batch_size # type: ignore[operator] elif self.meta.is_max_reduction or self.meta.is_min_reduction: self.value = self.meta.reduce_fx(self.value, value.mean()) elif self.meta.is_sum_reduction: - self.value += value.mean() + self.value = self.value + value.mean() else: value = cast(Metric, value) self.value = value diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 5667772800875..3f3e36aa52ba1 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -590,6 +590,27 @@ def test_metric_result_respects_dtype(floating_dtype): torch.set_default_dtype(torch.float) +@pytest.mark.parametrize("reduce_fx", ("mean", sum)) +def test_metric_result_dtype_promotion(reduce_fx): + metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx) + metadata.sync = _Sync() + rm = _ResultMetric(metadata, is_tensor=True) + + value, batch_size = torch.tensor(2, dtype=torch.double), 3 + assert rm.value.dtype == torch.float + # log a double + rm.update(value, batch_size) + # `rm.value.dtype` is promoted + assert rm.value.dtype == torch.double + # log a float + rm.update(torch.tensor(4.0, dtype=torch.float), 5) + # the previous dtype stays + assert rm.value.dtype == torch.double + + total = rm.compute() + assert total.dtype == torch.double + + @pytest.mark.parametrize(["reduce_fx", "expected"], [(max, -2), (min, 2)]) def test_result_metric_max_min(reduce_fx, expected): metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx) From b0428b9b8ca32351eac8a29a53ec2f2846c3233f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 10 Jan 2022 18:13:17 +0100 Subject: [PATCH 2/4] CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e2bde5be2d0b..424744a4b0aa8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -407,6 +407,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155)) +- Fixed type promotion when tensors of higher category than float are logged ([#11401](https://github.com/PyTorchLightning/pytorch-lightning/pull/11401)) + + - Fixed the lr-scheduler state not being dumped to checkpoint when using the deepspeed strategy ([#11307](https://github.com/PyTorchLightning/pytorch-lightning/pull/11307)) From a23b2926d6cebddbdfdd5a213e85942d4574df76 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 10 Jan 2022 18:49:47 +0100 Subject: [PATCH 3/4] Fix mypy --- .../trainer/connectors/logger_connector/result.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 2048c964ee334..a3acea7fcb181 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -216,6 +216,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: # do not set a dtype in case the default dtype was changed self.add_state("value", torch.tensor(default), dist_reduce_fx=torch.sum) if self.meta.is_mean_reduction: + self.cumulated_batch_size: torch.Tensor self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum) # this is defined here only because upstream is missing the type annotation self._forward_cache: Optional[Any] = None @@ -243,9 +244,7 @@ def update(self, value: _IN_METRIC, batch_size: int) -> None: # type: ignore[ov if self.meta.is_mean_reduction: # do not use `+=` as it doesn't do type promotion self.value = self.value + value.mean() * batch_size - # `Metric.add_state` does not work well with mypy, mypy doesn't know this is a `Tensor` - # we could add an assertion, but this is a hot code path - self.cumulated_batch_size = self.cumulated_batch_size + batch_size # type: ignore[operator] + self.cumulated_batch_size = self.cumulated_batch_size + batch_size elif self.meta.is_max_reduction or self.meta.is_min_reduction: self.value = self.meta.reduce_fx(self.value, value.mean()) elif self.meta.is_sum_reduction: From 4a0befbef4e6b264d9e9b26e26d0d43a79f91959 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 10 Jan 2022 18:52:07 +0100 Subject: [PATCH 4/4] Simplify test --- tests/core/test_metric_result_integration.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 3f3e36aa52ba1..bd37b760f346a 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -595,15 +595,14 @@ def test_metric_result_dtype_promotion(reduce_fx): metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx) metadata.sync = _Sync() rm = _ResultMetric(metadata, is_tensor=True) - - value, batch_size = torch.tensor(2, dtype=torch.double), 3 assert rm.value.dtype == torch.float + # log a double - rm.update(value, batch_size) + rm.update(torch.tensor(0, dtype=torch.double), 1) # `rm.value.dtype` is promoted assert rm.value.dtype == torch.double # log a float - rm.update(torch.tensor(4.0, dtype=torch.float), 5) + rm.update(torch.tensor(0, dtype=torch.float), 1) # the previous dtype stays assert rm.value.dtype == torch.double