Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stephencrowell committed Nov 13, 2024
1 parent 0927628 commit ed9de5d
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 51 deletions.
113 changes: 62 additions & 51 deletions tests/impls/gen_object_detector_blackbox_sal/test_occlusion_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
from xaitk_saliency.utils.masking import occlude_image_batch


def _perturb(ref_image: np.ndarray) -> np.ndarray:
return np.ones((6, *ref_image.shape[:2]), dtype=bool)


class TestPerturbationOcclusion:

def teardown(self) -> None:
Expand Down Expand Up @@ -62,70 +66,40 @@ def test_generate_success(self) -> None:
Test successfully invoking _generate().
"""

class StubPI (PerturbImage):
"""
Stub perturber that returns masks of ones.
"""

def perturb(self, ref_image: np.ndarray) -> np.ndarray:
return np.ones((6, *ref_image.shape[:2]), dtype=bool)

get_config = None # type: ignore

class StubGen (GenerateDetectorProposalSaliency):
"""
Stub saliency generator that returns zeros with correct shape.
"""

def generate(
self,
ref_dets: np.ndarray,
pert_dets: np.ndarray,
pert_masks: np.ndarray
) -> np.ndarray:
return np.zeros((ref_dets.shape[0], *pert_masks.shape[1:]), dtype=np.float16)

get_config = None # type: ignore

class StubDetector (DetectImageObjects):
"""
Stub object detector that returns known detections.
"""

def detect_objects(
self,
img_iter: Iterable[np.ndarray]
) -> Iterable[Iterable[Tuple[AxisAlignedBoundingBox, Dict[Hashable, float]]]]:
for i, _ in enumerate(img_iter):
# Return different number of detections for each image to
# test padding functinality
yield [(
AxisAlignedBoundingBox((0, 0), (1, 1)),
{'class0': 0.0, 'class1': 0.9}
) for _ in range(i)]

get_config = None # type: ignore

test_pi = StubPI()
test_gen = StubGen()
test_detector = StubDetector()
def detect_objects(
img_iter: Iterable[np.ndarray]
) -> Iterable[Iterable[Tuple[AxisAlignedBoundingBox, Dict[Hashable, float]]]]:
for i, _ in enumerate(img_iter):
# Return different number of detections for each image to
# test padding functinality
yield [(
AxisAlignedBoundingBox((0, 0), (1, 1)),
{'class0': 0.0, 'class1': 0.9}
) for _ in range(i)]

test_image = np.ones((64, 64, 3), dtype=np.uint8)

test_bboxes = np.ones((3, 4))
test_scores = np.ones((3, 2))

m_perturb = mock.Mock(spec=PerturbImage)
m_perturb.return_value = _perturb(test_image)
m_gen = mock.Mock(spec=GenerateDetectorProposalSaliency)
m_gen.return_value = np.zeros((3, 64, 64))
m_detector = mock.Mock(spec=DetectImageObjects)
m_detector.detect_objects = detect_objects

# Call with default fill
with mock.patch(
'xaitk_saliency.impls.gen_object_detector_blackbox_sal.occlusion_based.occlude_image_batch',
wraps=occlude_image_batch
) as m_occ_img:
inst = PerturbationOcclusion(test_pi, test_gen)
inst = PerturbationOcclusion(m_perturb, m_gen)
test_result = inst._generate(
test_image,
test_bboxes,
test_scores,
test_detector
m_detector,
)

assert test_result.shape == (3, 64, 64)
Expand All @@ -143,13 +117,13 @@ def detect_objects(
'xaitk_saliency.impls.gen_object_detector_blackbox_sal.occlusion_based.occlude_image_batch',
wraps=occlude_image_batch
) as m_occ_img:
inst = PerturbationOcclusion(test_pi, test_gen)
inst = PerturbationOcclusion(m_perturb, m_gen)
inst.fill = test_fill
test_result = inst._generate(
test_image,
test_bboxes,
test_scores,
test_detector
m_detector,
)

assert test_result.shape == (3, 64, 64)
Expand All @@ -160,3 +134,40 @@ def detect_objects(
m_kwargs = m_occ_img.call_args[-1]
assert "fill" in m_kwargs
assert m_kwargs['fill'] == test_fill

def test_empty_detections(self) -> None:
"""
Test invoking _generate() with empty detections.
"""

def detect_objects(
img_iter: Iterable[np.ndarray]
) -> Iterable[Iterable[Tuple[AxisAlignedBoundingBox, Dict[Hashable, float]]]]:
for i, _ in enumerate(img_iter):
# Return 0 detections for each image
yield []

m_detector = mock.Mock(spec=DetectImageObjects)
m_detector.detect_objects = detect_objects

test_image = np.ones((64, 64, 3), dtype=np.uint8)

test_bboxes = np.ones((3, 4))
test_scores = np.ones((3, 2))

m_perturb = mock.Mock(spec=PerturbImage)
m_perturb.return_value = _perturb(test_image)
m_gen = mock.Mock(spec=GenerateDetectorProposalSaliency)
m_gen.return_value = np.zeros((3, 64, 64))
m_detector = mock.Mock(spec=DetectImageObjects)
m_detector.detect_objects = detect_objects

inst = PerturbationOcclusion(m_perturb, m_gen)
test_result = inst._generate(
test_image,
test_bboxes,
test_scores,
m_detector,
)

assert len(test_result) == 0
35 changes: 35 additions & 0 deletions tests/interfaces/test_gen_object_detector_blackbox_sal.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,38 @@ def test_call_alias() -> None:
None # no objectness passed
)
assert test_ret == expected_return


def test_return_empty_map() -> None:
"""
Test that an empty array of maps is returned properly
"""
m_impl = mock.Mock(spec=GenerateObjectDetectorBlackboxSaliency)
m_detector = mock.Mock(spec=DetectImageObjects)

# test reference detections inputs with matching lengths
test_bboxes = np.ones((5, 4), dtype=float)
test_scores = np.ones((5, 3), dtype=float)

# 2-channel image as just HxW should work
test_image = np.ones((256, 256), dtype=np.uint8)

expected_return = np.array([])
m_impl._generate.return_value = expected_return

test_ret = GenerateObjectDetectorBlackboxSaliency.generate(
m_impl,
test_image,
test_bboxes,
test_scores,
m_detector,
)

m_impl._generate.assert_called_with(
test_image,
test_bboxes,
test_scores,
m_detector,
None # no objectness passed
)
assert len(test_ret) == 0

0 comments on commit ed9de5d

Please sign in to comment.