Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Custom resizing function in flow from directory #248

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
9 changes: 8 additions & 1 deletion keras_preprocessing/image/dataframe_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ class DataFrameIterator(BatchFromFilesMixin, Iterator):
If PIL version 1.1.3 or newer is installed, "lanczos" is also
supported. If PIL version 3.4.0 or newer is installed, "box" and
"hamming" are also supported. By default, "nearest" is used.
resizing_function: function, used to resize the loaded images to the
target size. this will overrule interpolation. If None, then
interpolation will happen. The input is an image in the specified
data format, and the output has to be an image in the specified
data format with the target size.
dtype: Dtype to use for the generated arrays.
validate_filenames: Boolean, whether to validate image filenames in
`x_col`. If `True`, invalid images will be ignored. Disabling this option
Expand Down Expand Up @@ -109,6 +114,7 @@ def __init__(self,
save_format='png',
subset=None,
interpolation='nearest',
resizing_function=None,
dtype='float32',
validate_filenames=True):

Expand All @@ -120,7 +126,8 @@ def __init__(self,
save_prefix,
save_format,
subset,
interpolation)
interpolation,
resizing_function)
df = dataframe.copy()
self.directory = directory or ''
self.class_mode = class_mode
Expand Down
9 changes: 8 additions & 1 deletion keras_preprocessing/image/directory_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ class DirectoryIterator(BatchFromFilesMixin, Iterator):
If PIL version 1.1.3 or newer is installed, "lanczos" is also
supported. If PIL version 3.4.0 or newer is installed, "box" and
"hamming" are also supported. By default, "nearest" is used.
resizing_function: function, used to resize the loaded images to the
target size. this will overrule interpolation. If None, then
interpolation will happen. The input is an image in the specified
data format, and the output has to be an image in the specified
data format with the target size.
dtype: Dtype to use for generated arrays.
"""
allowed_class_modes = {'categorical', 'binary', 'sparse', 'input', None}
Expand All @@ -81,6 +86,7 @@ def __init__(self,
follow_links=False,
subset=None,
interpolation='nearest',
resizing_function=None,
dtype='float32'):
super(DirectoryIterator, self).set_processing_attrs(image_data_generator,
target_size,
Expand All @@ -90,7 +96,8 @@ def __init__(self,
save_prefix,
save_format,
subset,
interpolation)
interpolation,
resizing_function)
self.directory = directory
self.classes = classes
if class_mode not in self.allowed_class_modes:
Expand Down
11 changes: 9 additions & 2 deletions keras_preprocessing/image/image_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,8 @@ def flow_from_directory(self,
save_format='png',
follow_links=False,
subset=None,
interpolation='nearest'):
interpolation='nearest',
resizing_function=None):
"""Takes the path to a directory & generates batches of augmented data.

# Arguments
Expand Down Expand Up @@ -515,6 +516,11 @@ class subdirectories (default: False).
supported. If PIL version 3.4.0 or newer is installed,
`"box"` and `"hamming"` are also supported.
By default, `"nearest"` is used.
resizing_function: function, used to resize the loaded images to the
target size. this will overrule interpolation. If None, then
interpolation will happen. The input is an image in the specified
data format, and the output has to be an image in the specified
data format with the target size.

# Returns
A `DirectoryIterator` yielding tuples of `(x, y)`
Expand All @@ -539,7 +545,8 @@ class subdirectories (default: False).
follow_links=follow_links,
subset=subset,
interpolation=interpolation,
dtype=self.dtype
dtype=self.dtype,
resizing_function=resizing_function
)

def flow_from_dataframe(self,
Expand Down
23 changes: 21 additions & 2 deletions keras_preprocessing/image/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ def set_processing_attrs(self,
save_prefix,
save_format,
subset,
interpolation):
interpolation,
resizing_function):
"""Sets attributes to use later for processing files into a batch.

# Arguments
Expand All @@ -168,6 +169,11 @@ def set_processing_attrs(self,
If PIL version 1.1.3 or newer is installed, "lanczos" is also
supported. If PIL version 3.4.0 or newer is installed, "box" and
"hamming" are also supported. By default, "nearest" is used.
resizing_function: function, used to resize the loaded images to the
target size. this will overrule interpolation. If None, then
interpolation will happen. The input is an image in the specified
data format, and the output has to be an image in the specified
data format with the target size.
"""
self.image_data_generator = image_data_generator
self.target_size = tuple(target_size)
Expand Down Expand Up @@ -195,6 +201,7 @@ def set_processing_attrs(self,
self.save_prefix = save_prefix
self.save_format = save_format
self.interpolation = interpolation
self.resizing_function = resizing_function
if subset is not None:
validation_split = self.image_data_generator._validation_split
if subset == 'validation':
Expand Down Expand Up @@ -223,12 +230,24 @@ def _get_batches_of_transformed_samples(self, index_array):
# build batch of image data
# self.filepaths is dynamic, is better to call it once outside the loop
filepaths = self.filepaths
load_target_size = self.target_size
if self.resizing_function is not None:
load_target_size = None
for i, j in enumerate(index_array):
img = load_img(filepaths[j],
color_mode=self.color_mode,
target_size=self.target_size,
target_size=load_target_size,
interpolation=self.interpolation)
x = img_to_array(img, data_format=self.data_format)
# NOTE: we could potentially have keyword arguments for the
# resizing function
if self.resizing_function is not None:
x = self.resizing_function(x)
if x.shape != self.image_shape:
raise ValueError(
'The loaded image shape %s (at %s) does not correspond to'
'the specified image shape %s' %
(str(x.shape), filepaths[j], str(self.image_shape)))
# Pillow images should be closed after `load_img`,
# but not PIL images.
if hasattr(img, 'close'):
Expand Down