Skip to content

Commit

Permalink
Set masks to zero where masks overlap (#8213)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
Co-authored-by: Nicolas Hug <nh.nicolas.hug@gmail.com>
  • Loading branch information
3 people authored Jan 19, 2024
1 parent 660868b commit 6f0deb9
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 10 deletions.
20 changes: 10 additions & 10 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,11 @@ def test_draw_segmentation_masks(colors, alpha, device):
num_masks, h, w = 2, 100, 100
dtype = torch.uint8
img = torch.randint(0, 256, size=(3, h, w), dtype=dtype, device=device)
masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool, device=device)
masks = torch.zeros((num_masks, h, w), dtype=torch.bool, device=device)
masks[0, 10:20, 10:20] = True
masks[1, 15:25, 15:25] = True

# For testing we enforce that there's no overlap between the masks. The
# current behaviour is that the last mask's color will take priority when
# masks overlap, but this makes testing slightly harder, so we don't really
# care
overlap = masks[0] & masks[1]
masks[:, overlap] = False

out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
assert out.dtype == dtype
Expand All @@ -239,12 +236,15 @@ def test_draw_segmentation_masks(colors, alpha, device):
color = torch.tensor(color, dtype=dtype, device=device)

if alpha == 1:
assert (out[:, mask] == color[:, None]).all()
assert (out[:, mask & ~overlap] == color[:, None]).all()
elif alpha == 0:
assert (out[:, mask] == img[:, mask]).all()
assert (out[:, mask & ~overlap] == img[:, mask & ~overlap]).all()

interpolated_color = (img[:, mask] * (1 - alpha) + color[:, None] * alpha).to(dtype)
torch.testing.assert_close(out[:, mask], interpolated_color, rtol=0.0, atol=1.0)
interpolated_color = (img[:, mask & ~overlap] * (1 - alpha) + color[:, None] * alpha).to(dtype)
torch.testing.assert_close(out[:, mask & ~overlap], interpolated_color, rtol=0.0, atol=1.0)

interpolated_overlap = (img[:, overlap] * (1 - alpha)).to(dtype)
torch.testing.assert_close(out[:, overlap], interpolated_overlap, rtol=0.0, atol=1.0)


def test_draw_segmentation_masks_dtypes():
Expand Down
3 changes: 3 additions & 0 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def draw_segmentation_masks(
raise ValueError("The image and the masks must have the same height and width")

num_masks = masks.size()[0]
overlapping_masks = masks.sum(dim=0) > 1

if num_masks == 0:
warnings.warn("masks doesn't contain any mask. No mask was drawn")
Expand All @@ -315,6 +316,8 @@ def draw_segmentation_masks(
for mask, color in zip(masks, colors):
img_to_draw[:, mask] = color[:, None]

img_to_draw[:, overlapping_masks] = 0

out = image * (1 - alpha) + img_to_draw * alpha
# Note: at this point, out is a float tensor in [0, 1] or [0, 255] depending on original_dtype
return out.to(original_dtype)
Expand Down

0 comments on commit 6f0deb9

Please sign in to comment.