Skip to content

Commit

Permalink
PERF: Make AdvancedNormalizedCorrelationImageToImageMetric faster
Browse files Browse the repository at this point in the history
Added a new member function to `AdvancedImageToImageMetric`, `FastEvaluateMovingImageValueAndDerivative`, which calls the multi-threaded overload of ITK's `itk::BSplineInterpolateImageFunction::EvaluateValueAndDerivativeAtContinuousIndex`, indirectly. (It does so via `EvaluateMovingImageValueAndDerivativeWithOptionalThreadId`, another new member function.)

Made `AdvancedNormalizedCorrelationImageToImageMetric::ThreadedGetValueAndDerivative` faster, by calling this new `FastEvaluateMovingImageValueAndDerivative`.

A large performance improvement was observed for GoogleTest unit test `itkElastixRegistrationMethod.EulerDiscRotation2D` (which uses the "AdvancedNormalizedCorrelation" metric), from ~1.5 second before this commit down to ~0.9 second after this commit, using Visual Studio 2019, Release configuration. (For a Debug configuration even from ~15 seconds before, down to ~4 seconds after this commit.)
  • Loading branch information
N-Dekker committed Nov 26, 2021
1 parent 10cfb97 commit c8c4a6a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 8 deletions.
23 changes: 22 additions & 1 deletion Common/CostFunctions/itkAdvancedImageToImageMetric.h
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,21 @@ class ITK_TEMPLATE_EXPORT AdvancedImageToImageMetric : public ImageToImageMetric
virtual bool
EvaluateMovingImageValueAndDerivative(const MovingImagePointType & mappedPoint,
RealType & movingImageValue,
MovingImageDerivativeType * gradient) const;
MovingImageDerivativeType * gradient) const
{
return EvaluateMovingImageValueAndDerivativeWithOptionalThreadId(mappedPoint, movingImageValue, gradient);
}

/* A faster version of `EvaluateMovingImageValueAndDerivative`: Non-virtual, using multithreading, and doing less
* dynamic memory allocation/decallocation operations, internally. */
bool
FastEvaluateMovingImageValueAndDerivative(const MovingImagePointType & mappedPoint,
RealType & movingImageValue,
MovingImageDerivativeType * gradient,
const ThreadIdType threadId) const
{
return EvaluateMovingImageValueAndDerivativeWithOptionalThreadId(mappedPoint, movingImageValue, gradient, threadId);
}

/** Computes the inner product of transform Jacobian with moving image gradient.
* The results are stored in imageJacobian, which is supposed
Expand Down Expand Up @@ -571,6 +585,13 @@ class ITK_TEMPLATE_EXPORT AdvancedImageToImageMetric : public ImageToImageMetric
void
operator=(const Self &) = delete;

template <typename... TOptionalThreadId>
bool
EvaluateMovingImageValueAndDerivativeWithOptionalThreadId(const MovingImagePointType & mappedPoint,
RealType & movingImageValue,
MovingImageDerivativeType * gradient,
const TOptionalThreadId... optionalThreadId) const;

/** Private member variables. */
bool m_UseImageSampler{ false };
bool m_UseFixedImageLimiter{ false };
Expand Down
24 changes: 18 additions & 6 deletions Common/CostFunctions/itkAdvancedImageToImageMetric.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ AdvancedImageToImageMetric<TFixedImage, TMovingImage>::Initialize(void)
if (this->m_UseMultiThread)
{
this->InitializeThreadingParameters();

const auto setNumberOfWorkUnitsIfNotNull = [this](const auto bsplineInterpolator) {
if (!bsplineInterpolator.IsNull())
{
bsplineInterpolator->SetNumberOfWorkUnits(this->Superclass::GetNumberOfWorkUnits());
}
};
setNumberOfWorkUnitsIfNotNull(m_BSplineInterpolator);
setNumberOfWorkUnitsIfNotNull(m_BSplineInterpolatorFloat);
}

} // end Initialize()
Expand Down Expand Up @@ -501,15 +510,17 @@ AdvancedImageToImageMetric<TFixedImage, TMovingImage>::CheckForBSplineTransform(


/**
* ******************* EvaluateMovingImageValueAndDerivative ******************
* ******************* EvaluateMovingImageValueAndDerivativeWithOptionalThreadId ******************
*/

template <class TFixedImage, class TMovingImage>
template <typename... TOptionalThreadId>
bool
AdvancedImageToImageMetric<TFixedImage, TMovingImage>::EvaluateMovingImageValueAndDerivative(
AdvancedImageToImageMetric<TFixedImage, TMovingImage>::EvaluateMovingImageValueAndDerivativeWithOptionalThreadId(
const MovingImagePointType & mappedPoint,
RealType & movingImageValue,
MovingImageDerivativeType * gradient) const
MovingImageDerivativeType * gradient,
const TOptionalThreadId... optionalThreadId) const
{
/** Check if mapped point inside image buffer. */
MovingImageContinuousIndexType cindex;
Expand All @@ -523,13 +534,14 @@ AdvancedImageToImageMetric<TFixedImage, TMovingImage>::EvaluateMovingImageValueA
if (this->m_InterpolatorIsBSpline && !this->GetComputeGradient())
{
/** Compute moving image value and gradient using the B-spline kernel. */
this->m_BSplineInterpolator->EvaluateValueAndDerivativeAtContinuousIndex(cindex, movingImageValue, *gradient);
this->m_BSplineInterpolator->EvaluateValueAndDerivativeAtContinuousIndex(
cindex, movingImageValue, *gradient, optionalThreadId...);
}
else if (this->m_InterpolatorIsBSplineFloat && !this->GetComputeGradient())
{
/** Compute moving image value and gradient using the B-spline kernel. */
this->m_BSplineInterpolatorFloat->EvaluateValueAndDerivativeAtContinuousIndex(
cindex, movingImageValue, *gradient);
cindex, movingImageValue, *gradient, optionalThreadId...);
}
else if (this->m_InterpolatorIsReducedBSpline && !this->GetComputeGradient())
{
Expand Down Expand Up @@ -606,7 +618,7 @@ AdvancedImageToImageMetric<TFixedImage, TMovingImage>::EvaluateMovingImageValueA

return sampleOk;

} // end EvaluateMovingImageValueAndDerivative()
} // end EvaluateMovingImageValueAndDerivativeWithOptionalThreadId()


/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,8 @@ AdvancedNormalizedCorrelationImageToImageMetric<TFixedImage, TMovingImage>::Thre
*/
if (sampleOk)
{
sampleOk = this->EvaluateMovingImageValueAndDerivative(mappedPoint, movingImageValue, &movingImageDerivative);
sampleOk = this->FastEvaluateMovingImageValueAndDerivative(
mappedPoint, movingImageValue, &movingImageDerivative, threadId);
}

if (sampleOk)
Expand Down

0 comments on commit c8c4a6a

Please sign in to comment.