Skip to content

Commit

Permalink
Add new metric usages and update RunningAverage accordingly (#2958)
Browse files Browse the repository at this point in the history
* Improve the metric, update tests & docstrings

* Adding epoch_bound and detach

Also a few improvements

* Add test for detach and epoch_bound

* Fix a bug and do a refactor in test_metric

and add test for SingleEpochRunningBatchWise in test_metric

* Fix docstrings

* autopep8 fix

* Improve code, docs and tests

* Improve code

* Fix mypy

* Update test_running_epoch_wise test

* Some improvements

---------

Co-authored-by: sadra-barikbin <sadra-barikbin@users.noreply.github.com>
Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
3 people authored Jun 29, 2023
1 parent eba5aae commit 193643c
Show file tree
Hide file tree
Showing 6 changed files with 501 additions and 356 deletions.
15 changes: 15 additions & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,10 @@ Complete list of usages

- :class:`~ignite.metrics.metric.MetricUsage`
- :class:`~ignite.metrics.metric.EpochWise`
- :class:`~ignite.metrics.metric.RunningEpochWise`
- :class:`~ignite.metrics.metric.BatchWise`
- :class:`~ignite.metrics.metric.RunningBatchWise`
- :class:`~ignite.metrics.metric.SingleEpochRunningBatchWise`
- :class:`~ignite.metrics.metric.BatchFiltered`

Metrics and distributed computations
Expand Down Expand Up @@ -359,10 +362,22 @@ EpochWise
~~~~~~~~~
.. autoclass:: ignite.metrics.metric.EpochWise

RunningEpochWise
~~~~~~~~~~~~~~~~
.. autoclass:: ignite.metrics.metric.RunningEpochWise

BatchWise
~~~~~~~~~
.. autoclass:: ignite.metrics.metric.BatchWise

RunningBatchWise
~~~~~~~~~~~~~~~~
.. autoclass:: ignite.metrics.metric.RunningBatchWise

SingleEpochRunningBatchWise
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ignite.metrics.metric.SingleEpochRunningBatchWise

BatchFiltered
~~~~~~~~~~~~~
.. autoclass:: ignite.metrics.metric.BatchFiltered
Expand Down
5 changes: 3 additions & 2 deletions ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ignite.handlers.checkpoint import BaseSaveHandler
from ignite.handlers.param_scheduler import ParamScheduler
from ignite.metrics import RunningAverage
from ignite.metrics.metric import RunningBatchWise
from ignite.utils import deprecated


Expand Down Expand Up @@ -209,8 +210,8 @@ def output_transform(x: Any, index: int, name: str) -> Any:
)

for i, n in enumerate(output_names):
RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach(
trainer, n
RunningAverage(output_transform=partial(output_transform, index=i, name=n)).attach(
trainer, n, usage=RunningBatchWise()
)

if with_pbars:
Expand Down
109 changes: 102 additions & 7 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@
if TYPE_CHECKING:
from ignite.metrics.metrics_lambda import MetricsLambda

__all__ = ["Metric", "MetricUsage", "EpochWise", "BatchWise", "BatchFiltered"]
__all__ = [
"Metric",
"MetricUsage",
"EpochWise",
"BatchWise",
"BatchFiltered",
"RunningEpochWise",
"RunningBatchWise",
"SingleEpochRunningBatchWise",
]


class MetricUsage:
Expand All @@ -31,6 +40,8 @@ class MetricUsage:
:meth:`~ignite.metrics.metric.Metric.iteration_completed`.
"""

usage_name: str

def __init__(self, started: Events, completed: Events, iteration_completed: CallableEventWithFilter) -> None:
self.__started = started
self.__completed = completed
Expand Down Expand Up @@ -74,6 +85,33 @@ def __init__(self) -> None:
)


class RunningEpochWise(EpochWise):
"""
Running epoch-wise usage of Metrics. It's the running version of the :class:`~.metrics.metric.EpochWise` metric
usage. A metric with such a usage most likely accompanies an :class:`~.metrics.metric.EpochWise` one to compute
a running measure of it e.g. running average.
Metric's methods are triggered on the following engine events:
- :meth:`~ignite.metrics.metric.Metric.started` on every ``STARTED``
(See :class:`~ignite.engine.events.Events`).
- :meth:`~ignite.metrics.metric.Metric.iteration_completed` on every ``EPOCH_COMPLETED``.
- :meth:`~ignite.metrics.metric.Metric.completed` on every ``EPOCH_COMPLETED``.
Attributes:
usage_name: usage name string
"""

usage_name: str = "running_epoch_wise"

def __init__(self) -> None:
super(EpochWise, self).__init__(
started=Events.STARTED,
completed=Events.EPOCH_COMPLETED,
iteration_completed=Events.EPOCH_COMPLETED,
)


class BatchWise(MetricUsage):
"""
Batch-wise usage of Metrics.
Expand All @@ -99,6 +137,59 @@ def __init__(self) -> None:
)


class RunningBatchWise(BatchWise):
"""
Running batch-wise usage of Metrics. It's the running version of the :class:`~.metrics.metric.EpochWise` metric
usage. A metric with such a usage could for example accompany a :class:`~.metrics.metric.BatchWise` one to compute
a running measure of it e.g. running average.
Metric's methods are triggered on the following engine events:
- :meth:`~ignite.metrics.metric.Metric.started` on every ``STARTED``
(See :class:`~ignite.engine.events.Events`).
- :meth:`~ignite.metrics.metric.Metric.iteration_completed` on every ``ITERATION_COMPLETED``.
- :meth:`~ignite.metrics.metric.Metric.completed` on every ``ITERATION_COMPLETED``.
Attributes:
usage_name: usage name string
"""

usage_name: str = "running_batch_wise"

def __init__(self) -> None:
super(BatchWise, self).__init__(
started=Events.STARTED,
completed=Events.ITERATION_COMPLETED,
iteration_completed=Events.ITERATION_COMPLETED,
)


class SingleEpochRunningBatchWise(BatchWise):
"""
Running batch-wise usage of Metrics in a single epoch. It's like :class:`~.metrics.metric.RunningBatchWise` metric
usage with the difference that is used during a single epoch.
Metric's methods are triggered on the following engine events:
- :meth:`~ignite.metrics.metric.Metric.started` on every ``EPOCH_STARTED``
(See :class:`~ignite.engine.events.Events`).
- :meth:`~ignite.metrics.metric.Metric.iteration_completed` on every ``ITERATION_COMPLETED``.
- :meth:`~ignite.metrics.metric.Metric.completed` on every ``ITERATION_COMPLETED``.
Attributes:
usage_name: usage name string
"""

usage_name: str = "single_epoch_running_batch_wise"

def __init__(self) -> None:
super(BatchWise, self).__init__(
started=Events.EPOCH_STARTED,
completed=Events.ITERATION_COMPLETED,
iteration_completed=Events.ITERATION_COMPLETED,
)


class BatchFiltered(MetricUsage):
"""
Batch filtered usage of Metrics. This usage is similar to epoch-wise but update event is filtered.
Expand Down Expand Up @@ -344,12 +435,16 @@ def completed(self, engine: Engine, name: str) -> None:

def _check_usage(self, usage: Union[str, MetricUsage]) -> MetricUsage:
if isinstance(usage, str):
if usage == EpochWise.usage_name:
usage = EpochWise()
elif usage == BatchWise.usage_name:
usage = BatchWise()
else:
raise ValueError(f"usage should be 'EpochWise.usage_name' or 'BatchWise.usage_name', get {usage}")
usages = [EpochWise, RunningEpochWise, BatchWise, RunningBatchWise, SingleEpochRunningBatchWise]
for usage_cls in usages:
if usage == usage_cls.usage_name:
usage = usage_cls()
break
if not isinstance(usage, MetricUsage):
raise ValueError(
"Argument usage should be '(Running)EpochWise.usage_name' or "
f"'((SingleEpoch)Running)BatchWise.usage_name', got {usage}"
)
if not isinstance(usage, MetricUsage):
raise TypeError(f"Unhandled usage type {type(usage)}")
return usage
Expand Down
150 changes: 101 additions & 49 deletions ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Callable, cast, Optional, Sequence, Union
import warnings
from typing import Any, Callable, cast, Optional, Union

import torch

import ignite.distributed as idist
from ignite.engine import Engine, Events
from ignite.metrics.metric import EpochWise, Metric, MetricUsage, reinit__is_reduced, sync_all_reduce
from ignite.metrics.metric import Metric, MetricUsage, reinit__is_reduced, RunningBatchWise, SingleEpochRunningBatchWise

__all__ = ["RunningAverage"]

Expand All @@ -18,8 +19,10 @@ class RunningAverage(Metric):
alpha: running average decay factor, default 0.98
output_transform: a function to use to transform the output if `src` is None and
corresponds the output of process function. Otherwise it should be None.
epoch_bound: whether the running average should be reset after each epoch (defaults
to True).
epoch_bound: whether the running average should be reset after each epoch. It is depracated in favor of
``usage`` argument in :meth:`attach` method. Setting ``epoch_bound`` to ``False`` is equivalent to
``usage=SingleEpochRunningBatchWise()`` and setting it to ``True`` is equivalent to
``usage=RunningBatchWise()`` in the :meth:`attach` method. Default None.
device: specifies which device updates are accumulated on. Should be
None when ``src`` is an instance of :class:`~ignite.metrics.metric.Metric`, as the running average will
use the ``src``'s device. Otherwise, defaults to CPU. Only applicable when the computed value
Expand Down Expand Up @@ -90,7 +93,7 @@ def __init__(
src: Optional[Metric] = None,
alpha: float = 0.98,
output_transform: Optional[Callable] = None,
epoch_bound: bool = True,
epoch_bound: Optional[bool] = None,
device: Optional[Union[str, torch.device]] = None,
):
if not (isinstance(src, Metric) or src is None):
Expand All @@ -101,70 +104,119 @@ def __init__(
if isinstance(src, Metric):
if output_transform is not None:
raise ValueError("Argument output_transform should be None if src is a Metric.")

def output_transform(x: Any) -> Any:
return x

if device is not None:
raise ValueError("Argument device should be None if src is a Metric.")
self.src = src
self._get_src_value = self._get_metric_value
setattr(self, "iteration_completed", self._metric_iteration_completed)
self.src: Union[Metric, None] = src
device = src._device
else:
if output_transform is None:
raise ValueError(
"Argument output_transform should not be None if src corresponds "
"to the output of process function."
)
self._get_src_value = self._get_output_value
setattr(self, "update", self._output_update)
self.src = None
if device is None:
device = torch.device("cpu")

self.alpha = alpha
if epoch_bound is not None:
warnings.warn(
"`epoch_bound` is deprecated and will be removed in the future. Consider using `usage` argument of"
"`attach` method instead. `epoch_bound=True` is equivalent with `usage=SingleEpochRunningBatchWise()`"
" and `epoch_bound=False` is equivalent with `usage=RunningBatchWise()`."
)
self.epoch_bound = epoch_bound
super(RunningAverage, self).__init__(output_transform=output_transform, device=device) # type: ignore[arg-type]
self.alpha = alpha
super(RunningAverage, self).__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self) -> None:
self._value: Optional[Union[float, torch.Tensor]] = None
if isinstance(self.src, Metric):
self.src.reset()

@reinit__is_reduced
def update(self, output: Sequence) -> None:
# Implement abstract method
pass

def compute(self) -> Union[torch.Tensor, float]:
if self._value is None:
self._value = self._get_src_value()
def update(self, output: Union[torch.Tensor, float]) -> None:
if self.src is None:
output = output.detach().to(self._device, copy=True) if isinstance(output, torch.Tensor) else output
value = idist.all_reduce(output) / idist.get_world_size()
else:
self._value = self._value * self.alpha + (1.0 - self.alpha) * self._get_src_value()
value = self.src.compute()
self.src.reset()

return self._value

def attach(self, engine: Engine, name: str, _usage: Union[str, MetricUsage] = EpochWise()) -> None:
if self.epoch_bound:
# restart average every epoch
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
if self._value is None:
self._value = value
else:
engine.add_event_handler(Events.STARTED, self.started)
# compute metric
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
# apply running average
engine.add_event_handler(Events.ITERATION_COMPLETED, self.completed, name)

def _get_metric_value(self) -> Union[torch.Tensor, float]:
return self.src.compute()

@sync_all_reduce("src")
def _get_output_value(self) -> Union[torch.Tensor, float]:
# we need to compute average instead of sum produced by @sync_all_reduce("src")
output = cast(Union[torch.Tensor, float], self.src) / idist.get_world_size()
return output
self._value = self._value * self.alpha + (1.0 - self.alpha) * value

def _metric_iteration_completed(self, engine: Engine) -> None:
self.src.started(engine)
self.src.iteration_completed(engine)

@reinit__is_reduced
def _output_update(self, output: Union[torch.Tensor, float]) -> None:
if isinstance(output, torch.Tensor):
output = output.detach().to(self._device, copy=True)
self.src = output # type: ignore[assignment]
def compute(self) -> Union[torch.Tensor, float]:
return cast(Union[torch.Tensor, float], self._value)

def attach(self, engine: Engine, name: str, usage: Union[str, MetricUsage] = RunningBatchWise()) -> None:
r"""
Attach the metric to the ``engine`` using the events determined by the ``usage``.
Args:
engine: the engine to get attached to.
name: by which, the metric is inserted into ``engine.state.metrics`` dictionary.
usage: the usage determining on which events the metric is reset, updated and computed. It should be an
instance of the :class:`~ignite.metrics.metric.MetricUsage`\ s in the following table.
======================================================= ===========================================
``usage`` **class** **Description**
======================================================= ===========================================
:class:`~.metrics.metric.RunningBatchWise` Running average of the ``src`` metric or
``engine.state.output`` is computed across
batches. In the former case, on each batch,
``src`` is reset, updated and computed then
its value is retrieved. Default.
:class:`~.metrics.metric.SingleEpochRunningBatchWise` Same as above but the running average is
computed across batches in an epoch so it
is reset at the end of the epoch.
:class:`~.metrics.metric.RunningEpochWise` Running average of the ``src`` metric or
``engine.state.output`` is computed across
epochs. In the former case, ``src`` works
as if it was attached in a
:class:`~ignite.metrics.metric.EpochWise`
manner and its computed value is retrieved
at the end of the epoch. The latter case
doesn't make much sense for this usage as
the ``engine.state.output`` of the last
batch is retrieved then.
======================================================= ===========================================
``RunningAverage`` retrieves ``engine.state.output`` at ``usage.ITERATION_COMPLETED`` if the ``src`` is not
given and it's computed and updated using ``src``, by manually calling its ``compute`` method, or
``engine.state.output`` at ``usage.COMPLETED`` event.
Also if ``src`` is given, it is updated at ``usage.ITERATION_COMPLETED``, but its reset event is determined by
``usage`` type. If ``isinstance(usage, BatchWise)`` holds true, ``src`` is reset on ``BatchWise().STARTED``,
otherwise on ``EpochWise().STARTED`` if ``isinstance(usage, EpochWise)``.
.. versionchanged:: 0.5.1
Added `usage` argument
"""
usage = self._check_usage(usage)
if self.epoch_bound is not None:
usage = SingleEpochRunningBatchWise() if self.epoch_bound else RunningBatchWise()

if isinstance(self.src, Metric) and not engine.has_event_handler(
self.src.iteration_completed, Events.ITERATION_COMPLETED
):
engine.add_event_handler(Events.ITERATION_COMPLETED, self.src.iteration_completed)

super().attach(engine, name, usage)

def detach(self, engine: Engine, usage: Union[str, MetricUsage] = RunningBatchWise()) -> None:
usage = self._check_usage(usage)
if self.epoch_bound is not None:
usage = SingleEpochRunningBatchWise() if self.epoch_bound else RunningBatchWise()

if isinstance(self.src, Metric) and engine.has_event_handler(
self.src.iteration_completed, Events.ITERATION_COMPLETED
):
engine.remove_event_handler(self.src.iteration_completed, Events.ITERATION_COMPLETED)

super().detach(engine, usage)
Loading

0 comments on commit 193643c

Please sign in to comment.