Skip to content

Commit

Permalink
Merge branch 'more-distrib-tests-updatest' of https://github.com/pyto…
Browse files Browse the repository at this point in the history
…rch/ignite into more-distrib-tests-updatest
  • Loading branch information
vfdev-5 committed Mar 20, 2024
2 parents bdc4497 + 74db37b commit 3a62b35
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,7 @@ Complete list of metrics
RougeN
InceptionScore
FID
CosineSimilarity

Helpers for customizing metrics
-------------------------------
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ignite.metrics.accuracy import Accuracy
from ignite.metrics.classification_report import ClassificationReport
from ignite.metrics.confusion_matrix import ConfusionMatrix, DiceCoefficient, IoU, JaccardIndex, mIoU
from ignite.metrics.cosine_similarity import CosineSimilarity
from ignite.metrics.epoch_metric import EpochMetric
from ignite.metrics.fbeta import Fbeta
from ignite.metrics.frequency import Frequency
Expand Down Expand Up @@ -33,6 +34,7 @@
"MeanPairwiseDistance",
"MeanSquaredError",
"ConfusionMatrix",
"CosineSimilarity",
"ClassificationReport",
"TopKCategoricalAccuracy",
"Average",
Expand Down
99 changes: 99 additions & 0 deletions ignite/metrics/cosine_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Callable, Sequence, Union

import torch

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["CosineSimilarity"]


class CosineSimilarity(Metric):
r"""Calculates the mean of the `cosine similarity <https://en.wikipedia.org/wiki/Cosine_similarity>`_.
.. math::
\text{cosine\_similarity} = \frac{1}{N} \sum_{i=1}^N
\frac{x_i \cdot y_i}{\max ( \| x_i \|_2 \| y_i \|_2 , \epsilon)}
where :math:`y_{i}` is the prediction tensor and :math:`x_{i}` is ground true tensor.
- ``update`` must receive output of the form ``(y_pred, y)``.
Args:
eps: a small value to avoid division by zero. Default: 1e-8
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
Examples:
To use with ``Engine`` and ``process_function``, simply attach the metric instance to the engine.
The output of the engine's ``process_function`` needs to be in the format of
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y, ...}``. If not, ``output_tranform`` can be added
to the metric to transform the output into the form expected by the metric.
``y_pred`` and ``y`` should have the same shape.
For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.
.. include:: defaults.rst
:start-after: :orphan:
.. testcode::
metric = CosineSimilarity()
metric.attach(default_evaluator, 'cosine_similarity')
preds = torch.tensor([
[1, 2, 4, 1],
[2, 3, 1, 5],
[1, 3, 5, 1],
[1, 5, 1 ,11]
]).float()
target = torch.tensor([
[1, 5, 1 ,11],
[1, 3, 5, 1],
[2, 3, 1, 5],
[1, 2, 4, 1]
]).float()
state = default_evaluator.run([[preds, target]])
print(state.metrics['cosine_similarity'])
.. testoutput::
0.5080491304397583
"""

def __init__(
self,
eps: float = 1e-8,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
):
super().__init__(output_transform, device)

self.eps = eps

_state_dict_all_req_keys = ("_sum_of_cos_similarities", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_cos_similarities = torch.tensor(0.0, device=self._device)
self._num_examples = 0

@reinit__is_reduced
def update(self, output: Sequence[torch.Tensor]) -> None:
y_pred = output[0].flatten(start_dim=1).detach()
y = output[1].flatten(start_dim=1).detach()
cos_similarities = torch.nn.functional.cosine_similarity(y_pred, y, dim=1, eps=self.eps)
self._sum_of_cos_similarities += torch.sum(cos_similarities).to(self._device)
self._num_examples += y.shape[0]

@sync_all_reduce("_sum_of_cos_similarities", "_num_examples")
def compute(self) -> float:
if self._num_examples == 0:
raise NotComputableError("CosineSimilarity must have at least one example before it can be computed.")
return self._sum_of_cos_similarities.item() / self._num_examples
211 changes: 211 additions & 0 deletions tests/ignite/metrics/test_cosine_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
import os

import numpy as np
import pytest
import torch

import ignite.distributed as idist
from ignite.exceptions import NotComputableError
from ignite.metrics import CosineSimilarity


def test_zero_sample():
cos_sim = CosineSimilarity()
with pytest.raises(
NotComputableError, match=r"CosineSimilarity must have at least one example before it can be computed"
):
cos_sim.compute()


@pytest.fixture(params=[item for item in range(4)])
def test_case(request):
return [
(torch.randn((100, 50)), torch.randn((100, 50)), 10 ** np.random.uniform(-8, 0), 1),
(
torch.normal(1.0, 2.0, size=(100, 10)),
torch.normal(3.0, 4.0, size=(100, 10)),
10 ** np.random.uniform(-8, 0),
1,
),
# updated batches
(torch.rand((100, 128)), torch.rand((100, 128)), 10 ** np.random.uniform(-8, 0), 16),
(
torch.normal(0.0, 5.0, size=(100, 30)),
torch.normal(5.0, 1.0, size=(100, 30)),
10 ** np.random.uniform(-8, 0),
16,
),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_compute(n_times, test_case):
y_pred, y, eps, batch_size = test_case

cos = CosineSimilarity(eps=eps)

cos.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
cos.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
cos.update((y_pred, y))

np_y = y.numpy()
np_y_pred = y_pred.numpy()

np_y_norm = np.clip(np.linalg.norm(np_y, axis=1, keepdims=True), eps, None)
np_y_pred_norm = np.clip(np.linalg.norm(np_y_pred, axis=1, keepdims=True), eps, None)
np_res = np.sum((np_y / np_y_norm) * (np_y_pred / np_y_pred_norm), axis=1)
np_res = np.mean(np_res)

assert isinstance(cos.compute(), float)
assert pytest.approx(np_res, rel=2e-5) == cos.compute()


def _test_distrib_integration(device, tol=2e-5):
from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12 + rank)

def _test(metric_device):
n_iters = 100
batch_size = 10
n_dims = 100

y_true = torch.randn((n_iters * batch_size, n_dims), dtype=torch.float).to(device)
y_preds = torch.normal(2.0, 3.0, size=(n_iters * batch_size, n_dims), dtype=torch.float).to(device)

def update(engine, i):
return (
y_preds[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)

m = CosineSimilarity(device=metric_device)
m.attach(engine, "cosine_similarity")

data = list(range(n_iters))
engine.run(data=data, max_epochs=1)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "cosine_similarity" in engine.state.metrics
res = engine.state.metrics["cosine_similarity"]

y_true_np = y_true.cpu().numpy()
y_preds_np = y_preds.cpu().numpy()
y_true_norm = np.clip(np.linalg.norm(y_true_np, axis=1, keepdims=True), 1e-8, None)
y_preds_norm = np.clip(np.linalg.norm(y_preds, axis=1, keepdims=True), 1e-8, None)
true_res = np.sum((y_true_np / y_true_norm) * (y_preds_np / y_preds_norm), axis=1)
true_res = np.mean(true_res)

assert pytest.approx(res, rel=tol) == true_res

_test("cpu")
if device.type != "xla":
_test(idist.device())


def _test_distrib_accumulator_device(device):
metric_devices = [torch.device("cpu")]
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
device = torch.device(device)
cos = CosineSimilarity(device=metric_device)

for dev in [cos._device, cos._sum_of_cos_similarities.device]:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"

y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float)
y = torch.ones(2, 2, dtype=torch.float)
cos.update((y_pred, y))

for dev in [cos._device, cos._sum_of_cos_similarities.device]:
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"


def test_accumulator_detached():
cos = CosineSimilarity()

y_pred = torch.tensor([[2.0, 3.0], [-2.0, 1.0]], dtype=torch.float)
y = torch.ones(2, 2, dtype=torch.float)
cos.update((y_pred, y))

assert not cos._sum_of_cos_similarities.requires_grad


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_accumulator_device(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_accumulator_device(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):
device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_accumulator_device(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_accumulator_device(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_integration(device, tol=1e-4)
_test_distrib_accumulator_device(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_integration(device, tol=1e-4)
_test_distrib_accumulator_device(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)

0 comments on commit 3a62b35

Please sign in to comment.