Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: InferenceSlicer overlap_ratio_wh argument changed to None by default #1547

Merged
merged 3 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ detections = sv.Detections.from_sam(sam_result=sam_result)

- Added [#1409](https://github.com/roboflow/supervision/pull/1409): `text_color` option for [`VertexLabelAnnotator`](https://supervision.roboflow.com/0.23.0/keypoint/annotators/#supervision.keypoint.annotators.VertexLabelAnnotator) keypoint annotator.

- Changed [#1434](https://github.com/roboflow/supervision/pull/1434): [`InferenceSlicer`](https://supervision.roboflow.com/0.23.0/detection/tools/inference_slicer/) now features an `overlap_ratio_wh` parameter, making it easier to compute slice sizes when handling overlapping slices.
- Changed [#1434](https://github.com/roboflow/supervision/pull/1434): [`InferenceSlicer`](https://supervision.roboflow.com/0.23.0/detection/tools/inference_slicer/) now features an `overlap_wh` parameter, making it easier to compute slice sizes when handling overlapping slices.

- Fix [#1448](https://github.com/roboflow/supervision/pull/1448): Various annotator type issues have been resolved, supporting expanded error handling.

Expand Down
18 changes: 6 additions & 12 deletions supervision/detection/tools/inference_slicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from supervision.utils.image import crop_image
from supervision.utils.internal import (
SupervisionWarnings,
deprecated_parameter,
warn_deprecated,
)

Expand Down Expand Up @@ -60,13 +59,15 @@ class InferenceSlicer:
Args:
slice_wh (Tuple[int, int]): Dimensions of each slice measured in pixels. The
tuple should be in the format `(width, height)`.
overlap_ratio_wh (Optional[Tuple[float, float]]): A tuple representing the
overlap_ratio_wh (Optional[Tuple[float, float]]): [⚠️ Deprecated: please set
to `None` and use `overlap_wh`] A tuple representing the
desired overlap ratio for width and height between consecutive slices.
Each value should be in the range [0, 1), where 0 means no overlap and
a value close to 1 means high overlap.
overlap_wh (Optional[Tuple[int, int]]): A tuple representing the desired
overlap for width and height between consecutive slices measured in pixels.
Each value should be greater than or equal to 0.
Each value should be greater than or equal to 0. Takes precedence over
`overlap_ratio_wh`.
overlap_filter (Union[OverlapFilter, str]): Strategy for
filtering or merging overlapping detections in slices.
iou_threshold (float): Intersection over Union (IoU) threshold
Expand All @@ -82,14 +83,6 @@ class InferenceSlicer:
not a multiple of the slice's width or height minus the overlap.
"""

@deprecated_parameter(
old_parameter="overlap_filter_strategy",
new_parameter="overlap_filter",
map_function=lambda x: x,
warning_message="`{old_parameter}` in `{function_name}` is deprecated and will "
"be removed in `supervision-0.27.0`. Use '{new_parameter}' "
"instead.",
)
def __init__(
self,
callback: Callable[[np.ndarray], Detections],
Expand All @@ -103,7 +96,8 @@ def __init__(
if overlap_ratio_wh is not None:
warn_deprecated(
"`overlap_ratio_wh` in `InferenceSlicer.__init__` is deprecated and "
"will be removed in `supervision-0.27.0`. Use `overlap_wh` instead."
"will be removed in `supervision-0.27.0`. Please manually set it to "
"`None` and use `overlap_wh` instead."
)

self._validate_overlap(overlap_ratio_wh, overlap_wh)
Expand Down
192 changes: 192 additions & 0 deletions test/detection/tools/test_inference_slicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from contextlib import ExitStack as DoesNotRaise
from typing import Optional, Tuple

import numpy as np
import pytest

from supervision.detection.core import Detections
from supervision.detection.overlap_filter import OverlapFilter
from supervision.detection.tools.inference_slicer import InferenceSlicer


@pytest.fixture
def mock_callback():
"""Mock callback function for testing."""

def callback(_: np.ndarray) -> Detections:
return Detections(xyxy=np.array([[0, 0, 10, 10]]))

return callback


@pytest.mark.parametrize(
"slice_wh, overlap_ratio_wh, overlap_wh, expected_overlap, exception",
[
# Valid case: overlap_ratio_wh provided, overlap calculated from the ratio
((128, 128), (0.2, 0.2), None, None, DoesNotRaise()),
# Valid case: overlap_wh in pixels, no ratio provided
((128, 128), None, (20, 20), (20, 20), DoesNotRaise()),
# Invalid case: overlap_ratio_wh greater than 1, should raise ValueError
((128, 128), (1.1, 0.5), None, None, pytest.raises(ValueError)),
# Invalid case: negative overlap_wh, should raise ValueError
((128, 128), None, (-10, 20), None, pytest.raises(ValueError)),
# Invalid case:
# overlap_ratio_wh and overlap_wh provided, should raise ValueError
((128, 128), (0.5, 0.5), (20, 20), (20, 20), pytest.raises(ValueError)),
# Valid case: no overlap_ratio_wh, overlap_wh = 50 pixels
((256, 256), None, (50, 50), (50, 50), DoesNotRaise()),
# Valid case: overlap_ratio_wh provided, overlap calculated from (0.3, 0.3)
((200, 200), (0.3, 0.3), None, None, DoesNotRaise()),
# Valid case: small overlap_ratio_wh values
((100, 100), (0.1, 0.1), None, None, DoesNotRaise()),
# Invalid case: negative overlap_ratio_wh value, should raise ValueError
((128, 128), (-0.1, 0.2), None, None, pytest.raises(ValueError)),
# Invalid case: negative overlap_ratio_wh with overlap_wh provided
((128, 128), (-0.1, 0.2), (30, 30), None, pytest.raises(ValueError)),
# Invalid case: overlap_wh greater than slice size, should raise ValueError
((128, 128), None, (150, 150), (150, 150), DoesNotRaise()),
# Valid case: overlap_ratio_wh is 0, no overlap
((128, 128), (0.0, 0.0), None, None, DoesNotRaise()),
# Invalid case: no overlaps defined, no overlap
((128, 128), None, None, None, pytest.raises(ValueError)),
],
)
def test_inference_slicer_overlap(
mock_callback,
slice_wh: Tuple[int, int],
overlap_ratio_wh: Optional[Tuple[float, float]],
overlap_wh: Optional[Tuple[int, int]],
expected_overlap: Optional[Tuple[int, int]],
exception: Exception,
) -> None:
with exception:
slicer = InferenceSlicer(
callback=mock_callback,
slice_wh=slice_wh,
overlap_ratio_wh=overlap_ratio_wh,
overlap_wh=overlap_wh,
overlap_filter=OverlapFilter.NONE,
)
assert slicer.overlap_wh == expected_overlap


@pytest.mark.parametrize(
"resolution_wh, slice_wh, overlap_wh, expected_offsets",
[
# Case 1: No overlap, exact slices fit within image dimensions
(
(256, 256),
(128, 128),
(0, 0),
np.array(
[
[0, 0, 128, 128],
[128, 0, 256, 128],
[0, 128, 128, 256],
[128, 128, 256, 256],
]
),
),
# Case 2: Overlap of 64 pixels in both directions
(
(256, 256),
(128, 128),
(64, 64),
np.array(
[
[0, 0, 128, 128],
[64, 0, 192, 128],
[128, 0, 256, 128],
[192, 0, 256, 128],
[0, 64, 128, 192],
[64, 64, 192, 192],
[128, 64, 256, 192],
[192, 64, 256, 192],
[0, 128, 128, 256],
[64, 128, 192, 256],
[128, 128, 256, 256],
[192, 128, 256, 256],
[0, 192, 128, 256],
[64, 192, 192, 256],
[128, 192, 256, 256],
[192, 192, 256, 256],
]
),
),
# Case 3: Image not perfectly divisible by slice size (no overlap)
(
(300, 300),
(128, 128),
(0, 0),
np.array(
[
[0, 0, 128, 128],
[128, 0, 256, 128],
[256, 0, 300, 128],
[0, 128, 128, 256],
[128, 128, 256, 256],
[256, 128, 300, 256],
[0, 256, 128, 300],
[128, 256, 256, 300],
[256, 256, 300, 300],
]
),
),
# Case 4: Overlap of 32 pixels, image not perfectly divisible by slice size
(
(300, 300),
(128, 128),
(32, 32),
np.array(
[
[0, 0, 128, 128],
[96, 0, 224, 128],
[192, 0, 300, 128],
[288, 0, 300, 128],
[0, 96, 128, 224],
[96, 96, 224, 224],
[192, 96, 300, 224],
[288, 96, 300, 224],
[0, 192, 128, 300],
[96, 192, 224, 300],
[192, 192, 300, 300],
[288, 192, 300, 300],
[0, 288, 128, 300],
[96, 288, 224, 300],
[192, 288, 300, 300],
[288, 288, 300, 300],
]
),
),
# Case 5: Image smaller than slice size (no overlap)
(
(100, 100),
(128, 128),
(0, 0),
np.array(
[
[0, 0, 100, 100],
]
),
),
# Case 6: Overlap_wh is greater than the slice size
((256, 256), (128, 128), (150, 150), np.array([]).reshape(0, 4)),
],
)
def test_generate_offset(
resolution_wh: Tuple[int, int],
slice_wh: Tuple[int, int],
overlap_wh: Optional[Tuple[int, int]],
expected_offsets: np.ndarray,
) -> None:
offsets = InferenceSlicer._generate_offset(
resolution_wh=resolution_wh,
slice_wh=slice_wh,
overlap_ratio_wh=None,
overlap_wh=overlap_wh,
)

# Verify that the generated offsets match the expected offsets
assert np.array_equal(
offsets, expected_offsets
), f"Expected {expected_offsets}, got {offsets}"