Skip to content

Commit

Permalink
ENH: Pass in the class instance to the PyImageFilter PyGenerateData
Browse files Browse the repository at this point in the history
It was difficult or not possible to write a useful PyGenerateData
function without access to the calling filter. Pass the filter in as an
argument to the callable.

Fix an issue where Modified() was not called if a new function was not
added.

Add tests.
  • Loading branch information
thewtex authored and hjmjohnson committed Jan 21, 2022
1 parent 8f5241e commit 32186c8
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 20 deletions.
3 changes: 3 additions & 0 deletions Wrapping/Generators/Python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,9 @@ macro(itk_wrap_simple_type_python wrap_class swig_name)
set(ITK_WRAP_PYTHON_SWIG_EXT "${ITK_WRAP_PYTHON_SWIG_EXT}DECL_PYTHON_IMAGE_CLASS(${swig_name})\n\n")
endif()

if("${cpp_name}" STREQUAL "itk::PyImageFilter" AND NOT "${swig_name}" MATCHES "Pointer$")
set(ITK_WRAP_PYTHON_SWIG_EXT "${ITK_WRAP_PYTHON_SWIG_EXT}DECL_PYIMAGEFILTER_CLASS(${swig_name})\n\n")
endif()

if("${cpp_name}" STREQUAL "itk::StatisticsLabelObject" AND NOT "${swig_name}" MATCHES "Pointer$")
set(ITK_WRAP_PYTHON_SWIG_EXT "${ITK_WRAP_PYTHON_SWIG_EXT}%template(map${swig_name}) std::map< unsigned long, ${swig_name}_Pointer, std::less< unsigned long > >;\n")
Expand Down
14 changes: 14 additions & 0 deletions Wrapping/Generators/Python/PyBase/pyBase.i
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,20 @@ str = str
// $1 = ptr;
// }

%define DECL_PYIMAGEFILTER_CLASS(swig_name)
%extend swig_name {
%pythoncode {
def Update(self):
"""Internal method to pass a pointer to the Python object wrapper, then call Update() on the filter."""
self._SetSelf(self)
super().Update()
def UpdateLargestPossibleRegion(self):
"""Internal method to pass a pointer to the Python object wrapper, then call UpdateLargestPossibleRegion() on the filter."""
self._SetSelf(self)
super().UpdateLargestPossibleRegion()
}
}
%enddef

%extend itkComponentTreeNode {
%pythoncode {
Expand Down
25 changes: 17 additions & 8 deletions Wrapping/Generators/Python/PyUtils/itkPyImageFilter.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,12 @@ namespace itk
{

/** \class PyImageFilter
* \brief ImageToImageFilter subclass that calls a Python callable object, e.g.
* a Python function.
* \brief ImageToImageFilter subclass that calls a Python callable object, e.g.
* a Python function or a class with a __call__ method.
*
* For more information on ITK filters, the GenerateData() method, and other filter pipeline methods,
* see the ITK Software Guide.
*/


template <class TInputImage, class TOutputImage>
class ITK_TEMPLATE_EXPORT PyImageFilter : public ImageToImageFilter<TInputImage, TOutputImage>
{
Expand Down Expand Up @@ -67,18 +68,26 @@ class ITK_TEMPLATE_EXPORT PyImageFilter : public ImageToImageFilter<TInputImage,
static constexpr unsigned int InputImageDimension = TInputImage::ImageDimension;
static constexpr unsigned int OutputImageDimension = TOutputImage::ImageDimension;


/** Python callable called during the filter's GenerateData. */
void
SetPyGenerateData(PyObject * obj);

/** Python internal method to pass a pointer to the wrapping Python object. */
void
_SetSelf(PyObject * self)
{
this->m_Self = self;
}

protected:
PyImageFilter();
virtual ~PyImageFilter();
virtual void
GenerateData();
void
GenerateData() override;

private:
PyObject * m_Object;
PyObject * m_Self;
PyObject * m_GenerateDataCallable;
};

} // end namespace itk
Expand Down
27 changes: 15 additions & 12 deletions Wrapping/Generators/Python/PyUtils/itkPyImageFilter.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,40 @@ namespace itk
template <class TInputImage, class TOutputImage>
PyImageFilter<TInputImage, TOutputImage>::PyImageFilter()
{
this->m_Object = nullptr;
this->m_GenerateDataCallable = nullptr;
}

template <class TInputImage, class TOutputImage>
PyImageFilter<TInputImage, TOutputImage>::~PyImageFilter()
{
if (this->m_Object)
if (this->m_GenerateDataCallable)
{
Py_DECREF(this->m_Object);
Py_DECREF(this->m_GenerateDataCallable);
}
this->m_Object = nullptr;
this->m_GenerateDataCallable = nullptr;
}

template <class TInputImage, class TOutputImage>
void
PyImageFilter<TInputImage, TOutputImage>::SetPyGenerateData(PyObject * o)
{
if (o != this->m_Object)
if (o != this->m_GenerateDataCallable)
{
if (this->m_Object)
if (this->m_GenerateDataCallable)
{
// get rid of our reference
Py_DECREF(this->m_Object);
Py_DECREF(this->m_GenerateDataCallable);
}

// store the new object
this->m_Object = o;
this->m_GenerateDataCallable = o;
this->Modified();

if (this->m_Object)
if (this->m_GenerateDataCallable)
{
// take out reference (so that the calling code doesn't
// have to keep a binding to the callable around)
Py_INCREF(this->m_Object);
Py_INCREF(this->m_GenerateDataCallable);
}
}
}
Expand All @@ -69,7 +70,7 @@ void
PyImageFilter<TInputImage, TOutputImage>::GenerateData()
{
// make sure that the CommandCallable is in fact callable
if (!PyCallable_Check(this->m_Object))
if (!PyCallable_Check(this->m_GenerateDataCallable))
{
// we throw a standard ITK exception: this makes it possible for
// our standard Swig exception handling logic to take this
Expand All @@ -81,7 +82,9 @@ PyImageFilter<TInputImage, TOutputImage>::GenerateData()
{
PyObject * result;

result = PyEval_CallObject(this->m_Object, (PyObject *)nullptr);
PyObject * args = PyTuple_Pack(1, this->m_Self);
result = PyObject_Call(this->m_GenerateDataCallable, args, (PyObject *)NULL);
Py_DECREF(args);

if (result)
{
Expand Down
1 change: 1 addition & 0 deletions Wrapping/Generators/Python/Tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ itk_python_add_test(NAME PythonDICOMSeries COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/d
DATA{${ITK_DATA_ROOT}/Input/DicomSeries/Image0075.dcm}
DATA{${ITK_DATA_ROOT}/Input/DicomSeries/Image0076.dcm}
DATA{${ITK_DATA_ROOT}/Input/DicomSeries/Image0077.dcm})
itk_python_add_test(NAME PyImageFilterTest COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/PyImageFilterTest.py)

# some tests will fail if dim=2 and unsigned short are not wrapped
INTERSECTION(WRAP_2 2 "${ITK_WRAP_IMAGE_DIMS}")
Expand Down
53 changes: 53 additions & 0 deletions Wrapping/Generators/Python/Tests/PyImageFilterTest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# ==========================================================================
#
# Copyright NumFOCUS
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==========================================================================*/

import itk
import numpy as np
import functools

input_image = itk.image_from_array(np.random.randint(0,255,size=(6,6), dtype=np.uint8))

py_filter = itk.PyImageFilter.New(input_image)

def constant_wrapper(f, constant=42):
@functools.wraps(f)
def wrapper(*args):
return f(*args, constant=constant)
return wrapper

def constant_output(py_image_filter, constant=42):
output = py_image_filter.GetOutput()
output.SetBufferedRegion(output.GetRequestedRegion())
output.Allocate()
output.FillBuffer(constant)

py_filter.SetPyGenerateData(constant_output)
py_filter.Update()
output_image = py_filter.GetOutput()
assert np.all(np.asarray(output_image) == 42)

# Filter calls Modified because a new PyGenerateData was passed
py_filter.SetPyGenerateData(constant_wrapper(constant_output, 10))
py_filter.Update()
output_image = py_filter.GetOutput()
assert np.all(np.asarray(output_image) == 10)

# Functional interface
output_image = itk.py_image_filter(input_image,
py_generate_data=constant_wrapper(constant_output, 7))
assert np.all(np.asarray(output_image) == 7)

0 comments on commit 32186c8

Please sign in to comment.