Skip to content

Commit

Permalink
Make metrics serializable (#3001)
Browse files Browse the repository at this point in the history
* Implement the feature

* Add docstrings

* Apply changes

* Fix a bug

* Fix a mypy issue

* Fix a flake issue in test_lr_finder
  • Loading branch information
sadra-barikbin committed Aug 10, 2023
1 parent 2f7246c commit cf3fdd1
Show file tree
Hide file tree
Showing 22 changed files with 170 additions and 6 deletions.
8 changes: 5 additions & 3 deletions docs/source/engine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Resuming the training
It is possible to resume the training from a checkpoint and approximately reproduce original run's behaviour.
Using Ignite, this can be easily done using :class:`~ignite.handlers.checkpoint.Checkpoint` handler. Engine provides two methods
to serialize and deserialize its internal state :meth:`~ignite.engine.engine.Engine.state_dict` and
:meth:`~ignite.engine.engine.Engine.load_state_dict`. In addition to serializing model, optimizer, lr scheduler etc user can
:meth:`~ignite.engine.engine.Engine.load_state_dict`. In addition to serializing model, optimizer, lr scheduler, metrics, etc., user can
store the trainer and then resume the training. For example:

.. code-block:: python
Expand All @@ -82,8 +82,9 @@ store the trainer and then resume the training. For example:
optimizer = ...
lr_scheduler = ...
data_loader = ...
metric = ...
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric}
handler = Checkpoint(to_save, DiskSaver('/tmp/training', create_dir=True))
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)
trainer.run(data_loader, max_epochs=100)
Expand All @@ -104,8 +105,9 @@ We can then restore the training from the last checkpoint.
optimizer = ...
lr_scheduler = ...
data_loader = ...
metric = ...
to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric}
checkpoint = torch.load(checkpoint_file)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)
Expand Down
2 changes: 1 addition & 1 deletion ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
return OrderedDict([("saved", [(p, f) for p, f in self._saved])])

def load_state_dict(self, state_dict: Mapping) -> None:
"""Method replace internal state of the class with provided state dict data.
"""Method replaces internal state of the class with provided state dict data.
Args:
state_dict: a dict with "saved" key and list of ``(priority, filename)`` pairs as values.
Expand Down
1 change: 1 addition & 0 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class VariableAccumulation(Metric):
"""

required_output_keys = None
_state_dict_all_req_keys = ("accumulator", "num_examples")

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def thresholded_output_transform(output):
0.6666...
"""

_state_dict_all_req_keys = ("_num_correct", "_num_examples")

def __init__(
self,
output_transform: Callable = lambda x: x,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def binary_one_hot_output_transform(output):
[1, 1]])
"""

_state_dict_all_req_keys = ("confusion_matrix", "_num_examples")

def __init__(
self,
num_classes: int,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/epoch_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def mse_fn(y_preds, y_targets):
To disable the warning, set ``check_compute_fn=False``.
"""

_state_dict_all_req_keys = ("_predictions", "_targets")

def __init__(
self,
compute_fn: Callable[[torch.Tensor, torch.Tensor], float],
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/gan/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def forward(self, x):
.. versionadded:: 0.4.6
"""

_state_dict_all_req_keys = ("_num_examples", "_train_total", "_test_total", "_train_sigma", "_test_sigma")

def __init__(
self,
num_features: Optional[int] = None,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/gan/inception_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class InceptionScore(_BaseInceptionMetric):
.. versionadded:: 0.4.6
"""

_state_dict_all_req_keys = ("_num_examples", "_prob_total", "_total_kl_d")

def __init__(
self,
num_features: Optional[int] = None,
Expand Down
1 change: 1 addition & 0 deletions ignite/metrics/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Loss(Metric):
"""

required_output_keys = ("y_pred", "y", "criterion_kwargs")
_state_dict_all_req_keys = ("_sum", "_num_examples")

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class MeanAbsoluteError(Metric):
2.9375
"""

_state_dict_all_req_keys = ("_sum_of_absolute_errors", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_absolute_errors = torch.tensor(0.0, device=self._device)
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/mean_pairwise_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class MeanPairwiseDistance(Metric):
1.5955...
"""

_state_dict_all_req_keys = ("_sum_of_distances", "_num_examples")

def __init__(
self,
p: int = 2,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class MeanSquaredError(Metric):
3.828125
"""

_state_dict_all_req_keys = ("_sum_of_squared_errors", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_squared_errors = torch.tensor(0.0, device=self._device)
Expand Down
54 changes: 52 additions & 2 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from collections.abc import Mapping
from functools import wraps
from numbers import Number
from typing import Any, Callable, cast, Dict, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import torch

import ignite.distributed as idist

from ignite.base.mixins import Serializable
from ignite.engine import CallableEventWithFilter, Engine, Events

if TYPE_CHECKING:
Expand Down Expand Up @@ -216,7 +219,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
)


class Metric(metaclass=ABCMeta):
class Metric(Serializable, metaclass=ABCMeta):
"""
Base class for all Metrics.
Expand Down Expand Up @@ -546,6 +549,53 @@ def is_attached(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise
usage = self._check_usage(usage)
return engine.has_event_handler(self.completed, usage.COMPLETED)

def state_dict(self) -> OrderedDict:
"""Method returns state dict with attributes of the metric specified in its
`_state_dict_all_req_keys` attribute. Can be used to save internal state of the class.
If there's an active distributed configuration, some collective operations is done and
the list of values across ranks is saved under each attribute's name in the dict.
"""
state = OrderedDict()
for attr_name in self._state_dict_all_req_keys:
if attr_name not in self.__dict__:
raise ValueError(
f"Found a value in _state_dict_all_req_keys that is not among metric attributes: {attr_name}"
)
attr = getattr(self, attr_name)
if not isinstance(attr, (int, float, torch.Tensor)):
raise TypeError(
"Currently, only numeric or tensor-typed attributes of the metric"
" could be added to its state_dict."
)
if idist.get_world_size() == 1:
state[attr_name] = [attr]
else:
if isinstance(attr, (int, float)):
attr_type = type(attr)
attr = float(attr)
gathered_attr = cast(List[Any], idist.all_gather(attr))
if isinstance(attr, float):
gathered_attr = [attr_type(process_attr) for process_attr in gathered_attr]
state[attr_name] = gathered_attr

return state

def load_state_dict(self, state_dict: Mapping) -> None:
"""Method replaces internal state of the class with provided state dict data.
If there's an active distributed configuration, the process uses its rank to pick the proper value from
the list of values saved under each attribute's name in the dict.
Args:
state_dict: a dict containing attributes of the metric specified in its `_state_dict_all_req_keys`
attribute.
"""
super().load_state_dict(state_dict)
rank = idist.get_rank()
for attr in self._state_dict_all_req_keys:
setattr(self, attr, state_dict[attr][rank])

def __add__(self, other: Any) -> "MetricsLambda":
from ignite.metrics.metrics_lambda import MetricsLambda

Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/multilabel_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class MultiLabelConfusionMatrix(Metric):
"""

_state_dict_all_req_keys = ("confusion_matrix", "_num_examples")

def __init__(
self,
num_classes: int,
Expand Down
5 changes: 5 additions & 0 deletions ignite/metrics/nlp/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def __init__(
raise ValueError(f'Average must be either "macro" or "micro" (got: {average})')
self.average = average

if average == "micro":
self._state_dict_all_req_keys = ("p_numerators", "p_denominators", "hyp_length_sum", "ref_length_sum")
else:
self._state_dict_all_req_keys = ("_sum_of_bleu", "_num_sentences")

super(Bleu, self).__init__(output_transform=output_transform, device=device)

def _n_gram_counter(
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/nlp/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ class _BaseRouge(Metric):
Rouge interface for Rouge-L and Rouge-N
"""

_state_dict_all_req_keys = ("_recall", "_precision", "_fmeasure", "_num_examples")

def __init__(
self,
multiref: str = "average",
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


class _BasePrecisionRecall(_BaseClassification):
_state_dict_all_req_keys = ("_numerator", "_denominator", "_weight")

def __init__(
self,
output_transform: Callable = lambda x: x,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def get_y_channel(output):
.. versionadded:: 0.4.3
"""

_state_dict_all_req_keys = ("_sum_of_batchwise_psnr", "_num_examples")

def __init__(
self,
data_range: Union[int, float],
Expand Down
3 changes: 3 additions & 0 deletions ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def log_running_avg_metrics():
"""

required_output_keys = None
# TODO Shall we put `src` here? Then we should add a new branch for metric-typed attributes in `state_dict`
# and `load_state_dict`. Examples; This class; `Rouge` which has a `List[_BaseRouge]`.
_state_dict_all_req_keys = ("_value",)

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class SSIM(Metric):
.. versionadded:: 0.4.2
"""

_state_dict_all_req_keys = ("_sum_of_ssim", "_num_examples", "_kernel")

def __init__(
self,
data_range: Union[int, float],
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/top_k_categorical_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def one_hot_to_binary_output_transform(output):
0.75
"""

_state_dict_all_req_keys = ("_num_correct", "_num_examples")

def __init__(
self,
k: int = 5,
Expand Down
74 changes: 74 additions & 0 deletions tests/ignite/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
_test_distrib_state_dict(device)


@pytest.mark.distributed
Expand All @@ -722,6 +723,7 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
_test_distrib_state_dict(device)


@pytest.mark.distributed
Expand All @@ -744,6 +746,7 @@ def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
_test_distrib_state_dict(device)


@pytest.mark.multinode_distributed
Expand Down Expand Up @@ -1125,3 +1128,74 @@ def update(self, output):

with pytest.raises(ValueError, match=r"Output should have 2 items of the same length"):
engine.run([0] * 10)


class DummyMetric4(Metric):
_state_dict_all_req_keys = ("dnumber", "fnumber", "tensor")

def __init__(self, value: int):
super().reset()
self.dnumber = value
self.fnumber = float(value + 1)
self.tensor = torch.tensor([value + 2])

def reset(self):
self.dnumber = -1
self.fnumber = -2.0
self.tensor = torch.tensor([-3])

def update(self, output):
pass

def compute(self):
pass


def test_wrong_state_dict():
class WrongMetric(Metric):
_state_dict_all_req_keys = ("object",)

def __init__(self, value):
super().__init__()
self.object = {"a": [value]}

def reset(self):
pass

def update(self, output):
pass

def compute(self):
pass

metric = WrongMetric(2)
with pytest.raises(TypeError, match="Currently, only numeric or tensor-typed attributes of the metric"):
metric.state_dict()

delattr(metric, "object")
with pytest.raises(ValueError, match="Found a value in _state_dict_all_req_keys that is not among"):
metric.state_dict()


def test_state_dict():
metric = DummyMetric4(1)
state = metric.state_dict()
assert state.keys() == {"dnumber", "fnumber", "tensor"}
metric.reset()
metric.load_state_dict(state)
assert metric.dnumber == 1
assert metric.fnumber == 2
assert metric.tensor == torch.tensor([3])


def _test_distrib_state_dict(device):
rank = idist.get_local_rank()
metric = DummyMetric4(rank)
state = metric.state_dict()
assert isinstance(state["dnumber"][rank], int)
assert isinstance(state["fnumber"][rank], float)
metric.reset()
metric.load_state_dict(state)
assert metric.dnumber == rank and isinstance(metric.dnumber, int)
assert metric.fnumber == rank + 1 and isinstance(metric.fnumber, float)
assert metric.tensor == torch.tensor([rank + 2])

0 comments on commit cf3fdd1

Please sign in to comment.