Skip to content

Commit

Permalink
PERF: Let ComputeImageExtremaFilter only compute min and max
Browse files Browse the repository at this point in the history
No longer the mean, sigma, variance, sum, or sum of squares. Removed m_ThreadSum, m_SumOfSquares, and m_Count. Removed inheritance from StatisticsImageFilter (which computed the mean, sigma, variance, sum, and sum of squares for images without a mask).

ComputeImageExtremaFilter is only being used in elastix by AdvancedImageToImageMetric and AdvancedMeanSquaresImageToImageMetric, and both of them only need to have the minimum and the maximum from ComputeImageExtremaFilter.

A performance improvement of ~8 percent was observed, from more than 0.13 sec. (before this commit) to less than 0.12 sec., on a 8192x8192 image with mask. Without a mask, the computation appears even more than six times as fast as before, from more than 0.25 sec. (before this commit) to less than 0.04 sec. (after this commit), on a 16384x16384 image. Using Visual Studio 2022 (Release).
  • Loading branch information
N-Dekker committed Feb 19, 2024
1 parent e99644c commit 76b082d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 153 deletions.
72 changes: 32 additions & 40 deletions Common/itkComputeImageExtremaFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,26 @@
#ifndef itkComputeImageExtremaFilter_h
#define itkComputeImageExtremaFilter_h

#include "itkStatisticsImageFilter.h"
#include "itkImageSink.h"
#include "itkImageMaskSpatialObject.h"

namespace itk
{
/** \class ComputeImageExtremaFilter
* \brief Compute min. max, variance and mean of an Image.
*
* StatisticsImageFilter computes the minimum, maximum, sum, mean, variance
* sigma of an image. The filter needs all of its input image. It
* behaves as a filter with an input and output. Thus it can be inserted
* in a pipline with other filters and the statistics will only be
* recomputed if a downstream filter changes.
*
* The filter passes its input through unmodified. The filter is
* threaded. It computes statistics in each thread then combines them in
* its AfterThreadedGenerate method.
* \brief Compute minimum and maximum pixel value of an Image.
*
* \ingroup MathematicalStatisticsImageFilters
* \ingroup ITKImageStatistics
*
* \wiki
* \wikiexample{Statistics/StatisticsImageFilter,Compute min\, max\, variance and mean of an Image.}
* \endwiki
*/
template <typename TInputImage>
class ITK_TEMPLATE_EXPORT ComputeImageExtremaFilter : public StatisticsImageFilter<TInputImage>
class ITK_TEMPLATE_EXPORT ComputeImageExtremaFilter : public ImageSink<TInputImage>
{
public:
ITK_DISALLOW_COPY_AND_MOVE(ComputeImageExtremaFilter);

/** Standard Self typedef */
using Self = ComputeImageExtremaFilter;
using Superclass = StatisticsImageFilter<TInputImage>;
using Superclass = ImageSink<TInputImage>;
using Pointer = SmartPointer<Self>;
using ConstPointer = SmartPointer<const Self>;

Expand All @@ -64,52 +50,58 @@ class ITK_TEMPLATE_EXPORT ComputeImageExtremaFilter : public StatisticsImageFilt
/** Image related typedefs. */
using InputImagePointer = typename TInputImage::Pointer;

using typename Superclass::RegionType;
using typename Superclass::SizeType;
using typename Superclass::IndexType;
using typename Superclass::PixelType;
using PointType = typename TInputImage::PointType;
using Superclass::InputImageDimension;
using typename Superclass::InputImageRegionType;
using PixelType = typename Superclass::InputImagePixelType;

/** Image related typedefs. */
itkStaticConstMacro(ImageDimension, unsigned int, TInputImage::ImageDimension);

/** Type to use for computations. */
using typename Superclass::RealType;

using ImageSpatialMaskType = ImageMaskSpatialObject<Self::ImageDimension>;
using ImageSpatialMaskPointer = typename ImageSpatialMaskType::Pointer;
using ImageSpatialMaskConstPointer = typename ImageSpatialMaskType::ConstPointer;
itkSetConstObjectMacro(ImageSpatialMask, ImageSpatialMaskType);
itkGetConstObjectMacro(ImageSpatialMask, ImageSpatialMaskType);

PixelType
GetMinimum() const
{
return m_ThreadMin;
}

PixelType
GetMaximum() const
{
return m_ThreadMax;
}

protected:
ComputeImageExtremaFilter() = default;
~ComputeImageExtremaFilter() override = default;

/** Initialize some accumulators before the threads run. */
/** Initialize minimum and maximum before the threads run. */
void
BeforeStreamedGenerateData() override;

/** Do final mean and variance computation from data accumulated in threads.
*/
void
AfterStreamedGenerateData() override;

/** Multi-thread version GenerateData. */
void
ThreadedStreamedGenerateData(const RegionType &) override;
virtual void
ThreadedGenerateDataImageSpatialMask(const RegionType &);
ThreadedStreamedGenerateData(const InputImageRegionType &) override;

private:
struct MinMaxResult
{
PixelType Min;
PixelType Max;
};

static MinMaxResult
RetrieveMinMax(const TInputImage &, const InputImageRegionType &, const ImageSpatialMaskType *, bool);

ImageSpatialMaskConstPointer m_ImageSpatialMask{};
bool m_SameGeometry{ false };

CompensatedSummation<RealType> m_ThreadSum{ 1 };
CompensatedSummation<RealType> m_SumOfSquares{ 1 };
SizeValueType m_Count{ 1 };
PixelType m_ThreadMin{ 1 };
PixelType m_ThreadMax{ 1 };
PixelType m_ThreadMin{ 1 };
PixelType m_ThreadMax{ 1 };

std::mutex m_Mutex{};
}; // end of class
Expand Down
174 changes: 61 additions & 113 deletions Common/itkComputeImageExtremaFilter.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
*=========================================================================*/
#ifndef itkComputeImageExtremaFilter_hxx
#define itkComputeImageExtremaFilter_hxx

#include "itkComputeImageExtremaFilter.h"

#include <itkImageRegionConstIterator.h>
#include <itkImageScanlineIterator.h>

#include "elxMaskHasSameImageDomain.h"

namespace itk
Expand All @@ -29,141 +32,86 @@ template <typename TInputImage>
void
ComputeImageExtremaFilter<TInputImage>::BeforeStreamedGenerateData()
{
if (m_ImageSpatialMask == nullptr)
{
Superclass::BeforeStreamedGenerateData();
}
else
{
// Resize the thread temporaries
m_Count = SizeValueType{};
m_SumOfSquares = RealType{};
m_ThreadSum = RealType{};
m_ThreadMin = NumericTraits<PixelType>::max();
m_ThreadMax = NumericTraits<PixelType>::NonpositiveMin();

if (this->GetImageSpatialMask())
{
this->m_SameGeometry = elastix::MaskHasSameImageDomain(*m_ImageSpatialMask, *(this->GetInput()));
}
else
{
this->m_SameGeometry = false;
}
}
}

template <typename TInputImage>
void
ComputeImageExtremaFilter<TInputImage>::AfterStreamedGenerateData()
{
if (m_ImageSpatialMask == nullptr)
{
Superclass::AfterStreamedGenerateData();
}
else
{
const SizeValueType count = m_Count;
const RealType sumOfSquares(m_SumOfSquares);
const PixelType minimum = m_ThreadMin;
const PixelType maximum = m_ThreadMax;
const RealType sum(m_ThreadSum);

const RealType mean = sum / static_cast<RealType>(count);
const RealType variance =
(sumOfSquares - (sum * sum / static_cast<RealType>(count))) / (static_cast<RealType>(count) - 1);
const RealType sigma = std::sqrt(variance);

// Set the outputs
this->SetMinimum(minimum);
this->SetMaximum(maximum);
this->SetMean(mean);
this->SetSigma(sigma);
this->SetVariance(variance);
this->SetSum(sum);
this->SetSumOfSquares(sumOfSquares);
}
m_ThreadMin = NumericTraits<PixelType>::max();
m_ThreadMax = NumericTraits<PixelType>::NonpositiveMin();
m_SameGeometry =
(m_ImageSpatialMask != nullptr) && elastix::MaskHasSameImageDomain(*m_ImageSpatialMask, *(this->GetInput()));
}

template <typename TInputImage>
void
ComputeImageExtremaFilter<TInputImage>::ThreadedStreamedGenerateData(const RegionType & regionForThread)
{
if (m_ImageSpatialMask == nullptr)
{
Superclass::ThreadedStreamedGenerateData(regionForThread);
}
else
{
this->ThreadedGenerateDataImageSpatialMask(regionForThread);
}
} // end ThreadedGenerateData()

template <typename TInputImage>
void
ComputeImageExtremaFilter<TInputImage>::ThreadedGenerateDataImageSpatialMask(const RegionType & regionForThread)
auto
ComputeImageExtremaFilter<TInputImage>::RetrieveMinMax(const TInputImage & inputImage,
const InputImageRegionType & regionForThread,
const ImageSpatialMaskType * const imageSpatialMask,
const bool sameGeometry) -> MinMaxResult
{
if (regionForThread.GetSize(0) == 0)
{
return;
}
RealType sum{};
RealType sumOfSquares{};
SizeValueType count{};
PixelType min = NumericTraits<PixelType>::max();
PixelType max = NumericTraits<PixelType>::NonpositiveMin();
PixelType min = NumericTraits<PixelType>::max();
PixelType max = NumericTraits<PixelType>::NonpositiveMin();

const auto & inputImage = *(this->GetInput());

if (this->m_SameGeometry)
if (imageSpatialMask)
{
const auto & maskImage = *(this->m_ImageSpatialMask->GetImage());
if (sameGeometry)
{
const auto & maskImage = *(imageSpatialMask->GetImage());

for (ImageRegionConstIterator<TInputImage> it(&inputImage, regionForThread); !it.IsAtEnd(); ++it)
for (ImageRegionConstIterator<TInputImage> it(&inputImage, regionForThread); !it.IsAtEnd(); ++it)
{
if (maskImage.GetPixel(it.GetIndex()) != PixelType{})
{
const PixelType value = it.Get();
min = std::min(min, value);
max = std::max(max, value);
}
}
}
else
{
if (maskImage.GetPixel(it.GetIndex()) != PixelType{})
for (ImageRegionConstIterator<TInputImage> it(&inputImage, regionForThread); !it.IsAtEnd(); ++it)
{
const PixelType value = it.Get();
const auto realValue = static_cast<RealType>(value);

min = std::min(min, value);
max = std::max(max, value);

sum += realValue;
sumOfSquares += (realValue * realValue);
++count;
typename ImageSpatialMaskType::PointType point;
inputImage.TransformIndexToPhysicalPoint(it.GetIndex(), point);
if (imageSpatialMask->IsInsideInWorldSpace(point))
{
const PixelType value = it.Get();
min = std::min(min, value);
max = std::max(max, value);
}
}
} // end for
}
}
else
{
for (ImageRegionConstIterator<TInputImage> it(&inputImage, regionForThread); !it.IsAtEnd(); ++it)
for (ImageScanlineConstIterator<TInputImage> it(&inputImage, regionForThread); !it.IsAtEnd(); it.NextLine())
{
PointType point;
inputImage.TransformIndexToPhysicalPoint(it.GetIndex(), point);
if (this->m_ImageSpatialMask->IsInsideInWorldSpace(point))
while (!it.IsAtEndOfLine())
{
const PixelType value = it.Get();
const auto realValue = static_cast<RealType>(value);

min = std::min(min, value);
max = std::max(max, value);

sum += realValue;
sumOfSquares += (realValue * realValue);
++count;
++it;
}
} // end for
} // end if
}
}
return { min, max };
}

const std::lock_guard<std::mutex> lockGuard(m_Mutex);
m_ThreadSum += sum;
m_SumOfSquares += sumOfSquares;
m_Count += count;
m_ThreadMin = std::min(min, m_ThreadMin);
m_ThreadMax = std::max(max, m_ThreadMax);

} // end ThreadedGenerateDataImageSpatialMask()
template <typename TInputImage>
void
ComputeImageExtremaFilter<TInputImage>::ThreadedStreamedGenerateData(const InputImageRegionType & regionForThread)
{
if (regionForThread.GetSize(0) > 0)
{
const MinMaxResult minMaxResult =
RetrieveMinMax(*(this->GetInput()), regionForThread, m_ImageSpatialMask, m_SameGeometry);

// Lock after calling RetrieveMinMax.
const std::lock_guard<std::mutex> lockGuard(m_Mutex);
m_ThreadMin = std::min(minMaxResult.Min, m_ThreadMin);
m_ThreadMax = std::max(minMaxResult.Max, m_ThreadMax);
}
}

} // end namespace itk
#endif

0 comments on commit 76b082d

Please sign in to comment.