Skip to content

Commit

Permalink
updated test average precision
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Sep 24, 2023
1 parent ea7cb1d commit 9a76c7b
Showing 1 changed file with 62 additions and 71 deletions.
133 changes: 62 additions & 71 deletions tests/ignite/contrib/metrics/test_average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,91 +63,82 @@ def test_check_shape():
ap._check_shape((torch.rand(4, 3), torch.rand(4, 3, 1)))


def test_binary_and_multilabel_inputs():
@pytest.fixture(params=[item for item in range(8)])
def test_data_binary_and_multilabel(request):
return [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 1),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16),
# Binary input data of shape (N, L)
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 1),
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16),
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16),
][request.param]


@pytest.mark.parametrize("n_times", range(5))
def test_binary_and_multilabel_inputs(n_times, test_data_binary_and_multilabel):
y_pred, y, batch_size = test_data_binary_and_multilabel
ap = AveragePrecision()
ap.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
ap.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
ap.update((y_pred, y))

def _test(y_pred, y, batch_size):
ap.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
for i in range(n_iters):
idx = i * batch_size
ap.update((y_pred[idx : idx + batch_size], y[idx : idx + batch_size]))
else:
ap.update((y_pred, y))

np_y = y.numpy()
np_y_pred = y_pred.numpy()

res = ap.compute()
assert isinstance(res, float)
assert average_precision_score(np_y, np_y_pred) == pytest.approx(res)

def get_test_cases():
test_cases = [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 1),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50,)).long(), torch.randint(0, 2, size=(50,)).long(), 16),
(torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16),
# Binary input data of shape (N, L)
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 1),
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 1),
# updated batches
(torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16),
(torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16),
]
np_y = y.numpy()
np_y_pred = y_pred.numpy()

return test_cases
res = ap.compute()
assert isinstance(res, float)
assert average_precision_score(np_y, np_y_pred) == pytest.approx(res)

for _ in range(5):
# check multiple random inputs as random exact occurencies are rare
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size)

@pytest.fixture(params=[item for item in range(4)])
def test_data_integration_binary_and_multilabel(request):
return [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(100,)).long(), torch.randint(0, 2, size=(100,)).long(), 10),
(torch.randint(0, 2, size=(100, 1)).long(), torch.randint(0, 2, size=(100, 1)).long(), 10),
# Binary input data of shape (N, L)
(torch.randint(0, 2, size=(100, 3)).long(), torch.randint(0, 2, size=(100, 3)).long(), 10),
(torch.randint(0, 2, size=(100, 4)).long(), torch.randint(0, 2, size=(100, 4)).long(), 10),
][request.param]

def test_integration_binary_and_mulitlabel_inputs():
def _test(y_pred, y, batch_size):
def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

engine = Engine(update_fn)
@pytest.mark.parametrize("n_times", range(5))
def test_integration_binary_and_mulitlabel_inputs(n_times, test_data_integration_binary_and_multilabel):
y_pred, y, batch_size = test_data_integration_binary_and_multilabel

ap_metric = AveragePrecision()
ap_metric.attach(engine, "ap")
def update_fn(engine, batch):
idx = (engine.state.iteration - 1) * batch_size
y_true_batch = np_y[idx : idx + batch_size]
y_pred_batch = np_y_pred[idx : idx + batch_size]
return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch)

np_y = y.numpy()
np_y_pred = y_pred.numpy()
engine = Engine(update_fn)

np_ap = average_precision_score(np_y, np_y_pred)
ap_metric = AveragePrecision()
ap_metric.attach(engine, "ap")

data = list(range(y_pred.shape[0] // batch_size))
ap = engine.run(data, max_epochs=1).metrics["ap"]
np_y = y.numpy()
np_y_pred = y_pred.numpy()

assert isinstance(ap, float)
assert np_ap == pytest.approx(ap)
np_ap = average_precision_score(np_y, np_y_pred)

def get_test_cases():
test_cases = [
# Binary input data of shape (N,) or (N, 1)
(torch.randint(0, 2, size=(100,)).long(), torch.randint(0, 2, size=(100,)).long(), 10),
(torch.randint(0, 2, size=(100, 1)).long(), torch.randint(0, 2, size=(100, 1)).long(), 10),
# Binary input data of shape (N, L)
(torch.randint(0, 2, size=(100, 3)).long(), torch.randint(0, 2, size=(100, 3)).long(), 10),
(torch.randint(0, 2, size=(100, 4)).long(), torch.randint(0, 2, size=(100, 4)).long(), 10),
]
return test_cases
data = list(range(y_pred.shape[0] // batch_size))
ap = engine.run(data, max_epochs=1).metrics["ap"]

for _ in range(5):
# check multiple random inputs as random exact occurencies are rare
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size)
assert isinstance(ap, float)
assert np_ap == pytest.approx(ap)


def _test_distrib_binary_and_multilabel_inputs(device):
Expand Down

0 comments on commit 9a76c7b

Please sign in to comment.