Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stephencrowell committed Nov 12, 2024
1 parent 0927628 commit a72efc1
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 33 deletions.
105 changes: 72 additions & 33 deletions tests/impls/gen_object_detector_blackbox_sal/test_occlusion_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,31 @@

class TestPerturbationOcclusion:

class StubPI (PerturbImage):
"""
Stub perturber that returns masks of ones.
"""

def perturb(self, ref_image: np.ndarray) -> np.ndarray:
return np.ones((6, *ref_image.shape[:2]), dtype=bool)

get_config = None # type: ignore

class StubGen (GenerateDetectorProposalSaliency):
"""
Stub saliency generator that returns zeros with correct shape.
"""

def generate(
self,
ref_dets: np.ndarray,
pert_dets: np.ndarray,
pert_masks: np.ndarray
) -> np.ndarray:
return np.zeros((ref_dets.shape[0], *pert_masks.shape[1:]), dtype=np.float16)

get_config = None # type: ignore

def teardown(self) -> None:
# Collect any temporary implementations so they are not returned during
# later `*.get_impl()` requests.
Expand All @@ -24,7 +49,7 @@ def test_configuration(self) -> None:
Test configuration suite using known simple implementations.
"""

class StubPI (PerturbImage):
class StubPIWithArgs (PerturbImage):
perturb = None # type: ignore

def __init__(self, stub_param: int):
Expand All @@ -33,7 +58,7 @@ def __init__(self, stub_param: int):
def get_config(self) -> Dict[str, Any]:
return {'stub_param': self.p}

class StubGen (GenerateDetectorProposalSaliency):
class StubGenWithArgs (GenerateDetectorProposalSaliency):
generate = None # type: ignore

def __init__(self, stub_param: int):
Expand All @@ -46,47 +71,22 @@ def get_config(self) -> Dict[str, Any]:
test_spi_p = 0
test_sgn_p = 1
inst = PerturbationOcclusion(
StubPI(test_spi_p),
StubGen(test_sgn_p),
StubPIWithArgs(test_spi_p),
StubGenWithArgs(test_sgn_p),
threads=87
)
for inst_i in configuration_test_helper(inst):
assert inst_i._threads == test_threads
assert isinstance(inst_i._perturber, StubPI)
assert isinstance(inst_i._perturber, StubPIWithArgs)
assert inst_i._perturber.p == test_spi_p
assert isinstance(inst_i._generator, StubGen)
assert isinstance(inst_i._generator, StubGenWithArgs)
assert inst_i._generator.p == test_sgn_p

def test_generate_success(self) -> None:
"""
Test successfully invoking _generate().
"""

class StubPI (PerturbImage):
"""
Stub perturber that returns masks of ones.
"""

def perturb(self, ref_image: np.ndarray) -> np.ndarray:
return np.ones((6, *ref_image.shape[:2]), dtype=bool)

get_config = None # type: ignore

class StubGen (GenerateDetectorProposalSaliency):
"""
Stub saliency generator that returns zeros with correct shape.
"""

def generate(
self,
ref_dets: np.ndarray,
pert_dets: np.ndarray,
pert_masks: np.ndarray
) -> np.ndarray:
return np.zeros((ref_dets.shape[0], *pert_masks.shape[1:]), dtype=np.float16)

get_config = None # type: ignore

class StubDetector (DetectImageObjects):
"""
Stub object detector that returns known detections.
Expand All @@ -106,8 +106,8 @@ def detect_objects(

get_config = None # type: ignore

test_pi = StubPI()
test_gen = StubGen()
test_pi = TestPerturbationOcclusion.StubPI()
test_gen = TestPerturbationOcclusion.StubGen()
test_detector = StubDetector()

test_image = np.ones((64, 64, 3), dtype=np.uint8)
Expand Down Expand Up @@ -160,3 +160,42 @@ def detect_objects(
m_kwargs = m_occ_img.call_args[-1]
assert "fill" in m_kwargs
assert m_kwargs['fill'] == test_fill

def test_empty_detections(self) -> None:
"""
Test invoking _generate() with empty detections.
"""

class StubDetector (DetectImageObjects):
"""
Stub object detector that returns known detections.
"""

def detect_objects(
self,
img_iter: Iterable[np.ndarray]
) -> Iterable[Iterable[Tuple[AxisAlignedBoundingBox, Dict[Hashable, float]]]]:
for i, _ in enumerate(img_iter):
# Return zero detections for each image
yield []

get_config = None # type: ignore

test_pi = TestPerturbationOcclusion.StubPI()
test_gen = TestPerturbationOcclusion.StubGen()
test_detector = StubDetector()

test_image = np.ones((64, 64, 3), dtype=np.uint8)

test_bboxes = np.ones((3, 4))
test_scores = np.ones((3, 2))

inst = PerturbationOcclusion(test_pi, test_gen)
test_result = inst._generate(
test_image,
test_bboxes,
test_scores,
test_detector
)

assert len(test_result) == 0
37 changes: 37 additions & 0 deletions tests/interfaces/test_gen_object_detector_blackbox_sal.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import unittest.mock as mock
import numpy as np
import pytest
from typing import Iterable, Tuple, Dict, Hashable

from smqtk_detection import DetectImageObjects
from smqtk_detection.utils.bbox import AxisAlignedBoundingBox

from xaitk_saliency.interfaces.gen_object_detector_blackbox_sal import GenerateObjectDetectorBlackboxSaliency
from xaitk_saliency.exceptions import ShapeMismatchError
Expand Down Expand Up @@ -301,3 +303,38 @@ def test_call_alias() -> None:
None # no objectness passed
)
assert test_ret == expected_return


def test_return_empty_map() -> None:
"""
Test that an empty array of maps is returned properly
"""
m_impl = mock.Mock(spec=GenerateObjectDetectorBlackboxSaliency)
m_detector = mock.Mock(spec=DetectImageObjects)

# test reference detections inputs with matching lengths
test_bboxes = np.ones((5, 4), dtype=float)
test_scores = np.ones((5, 3), dtype=float)

# 2-channel image as just HxW should work
test_image = np.ones((256, 256), dtype=np.uint8)

expected_return = np.array([])
m_impl._generate.return_value = expected_return

test_ret = GenerateObjectDetectorBlackboxSaliency.generate(
m_impl,
test_image,
test_bboxes,
test_scores,
m_detector,
)

m_impl._generate.assert_called_with(
test_image,
test_bboxes,
test_scores,
m_detector,
None # no objectness passed
)
assert len(test_ret) == 0

0 comments on commit a72efc1

Please sign in to comment.