From 83294ad8a5387771f08c31ca1ceca723b228af0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 30 Jul 2021 12:47:34 +0200 Subject: [PATCH] Fix references for `ResultCollection.extra` and improve `str` and `repr` (#8622) --- .../connectors/logger_connector/result.py | 19 +++++++++++---- tests/core/test_metric_result_integration.py | 23 ++++++++++++++++++- 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index aab86976fe76f..44774105cdb49 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -377,9 +377,8 @@ def minimize(self) -> Optional[torch.Tensor]: @minimize.setter def minimize(self, loss: Optional[torch.Tensor]) -> None: - if loss is not None: - if not isinstance(loss, torch.Tensor): - raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}") + if loss is not None and not isinstance(loss, torch.Tensor): + raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}") self._minimize = loss @property @@ -388,7 +387,8 @@ def extra(self) -> Dict[str, Any]: Extras are any keys other than the loss returned by :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` """ - return self.get("_extra", {}) + self.setdefault("_extra", {}) + return self["_extra"] @extra.setter def extra(self, extra: Dict[str, Any]) -> None: @@ -605,7 +605,16 @@ def cpu(self) -> "ResultCollection": return self.to(device="cpu") def __str__(self) -> str: - return f"{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})" + # sample output: `ResultCollection(minimize=1.23, {})` + minimize = f"minimize={self.minimize}, " if self.minimize is not None else "" + # remove empty values + self_str = str({k: v for k, v in self.items() if v}) + return f"{self.__class__.__name__}({minimize}{self_str})" + + def __repr__(self): + # sample output: `{True, cpu, minimize=tensor(1.23 grad_fn=), {'_extra': {}}}` + minimize = f"minimize={repr(self.minimize)}, " if self.minimize is not None else "" + return f"{{{self.training}, {repr(self.device)}, " + minimize + f"{super().__repr__()}}}" def __getstate__(self, drop_value: bool = True) -> dict: d = self.__dict__.copy() diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index 27503fc1d3339..fa2f9ccdf7c50 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -132,13 +132,28 @@ def test_result_metric_integration(): assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum} + result.minimize = torch.tensor(1.0) + result.extra = {} assert str(result) == ( - "ResultCollection(True, cpu, {" + "ResultCollection(" + "minimize=1.0, " + "{" "'h.a': ResultMetric('a', value=DummyMetric()), " "'h.b': ResultMetric('b', value=DummyMetric()), " "'h.c': ResultMetric('c', value=DummyMetric())" "})" ) + assert repr(result) == ( + "{" + "True, " + "device(type='cpu'), " + "minimize=tensor(1.), " + "{'h.a': ResultMetric('a', value=DummyMetric()), " + "'h.b': ResultMetric('b', value=DummyMetric()), " + "'h.c': ResultMetric('c', value=DummyMetric()), " + "'_extra': {}}" + "}" + ) def test_result_collection_simple_loop(): @@ -332,3 +347,9 @@ def on_save_checkpoint(self, checkpoint) -> None: gpus=1 if device == "cuda" else 0, ) trainer.fit(model) + + +def test_result_collection_extra_reference(): + """Unit-test to check that the `extra` dict reference is properly set.""" + rc = ResultCollection(True) + assert rc.extra is rc["_extra"]