Skip to content

Commit

Permalink
ENH: Add typehints for itk functional filters
Browse files Browse the repository at this point in the history
These are helpful for type checkers, IDE's, and for building interfaces
off the types.

Since we do not know the default values until the constructor has been
executed at runtime, use an Ellipsis as a placeholder per Python
typehint stub convention.

Types are reduced to equivalent Python types when possible, e.g. float,
int, str, and Abstract
Base Classes, abc's, like Sequence's.

This adds a dependency on numpy>=1.20 and typing-extensions for
numpy.typing.ArrayLike.

We add type annotation for a number of types that cover most argument
types -- more can gradually be added in the future.

This also reduces runtime initialization of the functional docstring to
compile time, which is helpful for performance.

A few bugs were also addressed: we do not generate functional interfaces
for abstract classes, and do not write out duplicate functional
interfaces.

Example result:

  from itk.support import helpers
  import itk.support.types as itkt
  from typing import Sequence, Tuple, Union

  @helpers.accept_array_like_xarray_torch
  def median_image_filter(*args: itkt.ImageLike,  radius: Union[Sequence[int], int]=...,**kwargs)-> itkt.ImageSourceReturn:
      """Functional interface for MedianImageFilter"""
      import itk

      kwarg_typehints = { 'radius':radius }
      specified_kwarg_typehints = { k:v for (k,v) in kwarg_typehints.items() if kwarg_typehints[k] != ... }
      kwargs.update(specified_kwarg_typehints)

      instance = itk.MedianImageFilter.New(*args, **kwargs)
      return instance.__internal_call__()

  def median_image_filter_init_docstring():
      import itk
      from itk.support import template_class

      filter_class = itk.ITKSmoothing.MedianImageFilter
      is_template = isinstance(filter_class, template_class.itkTemplate)
      if is_template:
          filter_object = filter_class.values()[0]
      else:
          filter_object = filter_class

      median_image_filter.__doc__ = filter_object.__doc__
  • Loading branch information
thewtex committed Mar 24, 2021
1 parent 14992c6 commit ed11fdc
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 121 deletions.
23 changes: 23 additions & 0 deletions Modules/Filtering/Smoothing/wrapping/test/MedianImageFilterTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,12 @@
#

import itk
import itk.support.types as itkt
from sys import argv
import warnings
from typing import Sequence, TypeVar, get_type_hints, get_args, get_origin, Union

import numpy.typing as npt

input_filename = argv[1]
output_filename = argv[2]
Expand All @@ -45,6 +49,25 @@
compare_filter.Update()
assert compare_filter.GetMaximumDifference() < 0.000000001

# Check the type hints
type_hints = get_type_hints(itk.median_image_filter, globalns= { 'itk': itk })

assert 'args' in type_hints
args_hints = type_hints['args']
assert get_origin(args_hints) is Union
assert itk.ImageBase in get_args(args_hints)

assert 'radius' in type_hints
radius_hints = type_hints['radius']
assert get_origin(radius_hints) is Union
assert int in get_args(radius_hints)
assert Sequence[int] in get_args(radius_hints)

assert 'return' in type_hints
result_hints = type_hints['return']
assert itk.ImageBase in get_args(args_hints)


# Test that `__call__()` inside itkTemplate is deprecated. Replaced
# by snake_case functions
with warnings.catch_warnings(record=True) as w:
Expand Down
4 changes: 2 additions & 2 deletions Testing/ContinuousIntegration/AzurePipelinesLinuxPython.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ jobs:
- bash: |
set -x
sudo python3 -m pip install ninja
sudo python3 -m pip install ninja numpy>=1.20 typing-extensions
sudo apt-get update
sudo apt-get install -y python3-venv python3-numpy python-numpy
sudo apt-get install -y python3-venv
sudo python3 -m pip install --upgrade setuptools
sudo python3 -m pip install scikit-ci-addons dask distributed
displayName: 'Install dependencies'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
- bash: |
set -x
sudo pip3 install ninja numpy
sudo pip3 install ninja numpy>=1.20 typing-extensions
sudo python3 -m pip install --upgrade setuptools
sudo python3 -m pip install scikit-ci-addons
sudo python3 -m pip install lxml dask distributed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
architecture: 'x64'

- script: |
python -m pip install ninja numpy
python -m pip install ninja numpy>=1.20 typing-extensions
python -m pip install --upgrade setuptools
python -m pip install scikit-ci-addons dask distributed
displayName: 'Install dependencies'
Expand Down
2 changes: 0 additions & 2 deletions Wrapping/Generators/Python/Tests/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
import numpy as np

import itk
import itk.support.types as itk_types


def custom_callback(name, progress):
if progress == 0:
Expand Down
54 changes: 27 additions & 27 deletions Wrapping/Generators/Python/itk/support/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@

fileiotype = Union[str, bytes, os.PathLike]

import itk.support.types as itk_types
import itk.support.types as itkt

if TYPE_CHECKING:
try:
import xarray
import xarray as xr
except ImportError:
pass
try:
Expand Down Expand Up @@ -154,7 +154,7 @@ def echo(obj, f=system_error_stream) -> None:
print(f, obj)


def size(image_or_filter: "itk_types.ImageOrImageSource") -> Sequence[int]:
def size(image_or_filter: "itkt.ImageOrImageSource") -> Sequence[int]:
"""Return the size of an image, or of the output image of a filter
This method take care of updating the needed information
Expand All @@ -168,7 +168,7 @@ def size(image_or_filter: "itk_types.ImageOrImageSource") -> Sequence[int]:
return img.GetLargestPossibleRegion().GetSize()


def physical_size(image_or_filter: "itk_types.ImageOrImageSource") -> Sequence[float]:
def physical_size(image_or_filter: "itkt.ImageOrImageSource") -> Sequence[float]:
"""Return the physical size of an image, or of the output image of a filter
This method take care of updating the needed information
Expand All @@ -184,7 +184,7 @@ def physical_size(image_or_filter: "itk_types.ImageOrImageSource") -> Sequence[f
return result


def spacing(image_or_filter: "itk_types.ImageOrImageSource") -> Sequence[float]:
def spacing(image_or_filter: "itkt.ImageOrImageSource") -> Sequence[float]:
"""Return the spacing of an image, or of the output image of a filter
This method take care of updating the needed information
Expand All @@ -197,7 +197,7 @@ def spacing(image_or_filter: "itk_types.ImageOrImageSource") -> Sequence[float]:
return img.GetSpacing()


def origin(image_or_filter: "itk_types.ImageOrImageSource") -> Sequence[float]:
def origin(image_or_filter: "itkt.ImageOrImageSource") -> Sequence[float]:
"""Return the origin of an image, or of the output image of a filter
This method take care of updating the needed information
Expand All @@ -210,7 +210,7 @@ def origin(image_or_filter: "itk_types.ImageOrImageSource") -> Sequence[float]:
return img.GetOrigin()


def index(image_or_filter: "itk_types.ImageOrImageSource") -> Sequence[int]:
def index(image_or_filter: "itkt.ImageOrImageSource") -> Sequence[int]:
"""Return the index of an image, or of the output image of a filter
This method take care of updating the needed information
Expand All @@ -223,7 +223,7 @@ def index(image_or_filter: "itk_types.ImageOrImageSource") -> Sequence[int]:
return img.GetLargestPossibleRegion().GetIndex()


def region(image_or_filter: "itk_types.ImageOrImageSource") -> "itk_types.ImageRegion":
def region(image_or_filter: "itkt.ImageOrImageSource") -> "itkt.ImageRegion":
"""Return the region of an image, or of the output image of a filter
This method take care of updating the needed information
Expand Down Expand Up @@ -291,7 +291,7 @@ def _GetArrayFromImage(


def GetArrayFromImage(
image_or_filter: "itk_types.ImageOrImageSource",
image_or_filter: "itkt.ImageOrImageSource",
keep_axes: bool = False,
update: bool = True,
ttype=None,
Expand All @@ -306,7 +306,7 @@ def GetArrayFromImage(


def GetArrayViewFromImage(
image_or_filter: "itk_types.ImageOrImageSource",
image_or_filter: "itkt.ImageOrImageSource",
keep_axes: bool = False,
update: bool = True,
ttype=None,
Expand Down Expand Up @@ -495,7 +495,7 @@ def GetMatrixFromArray(arr):
matrix_from_array = GetMatrixFromArray


def xarray_from_image(l_image: "itk_types.ImageOrImageSource") -> "xarray.DataArray":
def xarray_from_image(l_image: "itkt.ImageOrImageSource") -> "xr.DataArray":
"""Convert an itk.Image to an xarray.DataArray.
Origin and spacing metadata is preserved in the xarray's coords. The
Expand Down Expand Up @@ -542,7 +542,7 @@ def xarray_from_image(l_image: "itk_types.ImageOrImageSource") -> "xarray.DataAr
return data_array


def image_from_xarray(data_array: "xarray.DataArray") -> "itk_types.ImageBase":
def image_from_xarray(data_array: "xr.DataArray") -> "itkt.ImageBase":
"""Convert an xarray.DataArray to an itk.Image.
Metadata encoded with xarray_from_image is applied to the itk.Image.
Expand Down Expand Up @@ -594,7 +594,7 @@ def image_from_xarray(data_array: "xarray.DataArray") -> "itk_types.ImageBase":
return itk_image


def vtk_image_from_image(l_image: "itk_types.ImageOrImageSource") -> "vtk.vtkImageData":
def vtk_image_from_image(l_image: "itkt.ImageOrImageSource") -> "vtk.vtkImageData":
"""Convert an itk.Image to a vtk.vtkImageData."""
import itk
import vtk
Expand Down Expand Up @@ -632,7 +632,7 @@ def vtk_image_from_image(l_image: "itk_types.ImageOrImageSource") -> "vtk.vtkIma
return vtk_image


def image_from_vtk_image(vtk_image: "vtk.vtkImageData") -> "itk_types.ImageBase":
def image_from_vtk_image(vtk_image: "vtk.vtkImageData") -> "itkt.ImageBase":
"""Convert a vtk.vtkImageData to an itk.Image."""
import itk
from vtk.util.numpy_support import vtk_to_numpy
Expand Down Expand Up @@ -666,7 +666,7 @@ def image_from_vtk_image(vtk_image: "vtk.vtkImageData") -> "itk_types.ImageBase"
# return an image


def image_intensity_min_max(image_or_filter: "itk_types.ImageOrImageSource"):
def image_intensity_min_max(image_or_filter: "itkt.ImageOrImageSource"):
"""Return the minimum and maximum of values in a image of in the output image of a filter
The minimum and maximum values are returned in a tuple: (min, max)
Expand Down Expand Up @@ -694,10 +694,10 @@ def range(image_or_filter):


def imwrite(
image_or_filter: "itk_types.ImageOrImageSource",
image_or_filter: "itkt.ImageOrImageSource",
filename: fileiotype,
compression: bool = False,
imageio: Optional["itk_types.ImageIOBase"] = None,
imageio: Optional["itkt.ImageIOBase"] = None,
) -> None:
"""Write a image or the output image of a filter to a file.
Expand Down Expand Up @@ -737,10 +737,10 @@ def imwrite(

def imread(
filename: fileiotype,
pixel_type: Optional["itk_types.PixelTypes"] = None,
pixel_type: Optional["itkt.PixelTypes"] = None,
fallback_only: bool = False,
imageio: Optional["itk_types.ImageIOBase"] = None,
) -> "itk_types.ImageBase":
imageio: Optional["itkt.ImageIOBase"] = None,
) -> "itkt.ImageBase":
"""Read an image from a file or series of files and return an itk.Image.
Parameters
Expand Down Expand Up @@ -847,7 +847,7 @@ def imread(


def meshwrite(
mesh: "itk_types.Mesh", filename: fileiotype, compression: bool = False
mesh: "itkt.Mesh", filename: fileiotype, compression: bool = False
) -> None:
"""Write a mesh to a file.
Expand All @@ -868,9 +868,9 @@ def meshwrite(

def meshread(
filename: fileiotype,
pixel_type: Optional["itk_types.PixelTypes"] = None,
pixel_type: Optional["itkt.PixelTypes"] = None,
fallback_only: bool = False,
) -> "itk_types.Mesh":
) -> "itkt.Mesh":
"""Read a mesh from a file and return an itk.Mesh.
The reader is instantiated with the mesh type of the mesh file if
Expand Down Expand Up @@ -918,7 +918,7 @@ def meshread(
return reader.GetOutput()


def transformread(filename: fileiotype) -> List["itk_types.TransformBase"]:
def transformread(filename: fileiotype) -> List["itkt.TransformBase"]:
"""Read an itk Transform file.
Parameters
Expand Down Expand Up @@ -949,7 +949,7 @@ def transformread(filename: fileiotype) -> List["itk_types.TransformBase"]:


def transformwrite(
transforms: List["itk_types.TransformBase"],
transforms: List["itkt.TransformBase"],
filename: fileiotype,
compression: bool = False,
) -> None:
Expand Down Expand Up @@ -1434,7 +1434,7 @@ def Stop() -> None:
auto_pipeline.current = None


def down_cast(obj: "itk_types.LightObject"):
def down_cast(obj: "itkt.LightObject"):
"""Down cast an itk.LightObject (or a object of a subclass) to its most
specialized type.
"""
Expand Down Expand Up @@ -1664,7 +1664,7 @@ def template(cl):
return itkTemplateBase.__template_instantiations_object_to_name__[class_(cl)]


def ctype(s: str) -> "itk_types.itkCType":
def ctype(s: str) -> "itkt.itkCType":
"""Return the c type corresponding to the string passed in parameter
The string can contain some extra spaces.
Expand Down
61 changes: 6 additions & 55 deletions Wrapping/Generators/Python/itk/support/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,57 +40,6 @@ def camel_to_snake_case(name):
snake = re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake)
return snake.replace("__", "_").lower()


def filter_args(filter_object):
"""
This function accepts an itk filter object,
returns a) specific arguments of this filter
b) common arguments of its super class (i.e., itk.ProcessObject).
Both args exclude some useless args denoted as useless_args.
"""
import itk

exclude_args = [
camel_to_snake_case(item[3:])
for item in dir(itk.Object)
if item.startswith("Set")
]
common_args = [
camel_to_snake_case(item[3:])
for item in dir(itk.ProcessObject)
if item.startswith("Set")
]
useless_args = [
"abort_generate_data",
"release_data_flag",
"release_data_before_update_flag",
]
specific_args = [
camel_to_snake_case(item[3:])
for item in dir(filter_object)
if item.startswith("Set")
]

str_ret_args = "".join(
[
" " + item + "\n"
for item in specific_args
if item not in exclude_args
and item not in common_args
and item not in useless_args
]
)

str_common_args = "".join(
[
" " + item + "\n"
for item in common_args
if item not in useless_args and item not in exclude_args
]
)
return str_ret_args, str_common_args


def is_arraylike(arr):
return (
hasattr(arr, "shape")
Expand Down Expand Up @@ -118,11 +67,13 @@ def move_last_dimension_to_first(arr):
arr_interleaved_channels = np.moveaxis(arr, dest, source).copy()
return arr_interleaved_channels

def accept_numpy_array_like_xarray(image_filter):
def accept_array_like_xarray_torch(image_filter):
"""Decorator that allows itk.ProcessObject snake_case functions to accept
NumPy array-like or xarray DataArray inputs for itk.Image inputs. If a NumPy array-like is
passed as an input, output itk.Image's are converted to numpy.ndarray's. If a xarray DataArray is
passed as an input, output itk.Image's are converted to xarray.DataArray's."""
NumPy array-like, PyTorch Tensor's or xarray DataArray inputs for itk.Image inputs.
If a NumPy array-like is passed as an input, output itk.Image's are converted to numpy.ndarray's.
If a torch.Tensor is passed as an input, output itk.Image's are converted to torch.Tensors.
If a xarray DataArray is passed as an input, output itk.Image's are converted to xarray.DataArray's."""
import numpy as np
import itk

Expand Down
8 changes: 8 additions & 0 deletions Wrapping/Generators/Python/itk/support/template_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,14 @@ def __instancecheck__(self, instance) -> bool:
return True
return False

def __hash__(self):
"""Overloads `hash()` when called on an `itkTemplate` object.
Identify with the __name__, e.g. `itk.Image.__name__` is `itk::Image`.
Used by frozenset construction in typing._GenericAlias
"""
return hash(self.__name__)

def __find_param__(self, paramSetString) -> List[Any]:
"""Find the parameters of the template.
Expand Down
Loading

0 comments on commit ed11fdc

Please sign in to comment.