Skip to content

Commit

Permalink
Fix references for ResultCollection.extra and improve str and `re…
Browse files Browse the repository at this point in the history
…pr` (#8622)
  • Loading branch information
carmocca authored Jul 30, 2021
1 parent 07b7dc9 commit 9720e26
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 8 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Saved checkpoints will no longer use the type of a `Callback` as the key to avoid issues with unpickling ([#6886](https://github.com/PyTorchLightning/pytorch-lightning/pull/6886))


-
- Improved string conversion for `ResultCollection` ([#8622](https://github.com/PyTorchLightning/pytorch-lightning/pull/8622))


-
Expand Down Expand Up @@ -74,7 +74,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed references for `ResultCollection.extra` ([#8622](https://github.com/PyTorchLightning/pytorch-lightning/pull/8622))


-
Expand Down
19 changes: 14 additions & 5 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,8 @@ def minimize(self) -> Optional[torch.Tensor]:

@minimize.setter
def minimize(self, loss: Optional[torch.Tensor]) -> None:
if loss is not None:
if not isinstance(loss, torch.Tensor):
raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}")
if loss is not None and not isinstance(loss, torch.Tensor):
raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}")
self._minimize = loss

@property
Expand All @@ -388,7 +387,8 @@ def extra(self) -> Dict[str, Any]:
Extras are any keys other than the loss returned by
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`
"""
return self.get("_extra", {})
self.setdefault("_extra", {})
return self["_extra"]

@extra.setter
def extra(self, extra: Dict[str, Any]) -> None:
Expand Down Expand Up @@ -605,7 +605,16 @@ def cpu(self) -> "ResultCollection":
return self.to(device="cpu")

def __str__(self) -> str:
return f"{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})"
# sample output: `ResultCollection(minimize=1.23, {})`
minimize = f"minimize={self.minimize}, " if self.minimize is not None else ""
# remove empty values
self_str = str({k: v for k, v in self.items() if v})
return f"{self.__class__.__name__}({minimize}{self_str})"

def __repr__(self):
# sample output: `{True, cpu, minimize=tensor(1.23 grad_fn=<SumBackward0>), {'_extra': {}}}`
minimize = f"minimize={repr(self.minimize)}, " if self.minimize is not None else ""
return f"{{{self.training}, {repr(self.device)}, " + minimize + f"{super().__repr__()}}}"

def __getstate__(self, drop_value: bool = True) -> dict:
d = self.__dict__.copy()
Expand Down
23 changes: 22 additions & 1 deletion tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,28 @@ def test_result_metric_integration():

assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}

result.minimize = torch.tensor(1.0)
result.extra = {}
assert str(result) == (
"ResultCollection(True, cpu, {"
"ResultCollection("
"minimize=1.0, "
"{"
"'h.a': ResultMetric('a', value=DummyMetric()), "
"'h.b': ResultMetric('b', value=DummyMetric()), "
"'h.c': ResultMetric('c', value=DummyMetric())"
"})"
)
assert repr(result) == (
"{"
"True, "
"device(type='cpu'), "
"minimize=tensor(1.), "
"{'h.a': ResultMetric('a', value=DummyMetric()), "
"'h.b': ResultMetric('b', value=DummyMetric()), "
"'h.c': ResultMetric('c', value=DummyMetric()), "
"'_extra': {}}"
"}"
)


def test_result_collection_simple_loop():
Expand Down Expand Up @@ -332,3 +347,9 @@ def on_save_checkpoint(self, checkpoint) -> None:
gpus=1 if device == "cuda" else 0,
)
trainer.fit(model)


def test_result_collection_extra_reference():
"""Unit-test to check that the `extra` dict reference is properly set."""
rc = ResultCollection(True)
assert rc.extra is rc["_extra"]

0 comments on commit 9720e26

Please sign in to comment.