-
-
Notifications
You must be signed in to change notification settings - Fork 609
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add KLDivergence metric * add JSDivergence * fix variable name * update docstring for JSDivergence * Update ignite/metrics/js_divergence.py Co-authored-by: vfdev <vfdev.5@gmail.com> * Update ignite/metrics/kl_divergence.py Co-authored-by: vfdev <vfdev.5@gmail.com> * swap ground truth and prediction * swap the definitions of p and q --------- Co-authored-by: vfdev <vfdev.5@gmail.com>
- Loading branch information
Showing
6 changed files
with
512 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
from ignite.exceptions import NotComputableError | ||
from ignite.metrics.kl_divergence import KLDivergence | ||
from ignite.metrics.metric import sync_all_reduce | ||
|
||
__all__ = ["JSDivergence"] | ||
|
||
|
||
class JSDivergence(KLDivergence): | ||
r"""Calculates the mean of `Jensen-Shannon (JS) divergence | ||
<https://en.wikipedia.org/wiki/Jensen%E2%80%93Shannon_divergence>`_. | ||
.. math:: | ||
\begin{align*} | ||
D_\text{JS}(\mathbf{p}_i \| \mathbf{q}_i) &= \frac{1}{2} D_\text{KL}(\mathbf{p}_i \| \mathbf{m}_i) | ||
+ \frac{1}{2} D_\text{KL}(\mathbf{q}_i \| \mathbf{m}_i), \\ | ||
\mathbf{m}_i &= \frac{1}{2}(\mathbf{p}_i + \mathbf{q}_i), \\ | ||
D_\text{KL}(\mathbf{p}_i \| \mathbf{q}_i) &= \sum_{c=1}^C p_{i,c} \log \frac{p_{i,c}}{q_{i,c}}. | ||
\end{align*} | ||
where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the ground truth and prediction probability tensors, | ||
and :math:`D_\text{KL}` is the KL-divergence. | ||
- ``update`` must receive output of the form ``(y_pred, y)``. | ||
- ``y_pred`` and ``y`` are expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) | ||
or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. | ||
Args: | ||
output_transform: a callable that is used to transform the | ||
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
you want to compute the metric with respect to one of the outputs. | ||
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. | ||
device: specifies which device updates are accumulated on. Setting the | ||
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is | ||
non-blocking. By default, CPU. | ||
Examples: | ||
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. | ||
The output of the engine's ``process_function`` needs to be in the format of | ||
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added | ||
to the metric to transform the output into the form expected by the metric. | ||
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. | ||
.. include:: defaults.rst | ||
:start-after: :orphan: | ||
.. testcode:: | ||
metric = JSDivergence() | ||
metric.attach(default_evaluator, 'js-div') | ||
y_true = torch.tensor([ | ||
[ 0.0000, -2.3026, -2.3026], | ||
[ 1.3863, 1.6094, 1.6094], | ||
[ 0.0000, 0.6931, 1.0986] | ||
]) | ||
y_pred = torch.tensor([ | ||
[ 0.0000, 0.6931, 1.0986], | ||
[ 1.3863, 1.6094, 1.6094], | ||
[ 0.0000, -2.3026, -2.3026] | ||
]) | ||
state = default_evaluator.run([[y_pred, y_true]]) | ||
print(state.metrics['js-div']) | ||
.. testoutput:: | ||
0.16266516844431558 | ||
""" | ||
|
||
def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: | ||
m_prob = (F.softmax(y_pred, dim=1) + F.softmax(y, dim=1)) / 2 | ||
m_log = m_prob.log() | ||
y_pred = F.log_softmax(y_pred, dim=1) | ||
y = F.log_softmax(y, dim=1) | ||
self._sum_of_kl += ( | ||
F.kl_div(m_log, y_pred, log_target=True, reduction="sum") | ||
+ F.kl_div(m_log, y, log_target=True, reduction="sum") | ||
).to(self._device) | ||
|
||
@sync_all_reduce("_sum_of_kl", "_num_examples") | ||
def compute(self) -> float: | ||
if self._num_examples == 0: | ||
raise NotComputableError("JSDivergence must have at least one example before it can be computed.") | ||
return self._sum_of_kl.item() / (self._num_examples * 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from typing import Sequence | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from ignite.exceptions import NotComputableError | ||
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce | ||
|
||
__all__ = ["KLDivergence"] | ||
|
||
|
||
class KLDivergence(Metric): | ||
r"""Calculates the mean of `Kullback-Leibler (KL) divergence | ||
<https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>`_. | ||
.. math:: D_\text{KL}(\mathbf{p}_i \| \mathbf{q}_i) = \sum_{c=1}^C p_{i,c} \log \frac{p_{i,c}}{q_{i,c}} | ||
where :math:`\mathbf{p}_i` and :math:`\mathbf{q}_i` are the ground truth and prediction probability tensors. | ||
- ``update`` must receive output of the form ``(y_pred, y)``. | ||
- ``y_pred`` and ``y`` are expected to be the unnormalized logits for each class. :math:`(B, C)` (classification) | ||
or :math:`(B, C, ...)` (e.g., image segmentation) shapes are allowed. | ||
Args: | ||
output_transform: a callable that is used to transform the | ||
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
you want to compute the metric with respect to one of the outputs. | ||
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. | ||
device: specifies which device updates are accumulated on. Setting the | ||
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is | ||
non-blocking. By default, CPU. | ||
Examples: | ||
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine. | ||
The output of the engine's ``process_function`` needs to be in the format of | ||
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added | ||
to the metric to transform the output into the form expected by the metric. | ||
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. | ||
.. include:: defaults.rst | ||
:start-after: :orphan: | ||
.. testcode:: | ||
metric = KLDivergence() | ||
metric.attach(default_evaluator, 'kl-div') | ||
y_true = torch.tensor([ | ||
[ 0.0000, -2.3026, -2.3026], | ||
[ 1.3863, 1.6094, 1.6094], | ||
[ 0.0000, 0.6931, 1.0986] | ||
]) | ||
y_pred = torch.tensor([ | ||
[ 0.0000, 0.6931, 1.0986], | ||
[ 1.3863, 1.6094, 1.6094], | ||
[ 0.0000, -2.3026, -2.3026] | ||
]) | ||
state = default_evaluator.run([[y_pred, y_true]]) | ||
print(state.metrics['kl-div']) | ||
.. testoutput:: | ||
0.7220296859741211 | ||
""" | ||
|
||
_state_dict_all_req_keys = ("_sum_of_kl", "_num_examples") | ||
|
||
@reinit__is_reduced | ||
def reset(self) -> None: | ||
self._sum_of_kl = torch.tensor(0.0, device=self._device) | ||
self._num_examples = 0 | ||
|
||
@reinit__is_reduced | ||
def update(self, output: Sequence[torch.Tensor]) -> None: | ||
y_pred, y = output[0].detach(), output[1].detach() | ||
if y_pred.shape != y.shape: | ||
raise ValueError(f"y_pred and y must be in the same shape, got {y_pred.shape} != {y.shape}.") | ||
|
||
if y_pred.ndim >= 3: | ||
num_classes = y_pred.shape[1] | ||
# (B, C, ...) -> (B, ..., C) -> (B*..., C) | ||
# regarding as B*... predictions | ||
y_pred = y_pred.movedim(1, -1).reshape(-1, num_classes) | ||
y = y.movedim(1, -1).reshape(-1, num_classes) | ||
elif y_pred.ndim == 1: | ||
raise ValueError(f"y_pred must be in the shape of (B, C) or (B, C, ...), got {y_pred.shape}.") | ||
|
||
self._num_examples += y_pred.shape[0] | ||
self._update(y_pred, y) | ||
|
||
def _update(self, y_pred: torch.Tensor, y: torch.Tensor) -> None: | ||
y_pred = F.log_softmax(y_pred, dim=1) | ||
y = F.log_softmax(y, dim=1) | ||
kl_sum = F.kl_div(y_pred, y, log_target=True, reduction="sum") | ||
self._sum_of_kl += kl_sum.to(self._device) | ||
|
||
@sync_all_reduce("_sum_of_kl", "_num_examples") | ||
def compute(self) -> float: | ||
if self._num_examples == 0: | ||
raise NotComputableError("KLDivergence must have at least one example before it can be computed.") | ||
return self._sum_of_kl.item() / self._num_examples |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
import pytest | ||
import torch | ||
from scipy.spatial.distance import jensenshannon | ||
from scipy.special import softmax | ||
from torch import Tensor | ||
|
||
import ignite.distributed as idist | ||
from ignite.engine import Engine | ||
from ignite.exceptions import NotComputableError | ||
from ignite.metrics import JSDivergence | ||
|
||
|
||
def scipy_js_div(np_y_pred: np.ndarray, np_y: np.ndarray) -> float: | ||
y_pred_prob = softmax(np_y_pred, axis=1) | ||
y_prob = softmax(np_y, axis=1) | ||
# jensenshannon computes the sqrt of the JS divergence | ||
js_mean = np.mean(np.square(jensenshannon(y_pred_prob, y_prob, axis=1))) | ||
return js_mean | ||
|
||
|
||
def test_zero_sample(): | ||
js_div = JSDivergence() | ||
with pytest.raises( | ||
NotComputableError, match=r"JSDivergence must have at least one example before it can be computed" | ||
): | ||
js_div.compute() | ||
|
||
|
||
def test_shape_mismatch(): | ||
js_div = JSDivergence() | ||
y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) | ||
y = torch.tensor([[-2.0, 1.0]], dtype=torch.float) | ||
with pytest.raises(ValueError, match=r"y_pred and y must be in the same shape, got"): | ||
js_div.update((y_pred, y)) | ||
|
||
|
||
def test_invalid_shape(): | ||
js_div = JSDivergence() | ||
y_pred = torch.tensor([2.0, 3.0], dtype=torch.float) | ||
y = torch.tensor([4.0, 5.0], dtype=torch.float) | ||
with pytest.raises(ValueError, match=r"y_pred must be in the shape of \(B, C\) or \(B, C, ...\), got"): | ||
js_div.update((y_pred, y)) | ||
|
||
|
||
@pytest.fixture(params=list(range(4))) | ||
def test_case(request): | ||
return [ | ||
(torch.randn((100, 10)), torch.rand((100, 10)), 1), | ||
(torch.rand((100, 500)), torch.randn((100, 500)), 1), | ||
# updated batches | ||
(torch.normal(0.0, 5.0, size=(100, 10)), torch.rand((100, 10)), 16), | ||
(torch.normal(5.0, 3.0, size=(100, 200)), torch.rand((100, 200)), 16), | ||
# image segmentation | ||
(torch.randn((100, 5, 32, 32)), torch.rand((100, 5, 32, 32)), 16), | ||
(torch.rand((100, 5, 224, 224)), torch.randn((100, 5, 224, 224)), 16), | ||
][request.param] | ||
|
||
|
||
@pytest.mark.parametrize("n_times", range(5)) | ||
def test_compute(n_times, test_case: Tuple[Tensor, Tensor, int]): | ||
y_pred, y, batch_size = test_case | ||
|
||
js_div = JSDivergence() | ||
|
||
js_div.reset() | ||
if batch_size > 1: | ||
n_iters = y.shape[0] // batch_size + 1 | ||
for i in range(n_iters): | ||
idx = i * batch_size | ||
js_div.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size])) | ||
else: | ||
js_div.update((y_pred, y)) | ||
|
||
res = js_div.compute() | ||
|
||
np_y_pred = y_pred.numpy() | ||
np_y = y.numpy() | ||
|
||
np_res = scipy_js_div(np_y_pred, np_y) | ||
|
||
assert isinstance(res, float) | ||
assert pytest.approx(np_res, rel=1e-4) == res | ||
|
||
|
||
def test_accumulator_detached(): | ||
js_div = JSDivergence() | ||
|
||
y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float) | ||
y = torch.tensor([[-2.0, 1.0], [2.0, 3.0]], dtype=torch.float) | ||
js_div.update((y_pred, y)) | ||
|
||
assert not js_div._sum_of_kl.requires_grad | ||
|
||
|
||
@pytest.mark.usefixtures("distributed") | ||
class TestDistributed: | ||
def test_integration(self): | ||
tol = 1e-4 | ||
n_iters = 100 | ||
batch_size = 10 | ||
n_dims = 100 | ||
|
||
rank = idist.get_rank() | ||
torch.manual_seed(12 + rank) | ||
|
||
device = idist.device() | ||
metric_devices = [torch.device("cpu")] | ||
if device.type != "xla": | ||
metric_devices.append(device) | ||
|
||
for metric_device in metric_devices: | ||
y_true = torch.randn((n_iters * batch_size, n_dims)).float().to(device) | ||
y_preds = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_dims)).float().to(device) | ||
|
||
engine = Engine( | ||
lambda e, i: ( | ||
y_preds[i * batch_size : (i + 1) * batch_size], | ||
y_true[i * batch_size : (i + 1) * batch_size], | ||
) | ||
) | ||
|
||
m = JSDivergence(device=metric_device) | ||
m.attach(engine, "js_div") | ||
|
||
data = list(range(n_iters)) | ||
engine.run(data=data, max_epochs=1) | ||
|
||
y_preds = idist.all_gather(y_preds) | ||
y_true = idist.all_gather(y_true) | ||
|
||
assert "js_div" in engine.state.metrics | ||
res = engine.state.metrics["js_div"] | ||
|
||
y_true_np = y_true.cpu().numpy() | ||
y_preds_np = y_preds.cpu().numpy() | ||
true_res = scipy_js_div(y_preds_np, y_true_np) | ||
|
||
assert pytest.approx(true_res, rel=tol) == res | ||
|
||
def test_accumulator_device(self): | ||
device = idist.device() | ||
metric_devices = [torch.device("cpu")] | ||
if device.type != "xla": | ||
metric_devices.append(device) | ||
for metric_device in metric_devices: | ||
js_div = JSDivergence(device=metric_device) | ||
|
||
for dev in (js_div._device, js_div._sum_of_kl.device): | ||
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" | ||
|
||
y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]]).float() | ||
y = torch.ones(2, 2).float() | ||
js_div.update((y_pred, y)) | ||
|
||
for dev in (js_div._device, js_div._sum_of_kl.device): | ||
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}" |
Oops, something went wrong.