diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e2bde5be2d0b..424744a4b0aa8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -407,6 +407,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed wrong typehint for `Trainer.lightning_optimizers` ([#11155](https://github.com/PyTorchLightning/pytorch-lightning/pull/11155)) +- Fixed type promotion when tensors of higher category than float are logged ([#11401](https://github.com/PyTorchLightning/pytorch-lightning/pull/11401)) + + - Fixed the lr-scheduler state not being dumped to checkpoint when using the deepspeed strategy ([#11307](https://github.com/PyTorchLightning/pytorch-lightning/pull/11307)) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index a9d3f0cbf55db..a3acea7fcb181 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -216,6 +216,7 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: # do not set a dtype in case the default dtype was changed self.add_state("value", torch.tensor(default), dist_reduce_fx=torch.sum) if self.meta.is_mean_reduction: + self.cumulated_batch_size: torch.Tensor self.add_state("cumulated_batch_size", torch.tensor(0), dist_reduce_fx=torch.sum) # this is defined here only because upstream is missing the type annotation self._forward_cache: Optional[Any] = None @@ -241,14 +242,13 @@ def update(self, value: _IN_METRIC, batch_size: int) -> None: # type: ignore[ov # perform accumulation with reduction if self.meta.is_mean_reduction: - self.value += value.mean() * batch_size - # `Metric.add_state` does not work well with mypy, mypy doesn't know this is a `Tensor` - # we could add an assertion, but this is a hot code path - self.cumulated_batch_size += batch_size # type: ignore[operator] + # do not use `+=` as it doesn't do type promotion + self.value = self.value + value.mean() * batch_size + self.cumulated_batch_size = self.cumulated_batch_size + batch_size elif self.meta.is_max_reduction or self.meta.is_min_reduction: self.value = self.meta.reduce_fx(self.value, value.mean()) elif self.meta.is_sum_reduction: - self.value += value.mean() + self.value = self.value + value.mean() else: value = cast(Metric, value) self.value = value diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 5667772800875..bd37b760f346a 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -590,6 +590,26 @@ def test_metric_result_respects_dtype(floating_dtype): torch.set_default_dtype(torch.float) +@pytest.mark.parametrize("reduce_fx", ("mean", sum)) +def test_metric_result_dtype_promotion(reduce_fx): + metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx) + metadata.sync = _Sync() + rm = _ResultMetric(metadata, is_tensor=True) + assert rm.value.dtype == torch.float + + # log a double + rm.update(torch.tensor(0, dtype=torch.double), 1) + # `rm.value.dtype` is promoted + assert rm.value.dtype == torch.double + # log a float + rm.update(torch.tensor(0, dtype=torch.float), 1) + # the previous dtype stays + assert rm.value.dtype == torch.double + + total = rm.compute() + assert total.dtype == torch.double + + @pytest.mark.parametrize(["reduce_fx", "expected"], [(max, -2), (min, 2)]) def test_result_metric_max_min(reduce_fx, expected): metadata = _Metadata("foo", "bar", reduce_fx=reduce_fx)