From b0da1069ce6b1bb4acd3a87e8d71527e1ac09e0a Mon Sep 17 00:00:00 2001 From: ntustison Date: Sun, 24 Nov 2024 16:56:25 -0800 Subject: [PATCH] BUG: Minor fixes. --- ants/registration/registration.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/ants/registration/registration.py b/ants/registration/registration.py index 90ab1280..495d927e 100644 --- a/ants/registration/registration.py +++ b/ants/registration/registration.py @@ -1738,7 +1738,8 @@ def label_image_registration(fixed_label_images, count += 1 if do_deformable: deformable_multivariate_extras.append(["MSQ", fixed_single_label_image, - moving_single_label_image, label_image_weighting, 0]) + moving_single_label_image, + label_image_weights[i], 0]) linear_xfrm = ants.fit_transform_to_paired_points(moving_centers_of_mass, fixed_centers_of_mass, @@ -1793,7 +1794,7 @@ def label_image_registration(fixed_label_images, intensity_metric = "CC" if intensity_metric_parameter is None: intensity_metric_parameter = 2 - for i in range(1, len(fixed_intensity_images)): + for i in range(len(fixed_intensity_images)): syn_stage.append("--metric") metric_string = "%s[%s,%s,%s,%s]" % ( intensity_metric, @@ -1808,7 +1809,7 @@ def label_image_registration(fixed_label_images, "MSQ", get_pointer_string(deformable_multivariate_extras[kk][1]), get_pointer_string(deformable_multivariate_extras[kk][2]), - 1.0, 0.0) + label_image_weights[kk], 0.0) syn_stage.append(metricString) syn_shrink_factors = "8x4x2x1" @@ -1832,9 +1833,14 @@ def label_image_registration(fixed_label_images, syn_stage.insert(0, "SyN[0.1,3,0]") syn_stage.insert(0, "--transform") - args = ["-d", str(image_dimension), - "-r", linear_xfrm_file, - "-o", output_prefix] + args = None + if linear_xfrm is None: + args = ["-d", str(image_dimension), + "-o", output_prefix] + else: + args = ["-d", str(image_dimension), + "-r", linear_xfrm_file, + "-o", output_prefix] args.append(syn_stage) fixed_mask_string = 'NA' @@ -1884,14 +1890,15 @@ def label_image_registration(fixed_label_images, find_forward_warps = np.where([re.search("[0-9]Warp.nii.gz", ff) for ff in all_xfrms])[0] if len(find_inverse_warps) > 0: - fwdtransforms = list(reversed([ff for idx, ff in enumerate(all_xfrms) if idx != find_inverse_warps[0]])) + fwdtransforms = [find_forward_warps[0], linear_xfrm_file] + invtransforms = [linear_xfrm_file, find_inverse_warps[0]] invtransforms = [ff for idx, ff in enumerate(all_xfrms) if idx != find_forward_warps[0]] else: fwdtransforms = list(reversed(all_xfrms)) invtransforms = all_xfrms if verbose: - print("\n\nResulting transforms:") + print("\n\nResulting transforms") print(" fwdtransforms: ", fwdtransforms) print(" invtransforms: ", invtransforms)