Skip to content

Commit

Permalink
implement feature masks (#130)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasalexanderweber authored Aug 17, 2023
1 parent 76ed406 commit 7dabd96
Show file tree
Hide file tree
Showing 11 changed files with 301 additions and 135 deletions.
24 changes: 15 additions & 9 deletions stitching/cli/stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import argparse
import glob
import os
import sys
from datetime import datetime
Expand All @@ -29,7 +28,7 @@

def create_parser():
parser = argparse.ArgumentParser(prog="stitch.py")
parser.add_argument("img_names", nargs="+", help="Files to stitch", type=str)
parser.add_argument("images", nargs="+", help="Files to stitch", type=str)
parser.add_argument(
"-v",
"--verbose",
Expand Down Expand Up @@ -74,6 +73,13 @@ def create_parser():
"The default is 500.",
type=int,
)
parser.add_argument(
"--feature_masks",
nargs="*",
default=[],
help="Masks for selecting where features should be detected.",
type=str,
)
parser.add_argument(
"--matcher_type",
action="store",
Expand Down Expand Up @@ -280,9 +286,9 @@ def main():
args_dict = vars(args)

# Extract In- and Output
img_names = args_dict.pop("img_names")
if len(img_names) == 1:
img_names = glob.glob(img_names[0])
images = args_dict.pop("images")
feature_masks = args_dict.pop("feature_masks")

verbose = args_dict.pop("verbose")
verbose_dir = args_dict.pop("verbose_dir")
preview = args_dict.pop("preview")
Expand All @@ -297,12 +303,12 @@ def main():
stitcher = Stitcher(**args_dict)

if verbose:
print("stitching " + " ".join(img_names) + " into " + verbose_dir)
print("stitching " + " ".join(images) + " into " + verbose_dir)
os.makedirs(verbose_dir)
stitcher.stitch_verbose(img_names, verbose_dir)
panorama = stitcher.stitch_verbose(images, feature_masks, verbose_dir)
else:
print("stitching " + " ".join(img_names) + " into " + output)
panorama = stitcher.stitch(img_names)
print("stitching " + " ".join(images) + " into " + output)
panorama = stitcher.stitch(images, feature_masks)
cv.imwrite(output, panorama)

if preview:
Expand Down
18 changes: 18 additions & 0 deletions stitching/feature_detector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from collections import OrderedDict

import cv2 as cv
import numpy as np

from .stitching_error import StitchingError


class FeatureDetector:
Expand All @@ -21,6 +24,21 @@ def __init__(self, detector=DEFAULT_DETECTOR, **kwargs):
def detect_features(self, img, *args, **kwargs):
return cv.detail.computeImageFeatures2(self.detector, img, *args, **kwargs)

def detect(self, imgs):
return [self.detect_features(img) for img in imgs]

def detect_with_masks(self, imgs, masks):
features = []
for idx, (img, mask) in enumerate(zip(imgs, masks)):
assert len(img.shape) == 3 and len(mask.shape) == 2
if not np.array_equal(img.shape[:2], mask.shape):
raise StitchingError(
f"Resolution of mask '{idx+1}' ({mask.shape}) does not match"
f" the resolution of image '{idx+1}' ({img.shape[:2]})."
)
features.append(self.detect_features(img, mask=mask))
return features

@staticmethod
def draw_keypoints(img, features, **kwargs):
kwargs.setdefault("color", (0, 255, 0))
Expand Down
7 changes: 7 additions & 0 deletions stitching/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def resolve_wildcards(img_names):
def check_list_element_types(list_, type_):
return all([isinstance(element, type_) for element in list_])

@staticmethod
def to_binary(img):
if len(img.shape) == 3:
img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
_, binary = cv.threshold(img, 0.5, 255.0, cv.THRESH_BINARY)
return binary.astype(np.uint8)


class _NumpyImages(Images):
def __init__(self, images, medium_megapix, low_megapix, final_megapix):
Expand Down
20 changes: 14 additions & 6 deletions stitching/stitcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,16 @@ def initialize_stitcher(self, **kwargs):
self.blender = Blender(args.blender_type, args.blend_strength)
self.timelapser = Timelapser(args.timelapse, args.timelapse_prefix)

def stitch_verbose(self, images, verbose_dir=None):
return verbose_stitching(self, images, verbose_dir)
def stitch_verbose(self, images, feature_masks=[], verbose_dir=None):
return verbose_stitching(self, images, feature_masks, verbose_dir)

def stitch(self, images):
def stitch(self, images, feature_masks=[]):
self.images = Images.of(
images, self.medium_megapix, self.low_megapix, self.final_megapix
)

imgs = self.resize_medium_resolution()
features = self.find_features(imgs)
features = self.find_features(imgs, feature_masks)
matches = self.match_features(features)
imgs, features, matches = self.subset(imgs, features, matches)
cameras = self.estimate_camera_parameters(features, matches)
Expand Down Expand Up @@ -128,8 +128,16 @@ def stitch(self, images):
def resize_medium_resolution(self):
return list(self.images.resize(Images.Resolution.MEDIUM))

def find_features(self, imgs):
return [self.detector.detect_features(img) for img in imgs]
def find_features(self, imgs, feature_masks=[]):
if len(feature_masks) == 0:
return self.detector.detect(imgs)
else:
feature_masks = Images.of(
feature_masks, self.medium_megapix, self.low_megapix, self.final_megapix
)
feature_masks = list(feature_masks.resize(Images.Resolution.MEDIUM))
feature_masks = [Images.to_binary(mask) for mask in feature_masks]
return self.detector.detect_with_masks(imgs, feature_masks)

def match_features(self, features):
return self.matcher.match_features(features)
Expand Down
16 changes: 8 additions & 8 deletions stitching/verbose.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .timelapser import Timelapser


def verbose_stitching(stitcher, images, verbose_dir=None):
def verbose_stitching(stitcher, images, feature_masks=[], verbose_dir=None):
_dir = "." if verbose_dir is None else verbose_dir

images = Images.of(
Expand All @@ -19,7 +19,7 @@ def verbose_stitching(stitcher, images, verbose_dir=None):

# Find Features
finder = stitcher.detector
features = [finder.detect_features(img) for img in imgs]
features = stitcher.find_features(imgs, feature_masks)
for idx, img_features in enumerate(features):
img_with_features = finder.draw_keypoints(imgs[idx], img_features)
write_verbose_result(_dir, f"01_features_img{idx+1}.jpg", img_with_features)
Expand Down Expand Up @@ -123,16 +123,16 @@ def verbose_stitching(stitcher, images, verbose_dir=None):
low_corners, low_sizes = cropper.crop_rois(low_corners, low_sizes)

lir_aspect = images.get_ratio(Images.Resolution.LOW, Images.Resolution.FINAL)
cropped_final_masks = list(cropper.crop_images(final_masks, lir_aspect))
cropped_final_imgs = list(cropper.crop_images(final_imgs, lir_aspect))
final_masks = list(cropper.crop_images(final_masks, lir_aspect))
final_imgs = list(cropper.crop_images(final_imgs, lir_aspect))
final_corners, final_sizes = cropper.crop_rois(
final_corners, final_sizes, lir_aspect
)

timelapser = Timelapser("as_is")
timelapser.initialize(final_corners, final_sizes)

for idx, (img, corner) in enumerate(zip(cropped_final_imgs, final_corners)):
for idx, (img, corner) in enumerate(zip(final_imgs, final_corners)):
timelapser.process_frame(img, corner)
frame = timelapser.get_frame()
write_verbose_result(_dir, f"07_timelapse_cropped_img{idx+1}.jpg", frame)
Expand All @@ -143,11 +143,11 @@ def verbose_stitching(stitcher, images, verbose_dir=None):
seam_masks = seam_finder.find(low_imgs, low_corners, low_masks)
seam_masks = [
seam_finder.resize(seam_mask, mask)
for seam_mask, mask in zip(seam_masks, cropped_final_masks)
for seam_mask, mask in zip(seam_masks, final_masks)
]
seam_masks_plots = [
SeamFinder.draw_seam_mask(img, seam_mask)
for img, seam_mask in zip(cropped_final_imgs, seam_masks)
for img, seam_mask in zip(final_imgs, seam_masks)
]

for idx, seam_mask in enumerate(seam_masks_plots):
Expand All @@ -161,7 +161,7 @@ def verbose_stitching(stitcher, images, verbose_dir=None):
compensated_imgs = [
compensator.apply(idx, corner, img, mask)
for idx, (img, mask, corner) in enumerate(
zip(cropped_final_imgs, cropped_final_masks, final_corners)
zip(final_imgs, final_masks, final_corners)
)
]

Expand Down
22 changes: 22 additions & 0 deletions tests/test_detector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest

import numpy as np

from .context import FeatureDetector, load_test_img


Expand All @@ -17,6 +19,26 @@ def test_number_of_keypoints(self):
features = detector.detect_features(img1)
self.assertEqual(len(features.getKeypoints()), other_keypoints)

def test_feature_masking(self):
img1 = load_test_img("s1.jpg")

# creating the image mask and setting only the middle 20% as enabled
height, width = img1.shape[:2]
top, bottom, left, right = map(
int, (0.4 * height, 0.6 * height, 0.4 * width, 0.6 * width)
)
mask = np.zeros(shape=(height, width), dtype=np.uint8)
mask[top:bottom, left:right] = 255

num_features = 1000
detector = FeatureDetector("orb", nfeatures=num_features)
keypoints = detector.detect_features(img1, mask=mask).getKeypoints()
self.assertTrue(len(keypoints) > 0)
for point in keypoints:
x, y = point.pt
self.assertTrue(left <= x < right)
self.assertTrue(top <= y < bottom)


def start_test():
unittest.main()
Expand Down
21 changes: 21 additions & 0 deletions tests/test_stitch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,27 @@ def test_main_verbose(self):
img.shape[:2], (150, 590), atol=max_image_shape_derivation
)

def test_main_feature_masks(self):
output = test_output("features_with_mask_from_cli.jpg")
test_args = [
"stitch.py",
test_input("barcode1.png"),
test_input("barcode2.png"),
"--feature_masks",
test_input("mask1.png"),
test_input("mask2.png"),
"--output",
output,
]
with patch.object(sys, "argv", test_args):
main()

img = cv.imread(output)
max_image_shape_derivation = 15
np.testing.assert_allclose(
img.shape[:2], (716, 1852), atol=max_image_shape_derivation
)


def start_test():
unittest.main()
Expand Down
Loading

0 comments on commit 7dabd96

Please sign in to comment.