Skip to content

Commit

Permalink
feat: choices for canny random thresholds
Browse files Browse the repository at this point in the history
  • Loading branch information
killian31 authored and beniz committed Apr 17, 2023
1 parent 05e8959 commit 9573fc1
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 12 deletions.
35 changes: 31 additions & 4 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def modify_commandline_options(parser, is_train=True):
],
)

parser.add_argument(
"--alg_palette_sketch_canny_range",
type=int,
nargs="+",
default=[0, 255 * 3],
help="range for Canny thresholds",
)
parser.add_argument(
"--alg_palette_prob_use_previous_frame",
type=float,
Expand Down Expand Up @@ -225,16 +232,36 @@ def set_input(self, data):
fill_img_with_random_sketch = random_edge_mask(
fn_list=self.opt.alg_palette_computed_sketch_list
)
batch_cond_image = fill_img_with_random_sketch(
image.unsqueeze(0), mask.unsqueeze(0)
).squeeze(0)
if "canny" in self.opt.alg_palette_computed_sketch_list:
low = min(self.opt.alg_palette_sketch_canny_range)
high = max(self.opt.alg_palette_sketch_canny_range)
batch_cond_image = fill_img_with_random_sketch(
image.unsqueeze(0),
mask.unsqueeze(0),
low_threshold_random=low,
high_threshold_random=high,
).squeeze(0)
else:
batch_cond_image = fill_img_with_random_sketch(
image.unsqueeze(0), mask.unsqueeze(0)
).squeeze(0)
cond_images.append(batch_cond_image)
self.cond_image = torch.stack(cond_images)
else:
fill_img_with_random_sketch = random_edge_mask(
fn_list=self.opt.alg_palette_computed_sketch_list
)
self.cond_image = fill_img_with_random_sketch(self.gt_image, self.mask)
if "canny" in self.opt.alg_palette_computed_sketch_list:
self.cond_image = fill_img_with_random_sketch(
self.gt_image,
self.mask,
low_threshold_random=self.opt.alg_palette_canny_random_low,
high_threshold_random=self.opt.alg_palette_canny_random_high,
)
else:
self.cond_image = fill_img_with_random_sketch(
self.gt_image, self.mask
)

self.batch_size = self.cond_image.shape[0]

Expand Down
29 changes: 21 additions & 8 deletions util/mask_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
sys.path.append("./../")


def fill_img_with_sketch(img, mask):
def fill_img_with_sketch(img, mask, **kwargs):
"""Fill the masked region with sketch edges."""

grayscale = Grayscale(3)
Expand All @@ -32,12 +32,20 @@ def fill_img_with_sketch(img, mask):
return mask * thresh + (1 - mask) * img


def fill_img_with_canny(img, mask, low_threshold=None, high_threshold=None):
def fill_img_with_canny(
img,
mask,
low_threshold=None,
high_threshold=None,
**kwargs,
):
"""Fill the masked region with canny edges."""
low_threshold_random = kwargs["low_threshold_random"]
high_threshold_random = kwargs["high_threshold_random"]
max_value = 255 * 3
if high_threshold is None and low_threshold is None:
threshold_1 = random.randint(0, max_value)
threshold_2 = random.randint(0, max_value)
threshold_1 = random.randint(low_threshold_random, high_threshold_random)
threshold_2 = random.randint(low_threshold_random, high_threshold_random)
high_threshold = max(threshold_1, threshold_2)
low_threshold = min(threshold_1, threshold_2)
elif high_threshold is None and low_threshold is not None:
Expand All @@ -64,7 +72,7 @@ def fill_img_with_canny(img, mask, low_threshold=None, high_threshold=None):
return mask * edges + (1 - mask) * img


def fill_img_with_hed(img, mask):
def fill_img_with_hed(img, mask, **kwargs):
"""Fill the masked region with HED edges from the ControlNet paper."""

apply_hed = HEDdetector()
Expand All @@ -88,12 +96,17 @@ def fill_img_with_hed(img, mask):


def fill_img_with_hough(
img, mask, value_threshold=1e-05, distance_threshold=10.0, with_canny=False
img,
mask,
value_threshold=1e-05,
distance_threshold=10.0,
with_canny=False,
**kwargs,
):
"""Fill the masked region with Hough lines detection from the ControlNet paper."""

if with_canny:
img = fill_img_with_canny(img, mask)
img = fill_img_with_canny(img, mask, **kwargs)

device = img.device
apply_mlsd = MLSDdetector()
Expand All @@ -117,7 +130,7 @@ def fill_img_with_hough(
return mask * edges + (1 - mask) * img


def fill_img_with_depth(img, mask, depth_network="DPT_SwinV2_T_256"):
def fill_img_with_depth(img, mask, depth_network="DPT_SwinV2_T_256", **kwargs):
"""Fill the masked region with depth map."""

device = img.device
Expand Down

0 comments on commit 9573fc1

Please sign in to comment.