Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add masks to boundaries #7704

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ The below operators perform pre-processing as well as post-processing required i

batched_nms
masks_to_boxes
masks_to_boundaries
nms
roi_align
roi_pool
Expand Down
4 changes: 4 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def pytest_collection_modifyitems(items):
items[:] = out_items


def pytest_addoption(parser):
parser.addoption("--debug-images", action="store_true", help="Enable debug mode for saving images.")


def pytest_sessionfinish(session, exitstatus):
# This hook is called after all tests have run, and just before returning an exit status.
# We here change exit code 5 into 0.
Expand Down
109 changes: 108 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import math
import os
from abc import ABC, abstractmethod
Expand All @@ -7,12 +8,13 @@

import numpy as np
import pytest
import scipy.ndimage
import torch
import torch.fx
import torch.nn.functional as F
import torch.testing._internal.optests as optests
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
from PIL import Image
from PIL import Image, ImageDraw
from torch import nn, Tensor
from torch._dynamo.utils import is_compile_supported
from torch.autograd import gradcheck
Expand Down Expand Up @@ -741,6 +743,111 @@ def test_is_leaf_node(self, device):
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs


import matplotlib.pyplot as plt


class TestMasksToBoundaries(ABC):
def save_and_images(
self, original_masks, expected_boundaries, actual_boundaries, diff, filename_prefix, visualize=True
):
"""
Saves images separately for original masks, expected boundaries, actual boundaries, and their difference.

Parameters:
- original_masks: The starting binary masks tensor.
- expected_boundaries: The expected boundaries tensor.
- actual_boundaries: The actual boundaries tensor calculated by the function.
- diff: The absolute difference between expected and actual boundaries.
- filename_prefix: Prefix for the saved filename.
- visualize: Flag to enable or disable visualization.
"""
# Ensure directory exists
output_dir = "test_outputs"
os.makedirs(output_dir, exist_ok=True)
filepath_prefix = os.path.join(output_dir, filename_prefix)

num_images = original_masks.shape[0]

original_masks = original_masks.cpu().numpy() if original_masks.is_cuda else original_masks.numpy()
expected_boundaries = (
expected_boundaries.cpu().numpy() if expected_boundaries.is_cuda else expected_boundaries.numpy()
)
actual_boundaries = actual_boundaries.cpu().numpy() if actual_boundaries.is_cuda else actual_boundaries.numpy()
diff = diff.cpu().numpy() if diff.is_cuda else diff.numpy()

# Plot and save each image separately
for i in range(num_images):
original = original_masks[i].squeeze()
expected = expected_boundaries[i].squeeze()
actual = actual_boundaries[i].squeeze()
difference = diff[i].squeeze()

if visualize:
# Plotting
fig, axes = plt.subplots(1, 4, figsize=(20, 5))
titles = ["Original Mask", "Expected Boundaries", "Actual Boundaries", "Absolute Difference"]
images = [original, expected, actual, difference]

for ax, img, title in zip(axes, images, titles):
ax.imshow(img, cmap="gray", interpolation="nearest")
ax.axis("off")
ax.set_title(title)

plt.subplots_adjust(top=0.85)

# Save the figure
fig.tight_layout()
plt.savefig(f"{filepath_prefix}_image_{i}.png", bbox_inches="tight")
plt.close(fig)

@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("kernel_size", [3, 5]) # Example kernel sizes
@pytest.mark.parametrize("canvas_size", [32, 64]) # Example canvas sizes
@pytest.mark.parametrize("batch_size", [1, 4]) # Parametrizing over batch sizes, e.g., 1 and 4
def test_masks_to_boundaries(self, request, tmpdir, device, kernel_size, canvas_size, batch_size):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available on this system.")
debug_mode = request.config.getoption("--debug-images")
# Create masks with the specified canvas size and batch size
mask = torch.zeros(batch_size, canvas_size, canvas_size, dtype=torch.bool)

for b in range(batch_size):
if b % 4 == 0:
mask[b, 1:10, 1:10] = True
elif b % 4 == 1:
mask[b, 15:23, 15:23] = True
elif b % 4 == 2:
mask[b, 1:5, 22:30] = True
elif b % 4 == 3:
pil_img = Image.new("L", (canvas_size, canvas_size))
draw = ImageDraw.Draw(pil_img)
draw.ellipse([2, 7, min(26, canvas_size - 6), min(26, canvas_size - 6)], fill=1, outline=1, width=1)
ellipse_mask = torch.from_numpy(np.array(pil_img, dtype=np.uint8)).bool()
mask[b, ...] = ellipse_mask
mask = mask.to(device)
actual_boundaries = ops.masks_to_boundaries(mask, kernel_size)
expected_boundaries = torch.zeros_like(mask)
struct = np.ones((kernel_size, kernel_size), dtype=np.uint8)

# Calculate expected boundaries using scipy's binary_erosion
for i in range(batch_size):
single_mask = mask[i].cpu().numpy()
eroded_mask = scipy.ndimage.binary_erosion(single_mask, structure=struct, border_value=0)
single_expected_boundary = single_mask ^ eroded_mask
expected_boundaries[i] = torch.from_numpy(single_expected_boundary).to(device)

if debug_mode:
diff = torch.abs(expected_boundaries.float() - actual_boundaries.float())
filename_prefix = f"kernel_{kernel_size}_canvas_{canvas_size}_batch_{batch_size}"
output_file_path = tmpdir.join(f"{filename_prefix}.png")
# Log the path where the debug image will be saved
logging.info(f"Debug image saved at: {output_file_path}")

self.save_and_images(mask, expected_boundaries, actual_boundaries, diff, str(output_file_path))

torch.testing.assert_close(actual_boundaries, expected_boundaries)


class TestNMS:
def _reference_nms(self, boxes, scores, iou_threshold):
"""
Expand Down
2 changes: 2 additions & 0 deletions torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
distance_box_iou,
generalized_box_iou,
masks_to_boxes,
masks_to_boundaries,
nms,
remove_small_boxes,
)
Expand All @@ -32,6 +33,7 @@

__all__ = [
"masks_to_boxes",
"masks_to_boundaries",
"deform_conv2d",
"DeformConv2d",
"nms",
Expand Down
41 changes: 40 additions & 1 deletion torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Tuple

import torch
import torch.nn.functional as F
import torchvision
from torch import Tensor
from torchvision.extension import _assert_has_ops
Expand Down Expand Up @@ -379,7 +380,6 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso


def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Tensor, Tensor]:

iou = box_iou(boxes1, boxes2)
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
Expand All @@ -399,6 +399,45 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te
return iou - (centers_distance_squared / diagonal_distance_squared), iou


def masks_to_boundaries(masks: torch.Tensor, kernel_size: int) -> torch.Tensor:
"""
Compute the boundaries around the provided binary masks using morphological operations with a custom structuring element.
Enforces the use of an odd-sized kernel for the structuring element.

Parameters:
- masks: Input binary masks tensor of shape [N, H, W].
- kernel_size: Size of the kernel for the structuring element, must be odd.

Returns:
- Tensor representing the boundaries of the masks with shape [N, H, W].
"""
if masks.numel() == 0:
return torch.zeros_like(masks)

# Ensure kernel_size is odd
if kernel_size % 2 == 0:
raise ValueError("kernel_size must be odd.")

# Define the structuring element based on kernel_size
selem = torch.ones((1, 1, kernel_size, kernel_size), dtype=torch.float32, device=masks.device)

masks_float = masks.float().unsqueeze(1)

# Apply convolution with the structuring element
padding = (kernel_size - 1) // 2
eroded_masks = F.conv2d(masks_float, selem, padding=padding, stride=1)
eroded_masks = eroded_masks.squeeze(1) # Remove channel dimension after convolution

# Thresholding: a pixel in the eroded mask should be set if the convolution result
# is equal to the sum of the structuring element (i.e., all ones in the kernel)
threshold = torch.sum(selem).item()
eroded_masks = (eroded_masks == threshold).float()

contours = torch.logical_xor(masks, eroded_masks.bool())

return contours


def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:
"""
Compute the bounding boxes around the provided masks.
Expand Down