diff --git a/keras_preprocessing/image/dataframe_iterator.py b/keras_preprocessing/image/dataframe_iterator.py index 5412df24..dfc0b9a8 100644 --- a/keras_preprocessing/image/dataframe_iterator.py +++ b/keras_preprocessing/image/dataframe_iterator.py @@ -80,6 +80,11 @@ class DataFrameIterator(BatchFromFilesMixin, Iterator): If PIL version 1.1.3 or newer is installed, "lanczos" is also supported. If PIL version 3.4.0 or newer is installed, "box" and "hamming" are also supported. By default, "nearest" is used. + resizing_function: function, used to resize the loaded images to the + target size. this will overrule interpolation. If None, then + interpolation will happen. The input is an image in the specified + data format, and the output has to be an image in the specified + data format with the target size. dtype: Dtype to use for the generated arrays. validate_filenames: Boolean, whether to validate image filenames in `x_col`. If `True`, invalid images will be ignored. Disabling this option @@ -109,6 +114,7 @@ def __init__(self, save_format='png', subset=None, interpolation='nearest', + resizing_function=None, dtype='float32', validate_filenames=True): @@ -120,7 +126,8 @@ def __init__(self, save_prefix, save_format, subset, - interpolation) + interpolation, + resizing_function) df = dataframe.copy() self.directory = directory or '' self.class_mode = class_mode diff --git a/keras_preprocessing/image/directory_iterator.py b/keras_preprocessing/image/directory_iterator.py index 3f75d835..e03cf31d 100644 --- a/keras_preprocessing/image/directory_iterator.py +++ b/keras_preprocessing/image/directory_iterator.py @@ -60,6 +60,11 @@ class DirectoryIterator(BatchFromFilesMixin, Iterator): If PIL version 1.1.3 or newer is installed, "lanczos" is also supported. If PIL version 3.4.0 or newer is installed, "box" and "hamming" are also supported. By default, "nearest" is used. + resizing_function: function, used to resize the loaded images to the + target size. this will overrule interpolation. If None, then + interpolation will happen. The input is an image in the specified + data format, and the output has to be an image in the specified + data format with the target size. dtype: Dtype to use for generated arrays. """ allowed_class_modes = {'categorical', 'binary', 'sparse', 'input', None} @@ -81,6 +86,7 @@ def __init__(self, follow_links=False, subset=None, interpolation='nearest', + resizing_function=None, dtype='float32'): super(DirectoryIterator, self).set_processing_attrs(image_data_generator, target_size, @@ -90,7 +96,8 @@ def __init__(self, save_prefix, save_format, subset, - interpolation) + interpolation, + resizing_function) self.directory = directory self.classes = classes if class_mode not in self.allowed_class_modes: diff --git a/keras_preprocessing/image/image_data_generator.py b/keras_preprocessing/image/image_data_generator.py index 0d3d92a7..eab56407 100644 --- a/keras_preprocessing/image/image_data_generator.py +++ b/keras_preprocessing/image/image_data_generator.py @@ -447,7 +447,8 @@ def flow_from_directory(self, save_format='png', follow_links=False, subset=None, - interpolation='nearest'): + interpolation='nearest', + resizing_function=None): """Takes the path to a directory & generates batches of augmented data. # Arguments @@ -515,6 +516,11 @@ class subdirectories (default: False). supported. If PIL version 3.4.0 or newer is installed, `"box"` and `"hamming"` are also supported. By default, `"nearest"` is used. + resizing_function: function, used to resize the loaded images to the + target size. this will overrule interpolation. If None, then + interpolation will happen. The input is an image in the specified + data format, and the output has to be an image in the specified + data format with the target size. # Returns A `DirectoryIterator` yielding tuples of `(x, y)` @@ -539,7 +545,8 @@ class subdirectories (default: False). follow_links=follow_links, subset=subset, interpolation=interpolation, - dtype=self.dtype + dtype=self.dtype, + resizing_function=resizing_function ) def flow_from_dataframe(self, diff --git a/keras_preprocessing/image/iterator.py b/keras_preprocessing/image/iterator.py index f5a9b6cb..ada7ecb4 100644 --- a/keras_preprocessing/image/iterator.py +++ b/keras_preprocessing/image/iterator.py @@ -142,7 +142,8 @@ def set_processing_attrs(self, save_prefix, save_format, subset, - interpolation): + interpolation, + resizing_function): """Sets attributes to use later for processing files into a batch. # Arguments @@ -168,6 +169,11 @@ def set_processing_attrs(self, If PIL version 1.1.3 or newer is installed, "lanczos" is also supported. If PIL version 3.4.0 or newer is installed, "box" and "hamming" are also supported. By default, "nearest" is used. + resizing_function: function, used to resize the loaded images to the + target size. this will overrule interpolation. If None, then + interpolation will happen. The input is an image in the specified + data format, and the output has to be an image in the specified + data format with the target size. """ self.image_data_generator = image_data_generator self.target_size = tuple(target_size) @@ -195,6 +201,7 @@ def set_processing_attrs(self, self.save_prefix = save_prefix self.save_format = save_format self.interpolation = interpolation + self.resizing_function = resizing_function if subset is not None: validation_split = self.image_data_generator._validation_split if subset == 'validation': @@ -223,12 +230,24 @@ def _get_batches_of_transformed_samples(self, index_array): # build batch of image data # self.filepaths is dynamic, is better to call it once outside the loop filepaths = self.filepaths + load_target_size = self.target_size + if self.resizing_function is not None: + load_target_size = None for i, j in enumerate(index_array): img = load_img(filepaths[j], color_mode=self.color_mode, - target_size=self.target_size, + target_size=load_target_size, interpolation=self.interpolation) x = img_to_array(img, data_format=self.data_format) + # NOTE: we could potentially have keyword arguments for the + # resizing function + if self.resizing_function is not None: + x = self.resizing_function(x) + if x.shape != self.image_shape: + raise ValueError( + 'The loaded image shape %s (at %s) does not correspond to' + 'the specified image shape %s' % + (str(x.shape), filepaths[j], str(self.image_shape))) # Pillow images should be closed after `load_img`, # but not PIL images. if hasattr(img, 'close'):