Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate metrics API with self.log #3961

Merged
merged 8 commits into from
Oct 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@ logic present in ``.compute()`` is applied to state information from all process

The example below shows how to use a metric in your ``LightningModule``:

.. note::

For v0.10.0 the user is expected to call ``.compute()`` on the metric at the end each epoch.
This has been shown in the example below. For v1.0 release, we will integrate metrics
with logging and ``.compute()`` will be called automatically by PyTorch Lightning.

.. code-block:: python

def __init__(self):
Expand All @@ -49,6 +43,41 @@ The example below shows how to use a metric in your ``LightningModule``:
self.log('train_acc_epoch', self.accuracy.compute())


``Metric`` objects can also be directly logged, in which case Lightning will log
the metric based on ``on_step`` and ``on_epoch`` flags present in ``self.log(...)``.
If ``on_epoch`` is True, the logger automatically logs the end of epoch metric value by calling
``.compute()``.

.. note::
``sync_dist``, ``sync_dist_op``, ``sync_dist_group``, ``reduce_fx`` and ``tbptt_reduce_fx``
flags from ``self.log(...)`` don't affect the metric logging in any manner. The metric class
contains its own distributed synchronization logic.

This however is only true for metrics that inherit the base class ``Metric``,
and thus the functional metric API provides no support for in-built distributed synchronization
or reduction functions.


.. code-block:: python

def __init__(self):
...
self.train_acc = pl.metrics.Accuracy()
self.valid_acc = pl.metrics.Accuracy()

def training_step(self, batch, batch_idx):
logits = self(x)
...
self.train_acc(logits, y)
self.log('train_acc', self.train_acc, on_step=True, on_epoch=False)

def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc(logits, y)
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)


This metrics API is independent of PyTorch Lightning. Metrics can directly be used in PyTorch as shown in the example:

.. code-block:: python
Expand Down
61 changes: 47 additions & 14 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os

from pytorch_lightning.utilities.distributed import sync_ddp_if_available

from pytorch_lightning.metrics import Metric

class Result(Dict):
def __init__(
Expand Down Expand Up @@ -251,35 +251,57 @@ def get_batch_log_metrics(self, include_forked_originals=True) -> dict:
continue

if options['logger'] and options['on_step']:
result[k] = self[k]
if isinstance(self[k], Metric):
result[k] = self[k]._forward_cache
else:
result[k] = self[k]

return result

def get_epoch_log_metrics(self) -> dict:
"""
Gets the metrics to log at the end of the batch step
Gets the metrics to log at the end of epoch
"""
result = {}

meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue

if options['logger'] and options['on_epoch']:
result[k] = self[k]
if isinstance(self[k], Metric):
result[k] = self[k].compute()
else:
result[k] = self[k]

if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
self[k].compute()

return result

def get_epoch_pbar_metrics(self):
"""
Gets the metrics to log at the end of the batch step
Gets the metrics to log at the end of epoch
"""
result = {}

meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue

if options['prog_bar'] and options['on_epoch']:
result[k] = self[k]
if isinstance(self[k], Metric):
result[k] = self[k].compute()
else:
result[k] = self[k]

if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
# compute metric on epoch anyway so state does not accumulate
self[k].compute()

return result

def get_batch_pbar_metrics(self, include_forked_originals=True):
Expand All @@ -292,11 +314,16 @@ def get_batch_pbar_metrics(self, include_forked_originals=True):
for k, options in meta.items():
if k == '_internal':
continue

if options['forked'] and not include_forked_originals:
continue

if options['prog_bar'] and options['on_step']:
result[k] = self[k]
if isinstance(self[k], Metric):
result[k] = self[k]._forward_cache
else:
result[k] = self[k]

return result

def detach(self):
Expand Down Expand Up @@ -405,7 +432,7 @@ def reduce_on_epoch_end(cls, outputs):
recursive_stack(result)

for k, option in meta.items():
if k == '_internal':
if k == '_internal' or isinstance(result[k], Metric):
continue

if option['on_epoch']:
Expand Down Expand Up @@ -439,7 +466,7 @@ def reduce_across_time(cls, time_outputs):
recursive_stack(result)

for k, value in result.items():
if k in ['meta', 'extra']:
if k in ['meta', 'extra'] or isinstance(value, Metric):
continue

# pick the reduce fx
Expand All @@ -459,10 +486,12 @@ def reduce_across_time(cls, time_outputs):

def dp_reduce(self):
for k, value in self.items():
if k == 'meta':
if k == 'meta' or isinstance(value, Metric):
continue

if isinstance(value, list):
value = torch.tensor(value)

self[k] = value.mean(dim=-1)

@property
Expand Down Expand Up @@ -502,10 +531,14 @@ def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] =
v = recursive_gather([v], in_d)
result[k] = v
else:
if k not in result:
result[k] = []

result[k].append(v)
if isinstance(v, Metric):
# if v is a metric, just keep one of them,
# don't keep on adding a list of them
result[k] = v
else:
if k not in result:
result[k] = []
result[k].append(v)

return result

Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)
self._computed = None
self._forward_cache = None

# initialize state
self._reductions = {}
Expand Down Expand Up @@ -125,6 +126,7 @@ def forward(self, *args, **kwargs):
"""
# add current step
self.update(*args, **kwargs)
self._forward_cache = None

if self.compute_on_step:
self._to_sync = self.ddp_sync_on_step
Expand All @@ -135,15 +137,15 @@ def forward(self, *args, **kwargs):
# call reset, update, compute, on single batch
self.reset()
self.update(*args, **kwargs)
result = self.compute()
self._forward_cache = self.compute()

# restore context
for attr, val in self._cache.items():
setattr(self, attr, val)
self._to_sync = True
self._computed = None

return result
return self._forward_cache

def _sync_dist(self):
input_dict = {attr: getattr(self, attr) for attr in self._reductions.keys()}
Expand Down
131 changes: 131 additions & 0 deletions tests/core/test_metric_result_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@

import pytest
import sys
import torch
import torch.distributed as dist
import torch.multiprocessing as mp

from pytorch_lightning.core.step_result import Result
from pytorch_lightning.metrics import Metric
import tests.base.develop_utils as tutils


class DummyMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0), dist_reduce_fx="sum")

def update(self, x):
self.x += x

def compute(self):
return self.x


def _setup_ddp(rank, worldsize):
import os

os.environ["MASTER_ADDR"] = "localhost"

# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=worldsize)


def _ddp_test_fn(rank, worldsize):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([1.0])

metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()

# ddp_sync_on_step is False by default
result = Result()

for epoch in range(3):
cumulative_sum = 0

for i in range(5):
metric_a(i)
metric_b(i)
metric_c(i)

cumulative_sum += i

result.log('a', metric_a, on_step=True, on_epoch=True)
result.log('b', metric_b, on_step=False, on_epoch=True)
result.log('c', metric_c, on_step=True, on_epoch=False)

batch_log = result.get_batch_log_metrics()
batch_expected = {"a_step": i, "a": i, "c": i}
assert set(batch_log.keys()) == set(batch_expected.keys())
for k in batch_expected.keys():
assert batch_expected[k] == batch_log[k]

epoch_log = result.get_epoch_log_metrics()

# assert metric state reset to default values
assert metric_a.x == metric_a._defaults['x']
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']

epoch_expected = {
"b": cumulative_sum * worldsize,
"a": cumulative_sum * worldsize,
"a_epoch": cumulative_sum * worldsize
}

assert set(epoch_log.keys()) == set(epoch_expected.keys())
for k in epoch_expected.keys():
assert epoch_expected[k] == epoch_log[k]


@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows")
def test_result_reduce_ddp():
"""Make sure result logging works with DDP"""
tutils.reset_seed()
tutils.set_random_master_port()

worldsize = 2
mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize)


def test_result_metric_integration():
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()

result = Result()

for epoch in range(3):
cumulative_sum = 0

for i in range(5):
metric_a(i)
metric_b(i)
metric_c(i)

cumulative_sum += i

result.log('a', metric_a, on_step=True, on_epoch=True)
result.log('b', metric_b, on_step=False, on_epoch=True)
result.log('c', metric_c, on_step=True, on_epoch=False)

batch_log = result.get_batch_log_metrics()
batch_expected = {"a_step": i, "a": i, "c": i}
assert set(batch_log.keys()) == set(batch_expected.keys())
for k in batch_expected.keys():
assert batch_expected[k] == batch_log[k]

epoch_log = result.get_epoch_log_metrics()

# assert metric state reset to default values
assert metric_a.x == metric_a._defaults['x']
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']

epoch_expected = {"b": cumulative_sum, "a": cumulative_sum, "a_epoch": cumulative_sum}

assert set(epoch_log.keys()) == set(epoch_expected.keys())
for k in epoch_expected.keys():
assert epoch_expected[k] == epoch_log[k]
18 changes: 18 additions & 0 deletions tests/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,24 @@ def compute(self):
assert a.compute() == 5


def test_forward():
class A(Dummy):
def update(self, x):
self.x += x

def compute(self):
return self.x

a = A()
assert a(5) == 5
assert a._forward_cache == 5

assert a(8) == 8
assert a._forward_cache == 8

assert a.compute() == 13


class ToPickle(Dummy):
def update(self, x):
self.x += x
Expand Down