Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix metric attribute lookup #8181

Merged
merged 3 commits into from
Jun 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `log_gpu_memory` metrics not being added to `logging` when nothing else is logged ([#8174](https://github.com/PyTorchLightning/pytorch-lightning/pull/8174))


- Fixed a bug where calling `log` with a `Metric` instance would raise an error if it was a nested attribute of the model ([#8181](https://github.com/PyTorchLightning/pytorch-lightning/pull/8181))

## [1.3.7] - 2021-06-22

- Fixed a bug where skipping an optimizer while using amp causes amp to trigger an assertion error ([#7975](https://github.com/PyTorchLightning/pytorch-lightning/pull/7975))
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,15 +375,15 @@ def log(
# compute once
self._metric_attributes = {
id(module): name
for name, module in self.named_children() if isinstance(module, Metric)
for name, module in self.named_modules() if isinstance(module, Metric)
}
if not self._metric_attributes:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
" You can fix this by setting an attribute for the metric in your `LightningModule`."
)
# try to find the passed metric in the LightningModule
metric_attribute = self._metric_attributes.get(id(value))
metric_attribute = self._metric_attributes.get(id(value), None)
if metric_attribute is None:
raise MisconfigurationException(
"Could not find the `LightningModule` attribute for the `torchmetrics.Metric` logged."
Expand Down
93 changes: 93 additions & 0 deletions tests/metrics/test_metric_lightning.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,27 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from re import escape

import pytest
import torch
from torch import nn
from torchmetrics import Metric as TMetric

from pytorch_lightning import Trainer
from pytorch_lightning.metrics import Metric as PLMetric
from pytorch_lightning.metrics import MetricCollection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel


Expand Down Expand Up @@ -192,3 +210,78 @@ def training_epoch_end(self, outputs):
logged = trainer.logged_metrics
assert torch.allclose(torch.tensor(logged["SumMetric_epoch"]), model.sum)
assert torch.allclose(torch.tensor(logged["DiffMetric_epoch"]), model.diff)


def test_log_metric_no_attributes_raises(tmpdir):

class TestModel(BoringModel):

def training_step(self, *args):
metric = SumMetric()
self.log("foo", metric)

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
model = TestModel()
with pytest.raises(MisconfigurationException, match="Could not find the `LightningModule` attribute"):
trainer.fit(model)


def test_log_metrics_wrong_attributes_raises(tmpdir):

class TestModel(BoringModel):

def __init__(self):
super().__init__()

self.a_metric = SumMetric()

def training_step(self, *args):
metric = SumMetric()
self.log("foo", metric)

trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
model = TestModel()
with pytest.raises(MisconfigurationException, match=escape("where `name` is one of ['a_metric']")):
trainer.fit(model)
Comment on lines +236 to +245
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@carmocca why do we need this dependency on the attribute name in the module vs the key name used for publishing? if it's solely for restoration, could we think of other approaches? this style of metrics logging is a very common pattern for us

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not @carmocca, but I'll take a shot at this - I don't think there is a dependency on key name for publishing. This test is failing because the metric being logged isn't the same instance as the self.a_metric metric that's an attribute in the module.

test_log_metric_dict below logs some metrics that have key name != attribute name and still passes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mooey5775 is correct.

This is to match the logged metric with the metric attribute in the LightningModule so the state can be restored.

could we think of other approaches?

Open to ideas!



def test_log_metric_dict(tmpdir):

class TestModel(BoringModel):

def __init__(self):
super().__init__()
self.metrics = nn.ModuleDict({'sum': SumMetric(), 'diff': DiffMetric()})
self.sum = 0.0
self.diff = 0.0

def training_step(self, batch, batch_idx):
x = batch
self.metrics['sum'](x.sum())
self.metrics['diff'](x.sum())
self.sum += x.sum()
self.diff -= x.sum()
self.log_dict({f'{k}_step': v for k, v in self.metrics.items()})
return self.step(x)

def training_epoch_end(self, outputs):
self.metrics['sum'].compute()
self.metrics['diff'].compute()
self.log_dict({f'{k}_epoch': v for k, v in self.metrics.items()})

model = TestModel()
model.val_dataloader = None

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
weights_summary=None,
)
trainer.fit(model)

logged = trainer.logged_metrics
assert torch.allclose(torch.tensor(logged["sum_epoch"]), model.sum)
assert torch.allclose(torch.tensor(logged["diff_epoch"]), model.diff)