Skip to content

Commit

Permalink
Added metric test for Horovod
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Oct 6, 2020
1 parent 48b1d74 commit 0a8fdae
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 7 deletions.
8 changes: 2 additions & 6 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import nn

from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import gather_all_tensors_if_available
from pytorch_lightning.utilities.distributed import gather_all_tensors_if_available, is_distributed
from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum


Expand Down Expand Up @@ -179,11 +179,7 @@ def wrapped_func(*args, **kwargs):
if self._computed is not None:
return self._computed

if (
self._to_sync
and torch.distributed.is_available() # noqa: W503
and torch.distributed.is_initialized() # noqa: W503
):
if self._to_sync and is_distributed:
self._sync_dist()

self._computed = compute(*args, **kwargs)
Expand Down
15 changes: 14 additions & 1 deletion pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def find_free_network_port() -> int:
return port


def is_distributed():
return (torch.distributed.is_available() and torch.distributed.is_initialized()) or \
(HOROVOD_AVAILABLE and hvd.is_initialized())


def gather_all_tensors_if_available(result: Union[torch.Tensor], group: Optional[Any] = None):
"""
Function to gather all tensors from several distributed processes onto a list that
Expand Down Expand Up @@ -124,9 +129,15 @@ def gather_horovod(result: Union[torch.Tensor], group: Optional[Any] = None):
"Unset `group`."
)

if len(result.shape) == 0:
# Convert scalars to single dimension tensors
result = result.reshape(1)

# sync and gather all
hvd.join()
return hvd.allgather(result)
gathered = hvd.allgather(result)
gathered_result = list(gathered.split(1, dim=0))
return gathered_result


def sync_dist_if_available(
Expand Down Expand Up @@ -182,6 +193,8 @@ def sync_ddp(
if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)

return result


def sync_horovod(
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
Expand Down
47 changes: 47 additions & 0 deletions tests/models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
import subprocess
import sys

import numpy as np
import pytest
import torch

from sklearn.metrics import accuracy_score

import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult
from pytorch_lightning.metrics.classification.accuracy import Accuracy
from tests.base import EvalModelTemplate
from tests.base.models import BasicGAN

Expand Down Expand Up @@ -198,6 +202,7 @@ def get_optimizer_params(optimizer):


@pytest.mark.parametrize("result_cls", [Result, TrainResult, EvalResult])
@pytest.mark.skipif(not HOROVOD_AVAILABLE, reason="Horovod is unavailable")
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
def test_result_reduce_horovod(result_cls):
"""Make sure result logging works with Horovod."""
Expand All @@ -217,6 +222,48 @@ def hvd_test_fn():

horovod.run(hvd_test_fn, np=2)


def test_accuracy_metric_horovod():
num_batches = 10
batch_size = 16
threshold = 0.5

def sk_metric(preds, target):
sk_preds = (preds.view(-1).numpy() >= threshold).astype(np.uint8)
sk_target = target.view(-1).numpy()
return accuracy_score(y_true=sk_target, y_pred=sk_preds)

preds = torch.rand(num_batches, batch_size)
target = torch.randint(high=2, size=(num_batches, batch_size))

def _compute_batch():
import horovod.torch as hvd
hvd.init()

metric = Accuracy(compute_on_step=True,
dist_sync_on_step=True,
threshold=threshold)

for i in range(hvd.rank(), num_batches, hvd.size()):
batch_result = metric(preds[i], target[i])
if hvd.rank() == 0:
dist_preds = torch.stack([preds[i + r] for r in range(hvd.size())])
dist_target = torch.stack([target[i + r] for r in range(hvd.size())])
sk_batch_result = sk_metric(dist_preds, dist_target)
assert np.allclose(batch_result.numpy(), sk_batch_result)

# check on all batches on all ranks
result = metric.compute()
assert isinstance(result, torch.Tensor)

total_preds = torch.stack([preds[i] for i in range(num_batches)])
total_target = torch.stack([target[i] for i in range(num_batches)])
sk_result = sk_metric(total_preds, total_target)

assert np.allclose(result.numpy(), sk_result)

horovod.run(_compute_batch, np=2)

# @pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
# def test_horovod_multi_optimizer_with_scheduling_stepping(tmpdir):
# hparams = EvalModelTemplate.get_default_hparams()
Expand Down

0 comments on commit 0a8fdae

Please sign in to comment.