From 47bf24bc1d2625f7d1d353a492e6e43665685ff9 Mon Sep 17 00:00:00 2001 From: Asaph Zylbertal <30349799+azylbertal@users.noreply.github.com> Date: Thu, 13 Apr 2023 16:29:44 +0100 Subject: [PATCH 1/2] Allow optional moving mask and optionally apply masks to all registration stages --- ants/registration/interface.py | 194 +++++++++++++-------------------- 1 file changed, 77 insertions(+), 117 deletions(-) diff --git a/ants/registration/interface.py b/ants/registration/interface.py index 33365dcb..c8b45959 100644 --- a/ants/registration/interface.py +++ b/ants/registration/interface.py @@ -25,6 +25,8 @@ def registration( initial_transform=None, outprefix="", mask=None, + moving_mask=None, + mask_all_stages=False, grad_step=0.2, flow_sigma=3, total_sigma=0, @@ -70,8 +72,14 @@ def registration( output will be named with this prefix. mask : ANTsImage (optional) - mask the registration. - + mask the fixed image. + + moving_mask : ANTsImage (optional) + mask the moving image. + + mask_all_stages : boolean + apply mask(s) to all registration stages rather than to the last stage only + grad_step : scalar gradient step size (not for all tx) @@ -397,9 +405,25 @@ def registration( mask_scale = mask - mask.min() mask_scale = mask_scale / mask_scale.max() * 255.0 charmask = mask_scale.clone("unsigned char") - maskopt = "[%s,NA]" % (utils.get_pointer_string(charmask)) + f_mask_str = utils.get_pointer_string(charmask) + else + f_mask_str = "NA" + + if moving_mask is not None: + moving_mask_scale = moving_mask - moving_mask.min() + moving_mask_scale = moving_mask_scale / moving_mask_scale.max() * 255.0 + moving_charmask = moving_mask_scale.clone("unsigned char") + m_mask_str = utils.get_pointer_string(moving_charmask) + else + m_mask_str = "NA" + + maskopt = "[%s,%s]" % (f_mask_str, m_mask_str) + + if mask_all_stages: + earlymaskopt = maskopt; else: - maskopt = None + earlymaskopt = "[NA,NA]" + if initx is None: initx = "[%s,%s,1]" % (f, m) # ------------------------------------------------------------ @@ -421,7 +445,7 @@ def registration( "-f", "4x2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling), "-t", @@ -438,13 +462,9 @@ def registration( "1", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------ elif type_of_transform == "SyNBoldAff": args = [ @@ -464,7 +484,7 @@ def registration( "-f", "4x2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s,regular,%s]" % (aff_metric, f, m, aff_sampling, aff_random_sampling_rate), @@ -477,7 +497,7 @@ def registration( "-f", "2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling), "-t", @@ -494,13 +514,9 @@ def registration( "1", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------ elif type_of_transform == "ElasticSyN": args = [ @@ -520,7 +536,7 @@ def registration( "-f", "4x2x2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling), "-t", @@ -537,13 +553,9 @@ def registration( "1", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------ elif type_of_transform == "SyN" or type_of_transform == "Elastic": args = [ @@ -563,7 +575,7 @@ def registration( "-f", "4x2x2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling), "-t", @@ -580,13 +592,9 @@ def registration( "1", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------ elif type_of_transform == "SyNRA": args = [ @@ -606,7 +614,7 @@ def registration( "-f", "4x2x2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s,regular,%s]" % (aff_metric, f, m, aff_sampling, aff_random_sampling_rate), @@ -619,7 +627,7 @@ def registration( "-f", "4x2x2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling), "-t", @@ -636,13 +644,9 @@ def registration( "1", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------ elif type_of_transform == "SyNOnly": args = [ @@ -716,12 +720,8 @@ def registration( args.append(metrics[kk]) for kk in range(len(args1)): args.append(args1[kk]) - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") + args.append("-x") + args.append(maskopt) # ------------------------------------------------------------ elif type_of_transform == "SyNAggro": args = [ @@ -741,7 +741,7 @@ def registration( "-f", "4x2x2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling), "-t", @@ -758,13 +758,9 @@ def registration( "1", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------ elif type_of_transform == "SyNCC": syn_metric = "CC" @@ -791,7 +787,7 @@ def registration( "-f", "4x4x2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s,regular,%s]" % (aff_metric, f, m, aff_sampling, aff_random_sampling_rate), @@ -804,7 +800,7 @@ def registration( "-f", "4x2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling), "-t", @@ -821,13 +817,9 @@ def registration( "1", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------ elif type_of_transform == "TRSAA": itlen = len(reg_iterations) @@ -854,7 +846,7 @@ def registration( "-f", shrinkfactors, "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s,regular,%s]" % (aff_metric, f, m, aff_sampling, aff_random_sampling_rate), @@ -867,7 +859,7 @@ def registration( "-f", shrinkfactors, "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s,regular,%s]" % (aff_metric, f, m, aff_sampling, aff_random_sampling_rate), @@ -880,7 +872,7 @@ def registration( "-f", shrinkfactors, "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s,regular,%s]" % (aff_metric, f, m, aff_sampling, aff_random_sampling_rate), @@ -893,7 +885,7 @@ def registration( "-f", shrinkfactors, "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s,regular,%s]" % (aff_metric, f, m, aff_sampling, aff_random_sampling_rate), @@ -911,13 +903,9 @@ def registration( "1", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------s elif type_of_transform == "SyNabp": args = [ @@ -936,7 +924,7 @@ def registration( "-f", "8x4x2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "mattes[%s,%s,1,32,regular,0.25]" % (f, m), "-t", @@ -948,7 +936,7 @@ def registration( "-f", "8x4x2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "CC[%s,%s,0.5,4]" % (f, m), "-t", @@ -965,13 +953,9 @@ def registration( "1", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------ elif type_of_transform == "SyNLessAggro": args = [ @@ -991,7 +975,7 @@ def registration( "-f", "4x2x2x1", "-x", - "[NA,NA]", + earlymaskopt, "-m", "%s[%s,%s,1,%s]" % (syn_metric, f, m, syn_sampling), "-t", @@ -1008,13 +992,9 @@ def registration( "1", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------ elif type_of_transform in tvTypes: if grad_step is None: @@ -1052,13 +1032,9 @@ def registration( "0", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") elif type_of_transform == "TVMSQ": if grad_step is None: grad_step = 1.0 @@ -1086,13 +1062,9 @@ def registration( "0", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------ elif type_of_transform == "TVMSQC": if grad_step is None: @@ -1123,13 +1095,9 @@ def registration( "0", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------ elif ( (type_of_transform == "Rigid") @@ -1159,13 +1127,9 @@ def registration( "1", "-o", "[%s,%s,%s]" % (outprefix, wmo, wfo), + "-x", + maskopt ] - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") # ------------------------------------------------------------ elif "antsRegistrationSyN" in type_of_transform: @@ -1349,12 +1313,8 @@ def registration( if subtype_of_transform == "bo" or subtype_of_transform == "so": args.append(syn_stage) - if maskopt is not None: - args.append("-x") - args.append(maskopt) - else: - args.append("-x") - args.append("[NA,NA]") + args.append("-x") + args.append(maskopt) args = list( itertools.chain.from_iterable( From 4641c36b525f01e10ea9bc7e01416d7a227a5df0 Mon Sep 17 00:00:00 2001 From: Asaph Zylbertal <30349799+azylbertal@users.noreply.github.com> Date: Thu, 13 Apr 2023 16:50:31 +0100 Subject: [PATCH 2/2] Typo fix --- ants/registration/interface.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ants/registration/interface.py b/ants/registration/interface.py index c8b45959..d14a2170 100644 --- a/ants/registration/interface.py +++ b/ants/registration/interface.py @@ -406,7 +406,7 @@ def registration( mask_scale = mask_scale / mask_scale.max() * 255.0 charmask = mask_scale.clone("unsigned char") f_mask_str = utils.get_pointer_string(charmask) - else + else: f_mask_str = "NA" if moving_mask is not None: @@ -414,7 +414,7 @@ def registration( moving_mask_scale = moving_mask_scale / moving_mask_scale.max() * 255.0 moving_charmask = moving_mask_scale.clone("unsigned char") m_mask_str = utils.get_pointer_string(moving_charmask) - else + else: m_mask_str = "NA" maskopt = "[%s,%s]" % (f_mask_str, m_mask_str)