diff --git a/Modules/Core/Transform/wrapping/test/CMakeLists.txt b/Modules/Core/Transform/wrapping/test/CMakeLists.txt new file mode 100644 index 00000000000..73b5e06130b --- /dev/null +++ b/Modules/Core/Transform/wrapping/test/CMakeLists.txt @@ -0,0 +1,3 @@ +if(ITK_WRAP_PYTHON) + itk_python_add_test(NAME itkTransformSerializationTest COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/itkTransformSerializationTest.py) +endif() diff --git a/Modules/Core/Transform/wrapping/test/itkTransformSerializationTest.py b/Modules/Core/Transform/wrapping/test/itkTransformSerializationTest.py new file mode 100644 index 00000000000..af4327a7845 --- /dev/null +++ b/Modules/Core/Transform/wrapping/test/itkTransformSerializationTest.py @@ -0,0 +1,89 @@ +# ========================================================================== +# +# Copyright NumFOCUS +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0.txt +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# ==========================================================================*/ + +import itk +import numpy as np +import pickle + +Dimension = 3 +PixelType = itk.D + +# List of Transforms to test +transforms_to_test = [itk.AffineTransform[PixelType, Dimension], itk.DisplacementFieldTransform[PixelType, Dimension], itk.Rigid3DTransform[PixelType], itk.BSplineTransform[PixelType, Dimension, 3], itk.QuaternionRigidTransform[PixelType]] + +keys_to_test1 = ["name", "parametersValueType", "transformName", "transformType", "inDimension", "outDimension", "numberOfParameters", "numberOfFixedParameters"] +keys_to_test2 = ["parameters", "fixedParameters"] + +transform_object_list = [] +for i, transform_type in enumerate(transforms_to_test): + transform = transform_type.New() + transform.SetObjectName("transform"+str(i)) + + # Check the serialization + serialize_deserialize = pickle.loads(pickle.dumps(transform)) + + # Test all the attributes + for k in keys_to_test1: + assert serialize_deserialize[k] == transform[k] + + # Test all the parameters + for k in keys_to_test2: + assert np.array_equal(serialize_deserialize[k], transform[k]) + + transform_object_list.append(transform) + +print('Individual Transforms Test Done') + +# Test Composite Transform +transformType = itk.CompositeTransform[PixelType, Dimension] +composite_transform = transformType.New() +composite_transform.SetObjectName('composite_transform') + +# Add the above created transforms in the composite transform +for transform in transform_object_list: + composite_transform.AddTransform(transform) + +# Check the serialization of composite transform +serialize_deserialize = pickle.loads(pickle.dumps(composite_transform)) + +assert serialize_deserialize.GetObjectName() == composite_transform.GetObjectName() +assert serialize_deserialize.GetNumberOfTransforms() == 5 +assert serialize_deserialize["name"] == composite_transform["name"] + +deserialized_object_list = [] + +keys_to_test1 = ["name", "parametersValueType", "transformName", "inDimension", "outDimension", "numberOfParameters", "numberOfFixedParameters"] + +# Get the individual transform objects from the composite transform for testing +for i in range(len(transforms_to_test)): + transform_obj = serialize_deserialize.GetNthTransform(i) + + # Test all the attributes + for k in keys_to_test1: + assert transform_obj[k] == transform_object_list[i][k] + + # Test all the parameter arrays + for k in keys_to_test2: + assert np.array_equal(transform_obj[k], transform_object_list[i][k]) + + # Testing for loss of transformType in Composite transform + if i == 3: + # BSpline has same type here D33 + assert transform_obj["transformType"] == transform_object_list[i]["transformType"] + else: + assert transform_obj["transformType"] != transform_object_list[i]["transformType"] diff --git a/Wrapping/Generators/Python/PyBase/pyBase.i b/Wrapping/Generators/Python/PyBase/pyBase.i index 74e5a5b20e9..a47ca0c6ee9 100644 --- a/Wrapping/Generators/Python/PyBase/pyBase.i +++ b/Wrapping/Generators/Python/PyBase/pyBase.i @@ -405,6 +405,63 @@ str = str %enddef +%define DECL_PYTHON_TRANSFORMBASETEMPLATE_CLASS(swig_name) + %extend swig_name { + %pythoncode %{ + def keys(self): + """ + Return keys related to the transform's metadata. + These keys are used in the dictionary resulting from dict(transform). + """ + result = ['name', 'transformType', 'inDimension', 'outDimension', 'numberOfParameters', 'numberOfFixedParameters', 'parameters', 'fixedParameters'] + return result + + def __getitem__(self, key): + """Access metadata keys, see help(transform.keys), for string keys.""" + import itk + if isinstance(key, str): + state = itk.dict_from_transform(self) + return state[0][key] + + def __setitem__(self, key, value): + if isinstance(key, str): + import numpy as np + if key == 'name': + self.SetObjectName(value) + elif key == 'fixedParameters' or key == 'parameters': + if key == 'fixedParameters': + o1 = self.GetFixedParameters() + else: + o1 = self.GetParameters() + + o1.SetSize(value.shape[0]) + for i, v in enumerate(value): + o1.SetElement(i, v) + + if key == 'fixedParameters': + self.SetFixedParameters(o1) + else: + self.SetParameters(o1) + + + def __getstate__(self): + """Get object state, necessary for serialization with pickle.""" + import itk + state = itk.dict_from_transform(self) + return state + + def __setstate__(self, state): + """Set object state, necessary for serialization with pickle.""" + import itk + import numpy as np + deserialized = itk.transform_from_dict(state) + self.__dict__['this'] = deserialized + %} + } + +%enddef + + %define DECL_PYTHON_IMAGEBASE_CLASS(swig_name, template_params) %inline %{ #include "itkContinuousIndexSwigInterface.h" diff --git a/Wrapping/Generators/Python/itk/support/extras.py b/Wrapping/Generators/Python/itk/support/extras.py index 7cee2b66e45..291f6bd2ea0 100644 --- a/Wrapping/Generators/Python/itk/support/extras.py +++ b/Wrapping/Generators/Python/itk/support/extras.py @@ -16,6 +16,7 @@ # # ==========================================================================*/ +import enum import re from typing import Optional, Union, Dict, Any, List, Tuple, Sequence, TYPE_CHECKING from sys import stderr as system_error_stream @@ -104,6 +105,8 @@ "dict_from_mesh", "pointset_from_dict", "dict_from_pointset", + "transform_from_dict", + "dict_from_transform", "transformwrite", "transformread", "search", @@ -1012,6 +1015,144 @@ def dict_from_pointset(pointset: "itkt.PointSet") -> Dict: pointData=point_data_numpy, ) +def dict_from_transform(transform: "itkt.TransformBase") -> Dict: + import itk + + def update_transform_dict(current_transform): + current_transform_type = current_transform.GetTransformTypeAsString() + current_transform_type_split = current_transform_type.split('_') + component = itk.template(current_transform) + + in_transform_dict = dict() + in_transform_dict["name"] = current_transform.GetObjectName() + in_transform_dict["numberOfTransforms"] = 1 + + datatype_dict = {'double': itk.D, 'float': itk.F} + in_transform_dict["parametersValueType"] = python_to_js(datatype_dict[current_transform_type_split[1]]) + in_transform_dict["inDimension"] = int(current_transform_type_split[2]) + in_transform_dict["outDimension"] = int(current_transform_type_split[3]) + in_transform_dict["transformName"] = current_transform_type_split[0] + + # transformType field to be used for single transform object only. + # For composite transforms we lose the information for child transform objects. + data_type_dict = {itk.D: 'D', itk.F: 'F'} + mangle = data_type_dict[component[1][0]] + for p in component[1][1:]: + mangle += str(p) + in_transform_dict["transformType"] = mangle + + # To avoid copying the parameters for the Composite Transform as it is a copy of child transforms. + if 'Composite' not in current_transform_type_split[0]: + p = np.array(current_transform.GetParameters()) + in_transform_dict["parameters"] = p + + fp = np.array(current_transform.GetFixedParameters()) + in_transform_dict["fixedParameters"] = fp + + in_transform_dict["numberOfParameters"] = p.shape[0] + in_transform_dict["numberOfFixedParameters"] = fp.shape[0] + else: + in_transform_dict["parameters"] = np.array([]) + in_transform_dict["fixedParameters"] = np.array([]) + in_transform_dict["numberOfParameters"] = 0 + in_transform_dict["numberOfFixedParameters"] = 0 + + return in_transform_dict + + + dict_array = [] + transform_type = transform.GetTransformTypeAsString() + if 'CompositeTransform' in transform_type: + transform_dict = update_transform_dict(transform) + transform_dict["numberOfTransforms"] = transform.GetNumberOfTransforms() + + # Add the first entry for the composite transform + dict_array.append(transform_dict) + + # Rest follows the transforms inside the composite transform + # range is over-ridden so using this hack to create a list + for i, _ in enumerate([0]*transform.GetNumberOfTransforms()): + current_transform = transform.GetNthTransform(i) + dict_array.append(update_transform_dict(current_transform)) + else: + dict_array.append(update_transform_dict(transform)) + + return dict_array + +def transform_from_dict(transform_dict: Dict)-> "itkt.TransformBase": + import itk + + def set_parameters(transform, transform_parameters, transform_fixed_parameters): + o1 = transform.GetParameters() + o1.SetSize(transform_parameters.shape[0]) + for j, v in enumerate(transform_parameters): + o1.SetElement(j, v) + transform.SetParameters(o1) + + o2 = transform.GetFixedParameters() + o2.SetSize(transform_fixed_parameters.shape[0]) + for j, v in enumerate(transform_fixed_parameters): + o2.SetElement(j, v) + transform.SetFixedParameters(o2) + + + # For checking transforms which don't take additional parameters while instantiation + def special_transform_check(transform_name): + if '2D' in transform_name or '3D' in transform_name: + return True + + check_list = ['VersorTransform', 'QuaternionRigidTransform'] + for t in check_list: + if transform_name == t: + return True + return False + + # We only check for the first transform as composite similar to the + # convention followed in the itkTxtTransformIO.cxx + if 'CompositeTransform' in transform_dict[0]["transformName"]: + # Loop over all the transforms in the dictionary + transforms_list = [] + for i, _ in enumerate(transform_dict): + if transform_dict[i]["parametersValueType"] == "float32": + data_type = itk.F + else: + data_type = itk.D + + # No template parameter needed for transforms having 2D or 3D name + # Also for some selected transforms + if special_transform_check(transform_dict[i]["transformName"]): + transform_template = getattr(itk, transform_dict[i]["transformName"]) + transform = transform_template[data_type].New() + # Currently only BSpline Transform has 3 template parameters + # For future extensions the information will have to be encoded in + # the transformType variable. The transform object once added in a + # composite transform lose the information for other template parameters ex. BSpline. + # The Spline order is fixed as 3 here. + elif transform_dict[i]["transformName"] == 'BSplineTransform': + transform_template = getattr(itk, transform_dict[i]["transformName"]) + transform = transform_template[data_type, transform_dict[i]["inDimension"], 3].New() + else: + transform_template = getattr(itk, transform_dict[i]["transformName"]) + transform = transform_template[data_type, transform_dict[i]["inDimension"]].New() + + transform.SetObjectName(transform_dict[i]["name"]) + transforms_list.append(transform) + + # Obtain the first object which is composite transform object + # and add all the transforms in it. + transform = transforms_list[0] + for current_transform in transforms_list[1:]: + transform.AddTransform(current_transform) + else: + # For handling single transform objects we rely on itk.template + # because that way we can handle future extensions easily. + transform_template = getattr(itk, transform_dict[0]["transformName"]) + transform = getattr(transform_template, transform_dict[0]["transformType"]).New() + transform.SetObjectName(transform_dict[0]["name"]) + set_parameters(transform, transform_dict[0]["parameters"], transform_dict[0]["fixedParameters"]) + + return transform + def image_intensity_min_max(image_or_filter: "itkt.ImageOrImageSource"): """Return the minimum and maximum of values in a image of in the output image of a filter diff --git a/Wrapping/TypedefMacros.cmake b/Wrapping/TypedefMacros.cmake index 8f0302f0039..b06400106b9 100644 --- a/Wrapping/TypedefMacros.cmake +++ b/Wrapping/TypedefMacros.cmake @@ -1326,6 +1326,10 @@ macro(itk_wrap_simple_type wrap_class swig_name) set(ITK_WRAP_PYTHON_SWIG_EXT "${ITK_WRAP_PYTHON_SWIG_EXT}DECL_PYTHON_MESH_CLASS(${swig_name})\n\n") endif() + if("${cpp_name}" STREQUAL "itk::TransformBaseTemplate") + set(ITK_WRAP_PYTHON_SWIG_EXT "${ITK_WRAP_PYTHON_SWIG_EXT}DECL_PYTHON_TRANSFORMBASETEMPLATE_CLASS(${swig_name})\n\n") + endif() + if("${cpp_name}" STREQUAL "itk::PyImageFilter" AND NOT "${swig_name}" MATCHES "Pointer$") set(ITK_WRAP_PYTHON_SWIG_EXT "${ITK_WRAP_PYTHON_SWIG_EXT}DECL_PYIMAGEFILTER_CLASS(${swig_name})\n\n") endif()