From 7dabd9625147b6c911be470f304fb2b48731b407 Mon Sep 17 00:00:00 2001 From: Lukas Weber <32765578+lukasalexanderweber@users.noreply.github.com> Date: Fri, 18 Aug 2023 01:06:28 +0200 Subject: [PATCH] implement feature masks (#130) --- stitching/cli/stitch.py | 24 +-- stitching/feature_detector.py | 18 +++ stitching/images.py | 7 + stitching/stitcher.py | 20 ++- stitching/verbose.py | 16 +- tests/test_detector.py | 22 +++ tests/test_stitch_cli.py | 21 +++ tests/test_stitcher.py | 272 +++++++++++++++++++++++---------- tests/test_timelapse.py | 6 +- tests/test_verbose.py | 26 ---- tests/testdata/TEST_IMAGES.txt | 4 + 11 files changed, 301 insertions(+), 135 deletions(-) delete mode 100644 tests/test_verbose.py diff --git a/stitching/cli/stitch.py b/stitching/cli/stitch.py index dfa3334..871ce46 100644 --- a/stitching/cli/stitch.py +++ b/stitching/cli/stitch.py @@ -3,7 +3,6 @@ """ import argparse -import glob import os import sys from datetime import datetime @@ -29,7 +28,7 @@ def create_parser(): parser = argparse.ArgumentParser(prog="stitch.py") - parser.add_argument("img_names", nargs="+", help="Files to stitch", type=str) + parser.add_argument("images", nargs="+", help="Files to stitch", type=str) parser.add_argument( "-v", "--verbose", @@ -74,6 +73,13 @@ def create_parser(): "The default is 500.", type=int, ) + parser.add_argument( + "--feature_masks", + nargs="*", + default=[], + help="Masks for selecting where features should be detected.", + type=str, + ) parser.add_argument( "--matcher_type", action="store", @@ -280,9 +286,9 @@ def main(): args_dict = vars(args) # Extract In- and Output - img_names = args_dict.pop("img_names") - if len(img_names) == 1: - img_names = glob.glob(img_names[0]) + images = args_dict.pop("images") + feature_masks = args_dict.pop("feature_masks") + verbose = args_dict.pop("verbose") verbose_dir = args_dict.pop("verbose_dir") preview = args_dict.pop("preview") @@ -297,12 +303,12 @@ def main(): stitcher = Stitcher(**args_dict) if verbose: - print("stitching " + " ".join(img_names) + " into " + verbose_dir) + print("stitching " + " ".join(images) + " into " + verbose_dir) os.makedirs(verbose_dir) - stitcher.stitch_verbose(img_names, verbose_dir) + panorama = stitcher.stitch_verbose(images, feature_masks, verbose_dir) else: - print("stitching " + " ".join(img_names) + " into " + output) - panorama = stitcher.stitch(img_names) + print("stitching " + " ".join(images) + " into " + output) + panorama = stitcher.stitch(images, feature_masks) cv.imwrite(output, panorama) if preview: diff --git a/stitching/feature_detector.py b/stitching/feature_detector.py index 2cebd31..85d174c 100644 --- a/stitching/feature_detector.py +++ b/stitching/feature_detector.py @@ -1,6 +1,9 @@ from collections import OrderedDict import cv2 as cv +import numpy as np + +from .stitching_error import StitchingError class FeatureDetector: @@ -21,6 +24,21 @@ def __init__(self, detector=DEFAULT_DETECTOR, **kwargs): def detect_features(self, img, *args, **kwargs): return cv.detail.computeImageFeatures2(self.detector, img, *args, **kwargs) + def detect(self, imgs): + return [self.detect_features(img) for img in imgs] + + def detect_with_masks(self, imgs, masks): + features = [] + for idx, (img, mask) in enumerate(zip(imgs, masks)): + assert len(img.shape) == 3 and len(mask.shape) == 2 + if not np.array_equal(img.shape[:2], mask.shape): + raise StitchingError( + f"Resolution of mask '{idx+1}' ({mask.shape}) does not match" + f" the resolution of image '{idx+1}' ({img.shape[:2]})." + ) + features.append(self.detect_features(img, mask=mask)) + return features + @staticmethod def draw_keypoints(img, features, **kwargs): kwargs.setdefault("color", (0, 255, 0)) diff --git a/stitching/images.py b/stitching/images.py index 2892149..7efcf57 100644 --- a/stitching/images.py +++ b/stitching/images.py @@ -138,6 +138,13 @@ def resolve_wildcards(img_names): def check_list_element_types(list_, type_): return all([isinstance(element, type_) for element in list_]) + @staticmethod + def to_binary(img): + if len(img.shape) == 3: + img = cv.cvtColor(img, cv.COLOR_BGR2GRAY) + _, binary = cv.threshold(img, 0.5, 255.0, cv.THRESH_BINARY) + return binary.astype(np.uint8) + class _NumpyImages(Images): def __init__(self, images, medium_megapix, low_megapix, final_megapix): diff --git a/stitching/stitcher.py b/stitching/stitcher.py index 1b7b1e2..98a554f 100644 --- a/stitching/stitcher.py +++ b/stitching/stitcher.py @@ -86,16 +86,16 @@ def initialize_stitcher(self, **kwargs): self.blender = Blender(args.blender_type, args.blend_strength) self.timelapser = Timelapser(args.timelapse, args.timelapse_prefix) - def stitch_verbose(self, images, verbose_dir=None): - return verbose_stitching(self, images, verbose_dir) + def stitch_verbose(self, images, feature_masks=[], verbose_dir=None): + return verbose_stitching(self, images, feature_masks, verbose_dir) - def stitch(self, images): + def stitch(self, images, feature_masks=[]): self.images = Images.of( images, self.medium_megapix, self.low_megapix, self.final_megapix ) imgs = self.resize_medium_resolution() - features = self.find_features(imgs) + features = self.find_features(imgs, feature_masks) matches = self.match_features(features) imgs, features, matches = self.subset(imgs, features, matches) cameras = self.estimate_camera_parameters(features, matches) @@ -128,8 +128,16 @@ def stitch(self, images): def resize_medium_resolution(self): return list(self.images.resize(Images.Resolution.MEDIUM)) - def find_features(self, imgs): - return [self.detector.detect_features(img) for img in imgs] + def find_features(self, imgs, feature_masks=[]): + if len(feature_masks) == 0: + return self.detector.detect(imgs) + else: + feature_masks = Images.of( + feature_masks, self.medium_megapix, self.low_megapix, self.final_megapix + ) + feature_masks = list(feature_masks.resize(Images.Resolution.MEDIUM)) + feature_masks = [Images.to_binary(mask) for mask in feature_masks] + return self.detector.detect_with_masks(imgs, feature_masks) def match_features(self, features): return self.matcher.match_features(features) diff --git a/stitching/verbose.py b/stitching/verbose.py index db353f2..03123d0 100644 --- a/stitching/verbose.py +++ b/stitching/verbose.py @@ -7,7 +7,7 @@ from .timelapser import Timelapser -def verbose_stitching(stitcher, images, verbose_dir=None): +def verbose_stitching(stitcher, images, feature_masks=[], verbose_dir=None): _dir = "." if verbose_dir is None else verbose_dir images = Images.of( @@ -19,7 +19,7 @@ def verbose_stitching(stitcher, images, verbose_dir=None): # Find Features finder = stitcher.detector - features = [finder.detect_features(img) for img in imgs] + features = stitcher.find_features(imgs, feature_masks) for idx, img_features in enumerate(features): img_with_features = finder.draw_keypoints(imgs[idx], img_features) write_verbose_result(_dir, f"01_features_img{idx+1}.jpg", img_with_features) @@ -123,8 +123,8 @@ def verbose_stitching(stitcher, images, verbose_dir=None): low_corners, low_sizes = cropper.crop_rois(low_corners, low_sizes) lir_aspect = images.get_ratio(Images.Resolution.LOW, Images.Resolution.FINAL) - cropped_final_masks = list(cropper.crop_images(final_masks, lir_aspect)) - cropped_final_imgs = list(cropper.crop_images(final_imgs, lir_aspect)) + final_masks = list(cropper.crop_images(final_masks, lir_aspect)) + final_imgs = list(cropper.crop_images(final_imgs, lir_aspect)) final_corners, final_sizes = cropper.crop_rois( final_corners, final_sizes, lir_aspect ) @@ -132,7 +132,7 @@ def verbose_stitching(stitcher, images, verbose_dir=None): timelapser = Timelapser("as_is") timelapser.initialize(final_corners, final_sizes) - for idx, (img, corner) in enumerate(zip(cropped_final_imgs, final_corners)): + for idx, (img, corner) in enumerate(zip(final_imgs, final_corners)): timelapser.process_frame(img, corner) frame = timelapser.get_frame() write_verbose_result(_dir, f"07_timelapse_cropped_img{idx+1}.jpg", frame) @@ -143,11 +143,11 @@ def verbose_stitching(stitcher, images, verbose_dir=None): seam_masks = seam_finder.find(low_imgs, low_corners, low_masks) seam_masks = [ seam_finder.resize(seam_mask, mask) - for seam_mask, mask in zip(seam_masks, cropped_final_masks) + for seam_mask, mask in zip(seam_masks, final_masks) ] seam_masks_plots = [ SeamFinder.draw_seam_mask(img, seam_mask) - for img, seam_mask in zip(cropped_final_imgs, seam_masks) + for img, seam_mask in zip(final_imgs, seam_masks) ] for idx, seam_mask in enumerate(seam_masks_plots): @@ -161,7 +161,7 @@ def verbose_stitching(stitcher, images, verbose_dir=None): compensated_imgs = [ compensator.apply(idx, corner, img, mask) for idx, (img, mask, corner) in enumerate( - zip(cropped_final_imgs, cropped_final_masks, final_corners) + zip(final_imgs, final_masks, final_corners) ) ] diff --git a/tests/test_detector.py b/tests/test_detector.py index 92db66e..48c6a9c 100644 --- a/tests/test_detector.py +++ b/tests/test_detector.py @@ -1,5 +1,7 @@ import unittest +import numpy as np + from .context import FeatureDetector, load_test_img @@ -17,6 +19,26 @@ def test_number_of_keypoints(self): features = detector.detect_features(img1) self.assertEqual(len(features.getKeypoints()), other_keypoints) + def test_feature_masking(self): + img1 = load_test_img("s1.jpg") + + # creating the image mask and setting only the middle 20% as enabled + height, width = img1.shape[:2] + top, bottom, left, right = map( + int, (0.4 * height, 0.6 * height, 0.4 * width, 0.6 * width) + ) + mask = np.zeros(shape=(height, width), dtype=np.uint8) + mask[top:bottom, left:right] = 255 + + num_features = 1000 + detector = FeatureDetector("orb", nfeatures=num_features) + keypoints = detector.detect_features(img1, mask=mask).getKeypoints() + self.assertTrue(len(keypoints) > 0) + for point in keypoints: + x, y = point.pt + self.assertTrue(left <= x < right) + self.assertTrue(top <= y < bottom) + def start_test(): unittest.main() diff --git a/tests/test_stitch_cli.py b/tests/test_stitch_cli.py index 995f33a..923807b 100644 --- a/tests/test_stitch_cli.py +++ b/tests/test_stitch_cli.py @@ -58,6 +58,27 @@ def test_main_verbose(self): img.shape[:2], (150, 590), atol=max_image_shape_derivation ) + def test_main_feature_masks(self): + output = test_output("features_with_mask_from_cli.jpg") + test_args = [ + "stitch.py", + test_input("barcode1.png"), + test_input("barcode2.png"), + "--feature_masks", + test_input("mask1.png"), + test_input("mask2.png"), + "--output", + output, + ] + with patch.object(sys, "argv", test_args): + main() + + img = cv.imread(output) + max_image_shape_derivation = 15 + np.testing.assert_allclose( + img.shape[:2], (716, 1852), atol=max_image_shape_derivation + ) + def start_test(): unittest.main() diff --git a/tests/test_stitcher.py b/tests/test_stitcher.py index de136b6..aa16606 100644 --- a/tests/test_stitcher.py +++ b/tests/test_stitcher.py @@ -1,8 +1,11 @@ +import os import unittest +from datetime import datetime import numpy as np from .context import ( + VERBOSE_DIR, AffineStitcher, Stitcher, StitchingError, @@ -17,49 +20,65 @@ class TestStitcher(unittest.TestCase): def test_stitcher_weir(self): stitcher = Stitcher() - max_image_shape_derivation = 15 + max_derivation = 30 expected_shape = (673, 2636) - # from filenames - images = [test_input("weir*.jpg")] - result = stitcher.stitch(images) - write_test_result("weir_from_filenames.jpg", result) - - np.testing.assert_allclose( - result.shape[:2], expected_shape, atol=max_image_shape_derivation + # from image filenames + imgs = [test_input("weir*.jpg")] + name = "weir_from_filenames" + + self.stitch_test_with_warning( + stitcher, + imgs, + expected_shape, + max_derivation, + name, + StitchingWarning, + "Not all images are included", ) # from loaded numpy arrays - images = [ + imgs = [ load_test_img("weir_1.jpg"), load_test_img("weir_2.jpg"), load_test_img("weir_3.jpg"), load_test_img("weir_noise.jpg"), ] - result = stitcher.stitch(images) - write_test_result("weir_from_numpy_images.jpg", result) - - np.testing.assert_allclose( - result.shape[:2], expected_shape, atol=max_image_shape_derivation + name = "weir_from_numpy_images" + + self.stitch_test_with_warning( + stitcher, + imgs, + expected_shape, + max_derivation, + name, + StitchingWarning, + "Not all images are included", ) def test_stitcher_with_not_matching_images(self): stitcher = Stitcher() - with self.assertRaises(StitchingError) as cm: - stitcher.stitch([test_input("s1.jpg"), test_input("boat1.jpg")]) - self.assertTrue( - "No match exceeds the given confidence threshold" in str(cm.exception) + imgs = [test_input("s1.jpg"), test_input("boat1.jpg")] + + self.stitch_test_with_error( + stitcher, + imgs, + (), + 0, + "", + StitchingError, + "No match exceeds the given confidence threshold", + verbose=False, ) def test_stitcher_aquaduct(self): stitcher = Stitcher(nfeatures=250, crop=False) - result = stitcher.stitch([test_input("s?.jpg")]) - write_test_result("s_result.jpg", result) + imgs = [test_input("s?.jpg")] + max_derivation = 3 + expected_shape = (700, 1811) + name = "s_result" - max_image_shape_derivation = 3 - np.testing.assert_allclose( - result.shape[:2], (700, 1811), atol=max_image_shape_derivation - ) + self.stitch_test(stitcher, imgs, expected_shape, max_derivation, name) def test_stitcher_boat1(self): settings = { @@ -69,24 +88,21 @@ def test_stitcher_boat1(self): "compensator": "no", "crop": False, } - stitcher = Stitcher(**settings) - result = stitcher.stitch( - [ - test_input("boat5.jpg"), - test_input("boat2.jpg"), - test_input("boat3.jpg"), - test_input("boat4.jpg"), - test_input("boat1.jpg"), - test_input("boat6.jpg"), - ] - ) - - write_test_result("boat_fisheye.jpg", result) + imgs = [ + test_input("boat5.jpg"), + test_input("boat2.jpg"), + test_input("boat3.jpg"), + test_input("boat4.jpg"), + test_input("boat1.jpg"), + test_input("boat6.jpg"), + ] + max_derivation = 600 + expected_shape = (14488, 7556) + name = "boat_fisheye" - max_image_shape_derivation = 600 - np.testing.assert_allclose( - result.shape[:2], (14488, 7556), atol=max_image_shape_derivation + self.stitch_test( + stitcher, imgs, expected_shape, max_derivation, name, verbose=False ) def test_stitcher_boat2(self): @@ -96,53 +112,49 @@ def test_stitcher_boat2(self): "compensator": "channel_blocks", "crop": False, } - stitcher = Stitcher(**settings) - result = stitcher.stitch( - [ - test_input("boat5.jpg"), - test_input("boat2.jpg"), - test_input("boat3.jpg"), - test_input("boat4.jpg"), - test_input("boat1.jpg"), - test_input("boat6.jpg"), - ] - ) - - write_test_result("boat_plane.jpg", result) + imgs = [ + test_input("boat5.jpg"), + test_input("boat2.jpg"), + test_input("boat3.jpg"), + test_input("boat4.jpg"), + test_input("boat1.jpg"), + test_input("boat6.jpg"), + ] + max_derivation = 600 + expected_shape = (7400, 12340) + name = "boat_fisheye" - max_image_shape_derivation = 600 - np.testing.assert_allclose( - result.shape[:2], (7400, 12340), atol=max_image_shape_derivation + self.stitch_test( + stitcher, imgs, expected_shape, max_derivation, name, verbose=False ) def test_stitcher_boat_aquaduct_subset(self): graph = test_output("boat_subset_matches_graph.txt") settings = {"final_megapix": 1, "matches_graph_dot_file": graph} - stitcher = Stitcher(**settings) - - with self.assertWarns(StitchingWarning) as cm: - result = stitcher.stitch( - [ - test_input("boat5.jpg"), - test_input("s1.jpg"), - test_input("s2.jpg"), - test_input("boat2.jpg"), - test_input("boat3.jpg"), - test_input("boat4.jpg"), - test_input("boat1.jpg"), - test_input("boat6.jpg"), - ] - ) - - self.assertTrue(str(cm.warning).startswith("Not all images are included")) - - write_test_result("boat_subset_low_res.jpg", result) - - max_image_shape_derivation = 100 - np.testing.assert_allclose( - result.shape[:2], (705, 3374), atol=max_image_shape_derivation + imgs = [ + test_input("boat5.jpg"), + test_input("s1.jpg"), + test_input("s2.jpg"), + test_input("boat2.jpg"), + test_input("boat3.jpg"), + test_input("boat4.jpg"), + test_input("boat1.jpg"), + test_input("boat6.jpg"), + ] + max_derivation = 100 + expected_shape = (705, 3374) + name = "boat_subset_low_res" + + self.stitch_test_with_warning( + stitcher, + imgs, + expected_shape, + max_derivation, + name, + StitchingWarning, + "Not all images are included", ) with open(graph, "r") as file: @@ -156,15 +168,109 @@ def test_affine_stitcher_budapest(self): } stitcher = AffineStitcher(**settings) - result = stitcher.stitch([test_input("budapest?.jpg")]) + imgs = [test_input("budapest?.jpg")] + max_derivation = 50 + expected_shape = (1155, 2310) + name = "budapest" + + self.stitch_test(stitcher, imgs, expected_shape, max_derivation, name) + + def test_stitcher_feature_masks(self): + stitcher = Stitcher(crop=False) + + # without masks + imgs = [test_input("barcode1.png"), test_input("barcode2.png")] + max_derivation = 25 + expected_shape = (905, 2124) + name = "features_without_mask" - write_test_result("budapest.jpg", result) + self.stitch_test(stitcher, imgs, expected_shape, max_derivation, name) + + # with masks + masks = [test_input("mask1.png"), test_input("mask2.png")] + max_derivation = 15 + expected_shape = (716, 1852) + name = "features_with_mask" + + self.stitch_test( + stitcher, imgs, expected_shape, max_derivation, name, feature_masks=masks + ) + + def stitch_test( + self, + stitcher, + imgs, + expected_shape, + max_derivation, + name, + feature_masks=[], + verbose=True, + ): + result = stitcher.stitch(imgs, feature_masks) + + if verbose: + verbose_dir_name = datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + name + verbose_dir = os.path.join(VERBOSE_DIR, verbose_dir_name) + os.makedirs(verbose_dir) + result_verbose = stitcher.stitch_verbose(imgs, feature_masks, verbose_dir) + np.testing.assert_allclose( + result.shape, result_verbose.shape, atol=max_derivation + ) - max_image_shape_derivation = 50 np.testing.assert_allclose( - result.shape[:2], (1155, 2310), atol=max_image_shape_derivation + result.shape[:2], expected_shape, atol=max_derivation ) + write_test_result(name + ".jpg", result) + + def stitch_test_with_warning( + self, + stitcher, + imgs, + expected_shape, + max_derivation, + name, + expected_warning_type, + expected_warning_message, + feature_masks=[], + verbose=True, + ): + with self.assertWarns(expected_warning_type) as cm: + self.stitch_test( + stitcher, + imgs, + expected_shape, + max_derivation, + name, + feature_masks, + verbose, + ) + self.assertTrue(str(cm.warning).startswith(expected_warning_message)) + + def stitch_test_with_error( + self, + stitcher, + imgs, + expected_shape, + max_derivation, + name, + expected_error_type, + expected_error_message, + feature_masks=[], + verbose=True, + ): + with self.assertRaises(expected_error_type) as cm: + self.stitch_test( + stitcher, + imgs, + expected_shape, + max_derivation, + name, + feature_masks, + verbose, + ) + self.assertTrue(str(cm.exception).startswith(expected_error_message)) + def test_use_of_a_stitcher_for_multiple_image_sets(self): # the scale should not be fixed by the first run but set dynamically # based on every input image set. diff --git a/tests/test_timelapse.py b/tests/test_timelapse.py index f250d88..63d193d 100644 --- a/tests/test_timelapse.py +++ b/tests/test_timelapse.py @@ -3,18 +3,18 @@ import cv2 as cv import numpy as np -from .context import Stitcher, load_test_img, test_input +from .context import Stitcher, test_input, test_output class TestImageComposition(unittest.TestCase): def test_timelapse(self): stitcher = Stitcher( timelapse="as_is", - timelapse_prefix=test_input("timelapse_"), + timelapse_prefix=test_output("timelapse_"), crop=False, ) _ = stitcher.stitch([test_input("s?.jpg")]) - frame1 = load_test_img("timelapse_s1.jpg") + frame1 = cv.imread(test_output("timelapse_s1.jpg")) max_image_shape_derivation = 3 np.testing.assert_allclose( diff --git a/tests/test_verbose.py b/tests/test_verbose.py deleted file mode 100644 index f7890b0..0000000 --- a/tests/test_verbose.py +++ /dev/null @@ -1,26 +0,0 @@ -import unittest - -import numpy as np - -from .context import VERBOSE_DIR, Stitcher, test_input - - -class TestStitcherVerbose(unittest.TestCase): - def test_verbose(self): - stitcher = Stitcher() - panorama = stitcher.stitch_verbose([test_input("weir*")], VERBOSE_DIR) - - # Check only that the result is correct. - # Mostly this test is for checking that no error occurs during verbose mode. - max_image_shape_derivation = 25 - np.testing.assert_allclose( - panorama.shape[:2], (673, 2636), atol=max_image_shape_derivation - ) - - -def start_test(): - unittest.main() - - -if __name__ == "__main__": - start_test() diff --git a/tests/testdata/TEST_IMAGES.txt b/tests/testdata/TEST_IMAGES.txt index 7f50495..1e9c496 100644 --- a/tests/testdata/TEST_IMAGES.txt +++ b/tests/testdata/TEST_IMAGES.txt @@ -16,3 +16,7 @@ https://raw.githubusercontent.com/lukasalexanderweber/stitching_tutorial/master/ https://raw.githubusercontent.com/lukasalexanderweber/stitching_tutorial/master/imgs/weir_2.jpg https://raw.githubusercontent.com/lukasalexanderweber/stitching_tutorial/master/imgs/weir_3.jpg https://raw.githubusercontent.com/lukasalexanderweber/stitching_tutorial/master/imgs/weir_noise.jpg +https://raw.githubusercontent.com/lukasalexanderweber/stitching_tutorial/master/imgs/barcode1.png +https://raw.githubusercontent.com/lukasalexanderweber/stitching_tutorial/master/imgs/barcode2.png +https://raw.githubusercontent.com/lukasalexanderweber/stitching_tutorial/master/imgs/mask1.png +https://raw.githubusercontent.com/lukasalexanderweber/stitching_tutorial/master/imgs/mask2.png