Skip to content

Commit

Permalink
add iou tests
Browse files Browse the repository at this point in the history
  • Loading branch information
j-dsouza authored and Borda committed Jun 18, 2020
1 parent 826d688 commit d814fab
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
21 changes: 21 additions & 0 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
precision_recall_curve,
roc,
auc,
iou
)


Expand Down Expand Up @@ -366,5 +367,25 @@ def test_dice_score(pred, target, expected):
assert score == expected


@pytest.mark.parametrize(['target', 'pred', 'half_ones', 'reduction', 'remove_bg', 'expected'], [
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
False, 'none', False, torch.Tensor([1, 1, 1])),
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
False, 'elementwise_mean', False, torch.Tensor([1])),
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
False, 'none', True, torch.Tensor([1, 1])),
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
True, 'none', False, torch.Tensor([0.5, 0.5, 0.5])),
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
True, 'elementwise_mean', False, torch.Tensor([0.5])),
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
True, 'none', True, torch.Tensor([0.5, 0.5])),
])
def test_iou(target, pred, half_ones, reduction, remove_bg, expected):
if half_ones:
pred[:60] = 1
assert torch.all(torch.eq(iou(pred, target, remove_bg=remove_bg, reduction=reduction), expected))


# example data taken from
# https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/metrics/tests/test_ranking.py
12 changes: 12 additions & 0 deletions tests/metrics/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MulticlassROC,
MulticlassPrecisionRecall,
DiceCoefficient,
IoU
)


Expand Down Expand Up @@ -205,3 +206,14 @@ def test_dice_coefficient(include_background):
dice = dice_coeff(torch.randint(0, 1, (10, 25, 25)),
torch.randint(0, 1, (10, 25, 25)))
assert isinstance(dice, torch.Tensor)


@pytest.mark.parametrize('remove_bg', [True, False])
def test_iou(remove_bg):
iou = IoU(remove_bg=remove_bg)
assert iou.name == 'iou'

score = iou(torch.randint(0, 1, (10, 25, 25)),
torch.randint(0, 1, (10, 25, 25)))

assert isinstance(score, torch.Tensor)

0 comments on commit d814fab

Please sign in to comment.