Skip to content

Commit

Permalink
Ported more distrib tests using TestDistributed
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Mar 20, 2024
1 parent 9b23519 commit 0a02515
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 339 deletions.
129 changes: 51 additions & 78 deletions tests/ignite/metrics/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sklearn.metrics import precision_score

import ignite.distributed as idist
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics import Precision

Expand Down Expand Up @@ -419,25 +420,30 @@ def test_incorrect_y_classes(average):

@pytest.mark.usefixtures("distributed")
class TestDistributed:
def test_integration_multiclass(self):
from ignite.engine import Engine

@pytest.mark.parametrize("average", [False, "macro", "weighted", "micro"])
@pytest.mark.parametrize("n_epochs", [1, 2])
def test_integration_multiclass(self, average, n_epochs):
rank = idist.get_rank()
torch.manual_seed(12)
torch.manual_seed(12 + rank)

def _test(average, n_epochs, metric_device):
n_iters = 60
s = 16
n_classes = 7
n_iters = 60
batch_size = 16
n_classes = 7

offset = n_iters * s
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device)
y_preds = torch.rand(offset * idist.get_world_size(), n_classes).to(device)
metric_devices = [torch.device("cpu")]
device = idist.device()
if device.type != "xla":
metric_devices.append(idist.device())

for metric_device in metric_devices:
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
y_preds = torch.rand(n_iters * batch_size, n_classes).to(device)

def update(engine, i):
return (
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
y_preds[i * batch_size : (i + 1) * batch_size, :],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)
Expand All @@ -449,6 +455,9 @@ def update(engine, i):
data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "pr" in engine.state.metrics
assert pr._updated is True
res = engine.state.metrics["pr"]
Expand All @@ -464,40 +473,29 @@ def update(engine, i):

assert pytest.approx(res) == true_res

metric_devices = [torch.device("cpu")]
@pytest.mark.parametrize("average", [False, "macro", "weighted", "micro", "samples"])
@pytest.mark.parametrize("n_epochs", [1, 2])
def test_integration_multilabel(self, average, n_epochs):
rank = idist.get_rank()
torch.manual_seed(12 + rank)

n_iters = 60
batch_size = 16
n_classes = 7

metric_devices = ["cpu"]
device = idist.device()
if device.type != "xla":
metric_devices.append(idist.device())
for _ in range(2):
for metric_device in metric_devices:
_test(average=False, n_epochs=1, metric_device=metric_device)
_test(average=False, n_epochs=2, metric_device=metric_device)
_test(average="macro", n_epochs=1, metric_device=metric_device)
_test(average="macro", n_epochs=2, metric_device=metric_device)
_test(average="weighted", n_epochs=1, metric_device=metric_device)
_test(average="weighted", n_epochs=2, metric_device=metric_device)
_test(average="micro", n_epochs=1, metric_device=metric_device)
_test(average="micro", n_epochs=2, metric_device=metric_device)

def test_integration_multilabel(self):
from ignite.engine import Engine

rank = idist.get_rank()
torch.manual_seed(12)

def _test(average, n_epochs, metric_device):
n_iters = 60
s = 16
n_classes = 7

offset = n_iters * s
y_true = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)
y_preds = torch.randint(0, 2, size=(offset * idist.get_world_size(), n_classes, 6, 8)).to(device)
for metric_device in metric_devices:
y_true = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device)
y_preds = torch.randint(0, 2, size=(n_iters * batch_size, n_classes, 6, 8)).to(device)

def update(engine, i):
return (
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, ...],
y_true[i * s + rank * offset : (i + 1) * s + rank * offset, ...],
y_preds[i * batch_size : (i + 1) * batch_size, ...],
y_true[i * batch_size : (i + 1) * batch_size, ...],
)

engine = Engine(update)
Expand All @@ -509,6 +507,9 @@ def update(engine, i):
data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "pr" in engine.state.metrics
assert pr._updated is True
res = engine.state.metrics["pr"]
Expand All @@ -528,27 +529,16 @@ def update(engine, i):
warnings.simplefilter("ignore", category=UndefinedMetricWarning)
assert precision_score(np_y_true, np_y_preds, average=sk_average_parameter) == pytest.approx(res)

metric_devices = ["cpu"]
@pytest.mark.parametrize("average", [False, "macro", "weighted", "micro"])
def test_accumulator_device(self, average):
# Binary accuracy on input of shape (N, 1) or (N, )

metric_devices = [torch.device("cpu")]
device = idist.device()
if device.type != "xla":
metric_devices.append(idist.device())
for _ in range(2):
for metric_device in metric_devices:
_test(average=False, n_epochs=1, metric_device=metric_device)
_test(average=False, n_epochs=2, metric_device=metric_device)
_test(average="macro", n_epochs=1, metric_device=metric_device)
_test(average="macro", n_epochs=2, metric_device=metric_device)
_test(average="micro", n_epochs=1, metric_device=metric_device)
_test(average="micro", n_epochs=2, metric_device=metric_device)
_test(average="weighted", n_epochs=1, metric_device=metric_device)
_test(average="weighted", n_epochs=2, metric_device=metric_device)
_test(average="samples", n_epochs=1, metric_device=metric_device)
_test(average="samples", n_epochs=2, metric_device=metric_device)

def test_accumulator_device(self):
# Binary accuracy on input of shape (N, 1) or (N, )

def _test(average, metric_device):
for metric_device in metric_devices:
pr = Precision(average=average, device=metric_device)
assert pr._device == metric_device
assert pr._updated is False
Expand All @@ -575,24 +565,18 @@ def _test(average, metric_device):
assert pr._weight.device == metric_device, f"{type(pr._weight.device)}:{pr._weight.device} vs "
f"{type(metric_device)}:{metric_device}"

@pytest.mark.parametrize("average", [False, "macro", "weighted", "micro", "samples"])
def test_multilabel_accumulator_device(self, average):
# Multiclass input data of shape (N, ) and (N, C)

metric_devices = [torch.device("cpu")]
device = idist.device()
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
_test(False, metric_device=metric_device)
_test("macro", metric_device=metric_device)
_test("micro", metric_device=metric_device)
_test("weighted", metric_device=metric_device)

def test_multilabel_accumulator_device(self):
# Multiclass input data of shape (N, ) and (N, C)

def _test(average, metric_device):
pr = Precision(is_multilabel=True, average=average, device=metric_device)

assert pr._updated is False
assert pr._device == metric_device
assert pr._updated is False

y_pred = torch.randint(0, 2, size=(10, 4, 20, 23))
y = torch.randint(0, 2, size=(10, 4, 20, 23)).long()
Expand All @@ -613,14 +597,3 @@ def _test(average, metric_device):
if average == "weighted":
assert pr._weight.device == metric_device, f"{type(pr._weight.device)}:{pr._weight.device} vs "
f"{type(metric_device)}:{metric_device}"

metric_devices = [torch.device("cpu")]
device = idist.device()
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
_test(False, metric_device=metric_device)
_test("macro", metric_device=metric_device)
_test("micro", metric_device=metric_device)
_test("weighted", metric_device=metric_device)
_test("samples", metric_device=metric_device)
Loading

0 comments on commit 0a02515

Please sign in to comment.