Skip to content

Commit

Permalink
Merge branch 'master' into bugfix/toggle
Browse files Browse the repository at this point in the history
  • Loading branch information
mergify[bot] authored Jan 25, 2021
2 parents 75af9e8 + e87424a commit 9f11c76
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an error when logging a progress bar metric with a reserved name ([#5620](https://github.com/PyTorchLightning/pytorch-lightning/pull/5620))


- Fixed `Metric`'s `state_dict` not included when child modules ([#5614](https://github.com/PyTorchLightning/pytorch-lightning/pull/5614))


## [1.1.5] - 2021-01-19
Expand Down
20 changes: 16 additions & 4 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,23 @@ def persistent(self, mode: bool = False):
for key in self._persistent.keys():
self._persistent[key] = mode

def state_dict(self, *args, **kwargs):
def state_dict(self, destination=None, prefix='', keep_vars=False):
destination = super().state_dict(
destination=destination,
prefix=prefix,
keep_vars=keep_vars
)
# Register metric states to be part of the state_dict
state_dict = super().state_dict()
for key in self._defaults.keys():
if self._persistent[key]:
current_val = getattr(self, key)
state_dict.update({key: current_val})
return state_dict
if not keep_vars:
if torch.is_tensor(current_val):
current_val = current_val.detach()
elif isinstance(current_val, list):
current_val = [
cur_v.detach() if torch.is_tensor(cur_v) else cur_v
for cur_v in current_val
]
destination[prefix + key] = current_val
return destination
20 changes: 20 additions & 0 deletions tests/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest
import torch
from torch import nn

from pytorch_lightning.metrics.metric import Metric

Expand Down Expand Up @@ -201,3 +202,22 @@ def test_state_dict(tmpdir):
assert metric.state_dict() == OrderedDict(x=0)
metric.persistent(False)
assert metric.state_dict() == OrderedDict()


def test_child_metric_state_dict():
""" test that child metric states will be added to parent state dict """
class TestModule(nn.Module):
def __init__(self):
super().__init__()
self.metric = Dummy()
self.metric.add_state('a', torch.tensor(0), persistent=True)
self.metric.add_state('b', [], persistent=True)
self.metric.register_buffer('c', torch.tensor(0))

module = TestModule()
expected_state_dict = {
'metric.a': torch.tensor(0),
'metric.b': [],
'metric.c': torch.tensor(0)
}
assert module.state_dict() == expected_state_dict

0 comments on commit 9f11c76

Please sign in to comment.