From a72efc1a03751512d0a4d97e9ff1a622e2406d22 Mon Sep 17 00:00:00 2001 From: Stephen Crowell Date: Tue, 12 Nov 2024 13:06:59 -0500 Subject: [PATCH] Add tests --- .../test_occlusion_based.py | 105 ++++++++++++------ .../test_gen_object_detector_blackbox_sal.py | 37 ++++++ 2 files changed, 109 insertions(+), 33 deletions(-) diff --git a/tests/impls/gen_object_detector_blackbox_sal/test_occlusion_based.py b/tests/impls/gen_object_detector_blackbox_sal/test_occlusion_based.py index 8326de14..fd327f3d 100644 --- a/tests/impls/gen_object_detector_blackbox_sal/test_occlusion_based.py +++ b/tests/impls/gen_object_detector_blackbox_sal/test_occlusion_based.py @@ -14,6 +14,31 @@ class TestPerturbationOcclusion: + class StubPI (PerturbImage): + """ + Stub perturber that returns masks of ones. + """ + + def perturb(self, ref_image: np.ndarray) -> np.ndarray: + return np.ones((6, *ref_image.shape[:2]), dtype=bool) + + get_config = None # type: ignore + + class StubGen (GenerateDetectorProposalSaliency): + """ + Stub saliency generator that returns zeros with correct shape. + """ + + def generate( + self, + ref_dets: np.ndarray, + pert_dets: np.ndarray, + pert_masks: np.ndarray + ) -> np.ndarray: + return np.zeros((ref_dets.shape[0], *pert_masks.shape[1:]), dtype=np.float16) + + get_config = None # type: ignore + def teardown(self) -> None: # Collect any temporary implementations so they are not returned during # later `*.get_impl()` requests. @@ -24,7 +49,7 @@ def test_configuration(self) -> None: Test configuration suite using known simple implementations. """ - class StubPI (PerturbImage): + class StubPIWithArgs (PerturbImage): perturb = None # type: ignore def __init__(self, stub_param: int): @@ -33,7 +58,7 @@ def __init__(self, stub_param: int): def get_config(self) -> Dict[str, Any]: return {'stub_param': self.p} - class StubGen (GenerateDetectorProposalSaliency): + class StubGenWithArgs (GenerateDetectorProposalSaliency): generate = None # type: ignore def __init__(self, stub_param: int): @@ -46,15 +71,15 @@ def get_config(self) -> Dict[str, Any]: test_spi_p = 0 test_sgn_p = 1 inst = PerturbationOcclusion( - StubPI(test_spi_p), - StubGen(test_sgn_p), + StubPIWithArgs(test_spi_p), + StubGenWithArgs(test_sgn_p), threads=87 ) for inst_i in configuration_test_helper(inst): assert inst_i._threads == test_threads - assert isinstance(inst_i._perturber, StubPI) + assert isinstance(inst_i._perturber, StubPIWithArgs) assert inst_i._perturber.p == test_spi_p - assert isinstance(inst_i._generator, StubGen) + assert isinstance(inst_i._generator, StubGenWithArgs) assert inst_i._generator.p == test_sgn_p def test_generate_success(self) -> None: @@ -62,31 +87,6 @@ def test_generate_success(self) -> None: Test successfully invoking _generate(). """ - class StubPI (PerturbImage): - """ - Stub perturber that returns masks of ones. - """ - - def perturb(self, ref_image: np.ndarray) -> np.ndarray: - return np.ones((6, *ref_image.shape[:2]), dtype=bool) - - get_config = None # type: ignore - - class StubGen (GenerateDetectorProposalSaliency): - """ - Stub saliency generator that returns zeros with correct shape. - """ - - def generate( - self, - ref_dets: np.ndarray, - pert_dets: np.ndarray, - pert_masks: np.ndarray - ) -> np.ndarray: - return np.zeros((ref_dets.shape[0], *pert_masks.shape[1:]), dtype=np.float16) - - get_config = None # type: ignore - class StubDetector (DetectImageObjects): """ Stub object detector that returns known detections. @@ -106,8 +106,8 @@ def detect_objects( get_config = None # type: ignore - test_pi = StubPI() - test_gen = StubGen() + test_pi = TestPerturbationOcclusion.StubPI() + test_gen = TestPerturbationOcclusion.StubGen() test_detector = StubDetector() test_image = np.ones((64, 64, 3), dtype=np.uint8) @@ -160,3 +160,42 @@ def detect_objects( m_kwargs = m_occ_img.call_args[-1] assert "fill" in m_kwargs assert m_kwargs['fill'] == test_fill + + def test_empty_detections(self) -> None: + """ + Test invoking _generate() with empty detections. + """ + + class StubDetector (DetectImageObjects): + """ + Stub object detector that returns known detections. + """ + + def detect_objects( + self, + img_iter: Iterable[np.ndarray] + ) -> Iterable[Iterable[Tuple[AxisAlignedBoundingBox, Dict[Hashable, float]]]]: + for i, _ in enumerate(img_iter): + # Return zero detections for each image + yield [] + + get_config = None # type: ignore + + test_pi = TestPerturbationOcclusion.StubPI() + test_gen = TestPerturbationOcclusion.StubGen() + test_detector = StubDetector() + + test_image = np.ones((64, 64, 3), dtype=np.uint8) + + test_bboxes = np.ones((3, 4)) + test_scores = np.ones((3, 2)) + + inst = PerturbationOcclusion(test_pi, test_gen) + test_result = inst._generate( + test_image, + test_bboxes, + test_scores, + test_detector + ) + + assert len(test_result) == 0 diff --git a/tests/interfaces/test_gen_object_detector_blackbox_sal.py b/tests/interfaces/test_gen_object_detector_blackbox_sal.py index 1d457a60..9a92da28 100644 --- a/tests/interfaces/test_gen_object_detector_blackbox_sal.py +++ b/tests/interfaces/test_gen_object_detector_blackbox_sal.py @@ -1,8 +1,10 @@ import unittest.mock as mock import numpy as np import pytest +from typing import Iterable, Tuple, Dict, Hashable from smqtk_detection import DetectImageObjects +from smqtk_detection.utils.bbox import AxisAlignedBoundingBox from xaitk_saliency.interfaces.gen_object_detector_blackbox_sal import GenerateObjectDetectorBlackboxSaliency from xaitk_saliency.exceptions import ShapeMismatchError @@ -301,3 +303,38 @@ def test_call_alias() -> None: None # no objectness passed ) assert test_ret == expected_return + + +def test_return_empty_map() -> None: + """ + Test that an empty array of maps is returned properly + """ + m_impl = mock.Mock(spec=GenerateObjectDetectorBlackboxSaliency) + m_detector = mock.Mock(spec=DetectImageObjects) + + # test reference detections inputs with matching lengths + test_bboxes = np.ones((5, 4), dtype=float) + test_scores = np.ones((5, 3), dtype=float) + + # 2-channel image as just HxW should work + test_image = np.ones((256, 256), dtype=np.uint8) + + expected_return = np.array([]) + m_impl._generate.return_value = expected_return + + test_ret = GenerateObjectDetectorBlackboxSaliency.generate( + m_impl, + test_image, + test_bboxes, + test_scores, + m_detector, + ) + + m_impl._generate.assert_called_with( + test_image, + test_bboxes, + test_scores, + m_detector, + None # no objectness passed + ) + assert len(test_ret) == 0