Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Jun 18, 2020
1 parent 287361a commit f50d991
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
2 changes: 1 addition & 1 deletion pytorch_lightning/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
multiclass_roc,
multiclass_precision_recall_curve,
dice_score,
iou
iou,
)
from pytorch_lightning.metrics.metric import TensorMetric, TensorCollectionMetric

Expand Down
27 changes: 12 additions & 15 deletions tests/metrics/functional/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,24 +367,21 @@ def test_dice_score(pred, target, expected):
assert score == expected


@pytest.mark.parametrize(['target', 'pred', 'half_ones', 'reduction', 'remove_bg', 'expected'], [
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
False, 'none', False, torch.Tensor([1, 1, 1])),
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
False, 'elementwise_mean', False, torch.Tensor([1])),
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
False, 'none', True, torch.Tensor([1, 1])),
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
True, 'none', False, torch.Tensor([0.5, 0.5, 0.5])),
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
True, 'elementwise_mean', False, torch.Tensor([0.5])),
pytest.param((torch.arange(120) % 3).view(-1, 1), (torch.arange(120) % 3).view(-1, 1),
True, 'none', True, torch.Tensor([0.5, 0.5])),
@pytest.mark.parametrize(['half_ones', 'reduction', 'remove_bg', 'expected'], [
pytest.param(False, 'none', False, torch.Tensor([1, 1, 1])),
pytest.param(False, 'elementwise_mean', False, torch.Tensor([1])),
pytest.param(False, 'none', True, torch.Tensor([1, 1])),
pytest.param(True, 'none', False, torch.Tensor([0.5, 0.5, 0.5])),
pytest.param(True, 'elementwise_mean', False, torch.Tensor([0.5])),
pytest.param(True, 'none', True, torch.Tensor([0.5, 0.5])),
])
def test_iou(target, pred, half_ones, reduction, remove_bg, expected):
def test_iou(half_ones, reduction, remove_bg, expected):
pred = (torch.arange(120) % 3).view(-1, 1)
target = (torch.arange(120) % 3).view(-1, 1)
if half_ones:
pred[:60] = 1
assert torch.all(torch.eq(iou(pred, target, remove_bg=remove_bg, reduction=reduction), expected))
iou_val = iou(pred, target, remove_bg=remove_bg, reduction=reduction)
assert torch.allclose(iou_val, expected, atol=1e-9)


# example data taken from
Expand Down

0 comments on commit f50d991

Please sign in to comment.