Skip to content

Commit

Permalink
Intersection over Union Metric/Loss (#469)
Browse files Browse the repository at this point in the history
* Implement IoU metric and IoU loss

* Add tests for IoU metric & IoU loss

* Add documentation for IoU loss

* Update CHANGELOG

* Update IoU docstring

* Add tensor shape for IoU docstring

* Add tests for IoU/GIoU from torchvision
  • Loading branch information
briankosw authored Jan 3, 2021
1 parent a189d7c commit 88e30b2
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added GIoU loss ([#347](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/347))

- Added IoU loss ([#469](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/469))

### Changed

- Decoupled datamodules from models ([#332](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/332), [#270](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/270))
Expand Down
8 changes: 8 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ GIoU Loss

---------------

IoU Loss
--------

.. autofunction:: pl_bolts.losses.object_detection.iou_loss
:noindex:

---------------

Reinforcement Learning
======================
These are common losses used in RL.
Expand Down
28 changes: 26 additions & 2 deletions pl_bolts/losses/object_detection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
"""
Generalized Intersection over Union (GIoU) loss (Rezatofighi et. al)
Loss functions for Object Detection task
"""

import torch

from pl_bolts.metrics.object_detection import giou
from pl_bolts.metrics.object_detection import giou, iou


def iou_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Calculates the intersection over union loss.
Args:
preds: batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]``
target: batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]``
Example:
>>> import torch
>>> from pl_bolts.losses.object_detection import iou_loss
>>> preds = torch.tensor([[100, 100, 200, 200]])
>>> target = torch.tensor([[150, 150, 250, 250]])
>>> iou_loss(preds, target)
tensor([[0.8571]])
Returns:
IoU loss
"""
loss = 1 - iou(preds, target)
return loss


def giou_loss(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
Expand Down
33 changes: 33 additions & 0 deletions pl_bolts/metrics/object_detection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,39 @@
import torch


def iou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Calculates the intersection over union.
Args:
preds: an Nx4 batch of prediction bounding boxes with representation ``[x_min, y_min, x_max, y_max]``
target: an Mx4 batch of target bounding boxes with representation ``[x_min, y_min, x_max, y_max]``
Example:
>>> import torch
>>> from pl_bolts.metrics.object_detection import iou
>>> preds = torch.tensor([[100, 100, 200, 200]])
>>> target = torch.tensor([[150, 150, 250, 250]])
>>> iou(preds, target)
tensor([[0.1429]])
Returns:
IoU tensor: an NxM tensor containing the pairwise IoU values for every element in preds and target,
where N is the number of prediction bounding boxes and M is the number of target bounding boxes
"""
x_min = torch.max(preds[:, None, 0], target[:, 0])
y_min = torch.max(preds[:, None, 1], target[:, 1])
x_max = torch.min(preds[:, None, 2], target[:, 2])
y_max = torch.min(preds[:, None, 3], target[:, 3])
intersection = (x_max - x_min).clamp(min=0) * (y_max - y_min).clamp(min=0)
pred_area = (preds[:, 2] - preds[:, 0]) * (preds[:, 3] - preds[:, 1])
target_area = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
union = pred_area[:, None] + target_area - intersection
iou = torch.true_divide(intersection, union)
return iou


def giou(preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Calculates the generalized intersection over union.
Expand Down
17 changes: 16 additions & 1 deletion tests/losses/test_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,22 @@
import pytest
import torch

from pl_bolts.losses.object_detection import giou_loss
from pl_bolts.losses.object_detection import giou_loss, iou_loss


@pytest.mark.parametrize("preds, target, expected_loss", [
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([0.0]))
])
def test_iou_complete_overlap(preds, target, expected_loss):
torch.testing.assert_allclose(iou_loss(preds, target), expected_loss)


@pytest.mark.parametrize("preds, target, expected_loss", [
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 200, 200, 300]]), torch.tensor([1.0])),
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[200, 200, 300, 300]]), torch.tensor([1.0])),
])
def test_iou_no_overlap(preds, target, expected_loss):
torch.testing.assert_allclose(iou_loss(preds, target), expected_loss)


@pytest.mark.parametrize("preds, target, expected_loss", [
Expand Down
39 changes: 38 additions & 1 deletion tests/metrics/test_object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,33 @@
import pytest
import torch

from pl_bolts.metrics.object_detection import giou
from pl_bolts.metrics.object_detection import giou, iou


@pytest.mark.parametrize("preds, target, expected_iou", [
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 100, 200, 200]]), torch.tensor([1.0]))
])
def test_iou_complete_overlap(preds, target, expected_iou):
torch.testing.assert_allclose(iou(preds, target), expected_iou)


@pytest.mark.parametrize("preds, target, expected_iou", [
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[100, 200, 200, 300]]), torch.tensor([0.0])),
(torch.tensor([[100, 100, 200, 200]]), torch.tensor([[200, 200, 300, 300]]), torch.tensor([0.0])),
])
def test_iou_no_overlap(preds, target, expected_iou):
torch.testing.assert_allclose(iou(preds, target), expected_iou)


@pytest.mark.parametrize("preds, target, expected_iou", [
(
torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]),
torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]),
torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])
)
])
def test_iou_multi(preds, target, expected_iou):
torch.testing.assert_allclose(iou(preds, target), expected_iou)


@pytest.mark.parametrize("preds, target, expected_giou", [
Expand All @@ -21,3 +47,14 @@ def test_complete_overlap(preds, target, expected_giou):
])
def test_no_overlap(preds, target, expected_giou):
torch.testing.assert_allclose(giou(preds, target), expected_giou)


@pytest.mark.parametrize("preds, target, expected_giou", [
(
torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]),
torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]),
torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]])
)
])
def test_giou_multi(preds, target, expected_giou):
torch.testing.assert_allclose(giou(preds, target), expected_giou)

0 comments on commit 88e30b2

Please sign in to comment.