Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make precision (--float) configurable for registration #741

Merged
merged 5 commits into from
Nov 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions ants/registration/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def registration(
multivariate_extras=None,
restrict_transformation=None,
smoothing_in_mm=False,
singleprecision=True,
**kwargs
):
"""
Expand Down Expand Up @@ -152,6 +153,10 @@ def registration(

smoothing_in_mm : boolean ; currently only impacts low dimensional registration

singleprecision : boolean
if True, use float32 for computations. This is useful for reducing memory
usage for large datasets, at the cost of precision.

kwargs : keyword args
extra arguments

Expand Down Expand Up @@ -351,6 +356,8 @@ def registration(
synits = "x".join([str(ri) for ri in reg_iterations])

inpixeltype = fixed.pixeltype
output_pixel_type = 'float' if singleprecision else 'double'

tvTypes = [
"TV[1]",
"TV[2]",
Expand Down Expand Up @@ -409,8 +416,8 @@ def registration(
# initx = invertAntsrTransform( initx )
# writeAntsrTransform( initx, tempTXfilename )
# initx = tempTXfilename
moving = moving.clone("float")
fixed = fixed.clone("float")
moving = moving.clone(output_pixel_type)
fixed = fixed.clone(output_pixel_type)
# NOTE: this may be better for general purpose applications: TBD
# moving = ants.iMath( moving.clone("float"), "Normalize" )
# fixed = ants.iMath( fixed.clone("float"), "Normalize" )
Expand Down Expand Up @@ -1349,7 +1356,7 @@ def registration(
args.append(restrict_transformationchar)

args.append("--float")
args.append("1")
args.append(str(int(singleprecision)))
args.append("--write-composite-transform")
args.append(write_composite_transform * 1)
if verbose:
Expand Down