Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add masked/bspline fitting variant of Nyul histogram matching. #607

Merged
merged 1 commit into from
Apr 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ants/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .get_mask import get_mask
from .get_neighborhood import (get_neighborhood_in_mask,
get_neighborhood_at_voxel)
from .histogram_match_image import histogram_match_image
from .histogram_match_image import histogram_match_image, histogram_match_image2
from .histogram_equalize_image import histogram_equalize_image
from .hausdorff_distance import hausdorff_distance
from .image_similarity import image_similarity
Expand Down
108 changes: 105 additions & 3 deletions ants/utils/histogram_match_image.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@

__all__ = ['histogram_match_image']
__all__ = ['histogram_match_image',
'histogram_match_image2']

import math
import numpy as np

from ..core import ants_image as iio
from ..core import ants_image_io as iio
from .. import utils

from ..utils import fit_bspline_object_to_scattered_data


def histogram_match_image(source_image, reference_image, number_of_histogram_bins=255, number_of_match_points=64, use_threshold_at_mean_intensity=False):
"""
Expand Down Expand Up @@ -51,3 +54,102 @@ def histogram_match_image(source_image, reference_image, number_of_histogram_bin
return new_image


def histogram_match_image2(source_image, reference_image,
source_mask=None, reference_mask=None,
match_points=64,
transform_domain_size=255):
"""
Transform image intensities based on histogram mapping.

Apply B-spline 1-D maps to an input image for intensity warping.

Arguments
---------
source_image : ANTsImage
source image

reference_image : ANTsImage
reference image

source_mask : ANTsImage
source mask

reference_mask : ANTsImage
reference mask

match_points : integer or tuple
Parametric points at which the intensity transform displacements are
specified between [0, 1], i.e. quantiles. Alternatively, a single number
can be given and the sequence is linearly spaced in [0, 1].

transform_domain_size : integer
Defines the sampling resolution of the B-spline warping.

Returns
-------
ANTs image

Example
-------
>>> import ants
>>> src_img = ants.image_read(ants.get_data('r16'))
>>> ref_img = ants.image_read(ants.get_data('r64'))
>>> src_ref = ants.histogram_match_image(src_img, ref_img)
"""

if not isinstance(match_points, int):
if any(b < 0 for b in match_points) and any(b > 1 for b in match_points):
raise ValueError("If specifying match_points as a vector, values must be in the range [0, 1]")

# Use entire image if mask isn't specified
if source_mask is None:
source_mask = source_image * 0 + 1
if reference_mask is None:
reference_mask = reference_image * 0 + 1

source_array = source_image.numpy()
source_mask_array = source_mask.numpy()
source_masked_min = source_image[source_mask != 0].min()
source_masked_max = source_image[source_mask != 0].max()

reference_array = reference_image.numpy()
reference_mask_array = reference_mask.numpy()

parametric_points = None
if not isinstance(match_points, int):
parametric_points = match_points
else:
parametric_points = np.linspace(0, 1, match_points)

source_intensity_quantiles = np.quantile(source_array[source_mask_array != 0], parametric_points)
reference_intensity_quantiles = np.quantile(reference_array[reference_mask_array != 0], parametric_points)
displacements = reference_intensity_quantiles - source_intensity_quantiles

scattered_data = np.reshape(displacements, (len(displacements), 1))
parametric_data = np.reshape(parametric_points * (source_masked_max - source_masked_min) + source_masked_min, (len(parametric_points), 1))

transform_domain_origin = source_masked_min
transform_domain_spacing = (source_masked_max - transform_domain_origin) / (transform_domain_size - 1)

bspline_histogram_transform = fit_bspline_object_to_scattered_data(scattered_data,
parametric_data, [transform_domain_origin], [transform_domain_spacing], [transform_domain_size],
data_weights=None, is_parametric_dimension_closed=None, number_of_fitting_levels=8,
mesh_size=1, spline_order=3)

transform_domain = np.linspace(source_masked_min, source_masked_max, transform_domain_size)

transformed_source_array = source_image.numpy()
for i in range(len(transform_domain) - 1):
indices = np.where((source_array >= transform_domain[i]) & (source_array < transform_domain[i+1]))
intensities = source_array[indices]

alpha = (intensities - transform_domain[i])/(transform_domain[i+1] - transform_domain[i])
xfrm = alpha * (bspline_histogram_transform[i+1] - bspline_histogram_transform[i]) + bspline_histogram_transform[i]
transformed_source_array[indices] = intensities + xfrm

transformed_source_image = iio.from_numpy(transformed_source_array, origin=source_image.origin,
spacing=source_image.spacing, direction=source_image.direction)
transformed_source_image[source_mask == 0] = source_image[source_mask == 0]

return(transformed_source_image)