Skip to content

Commit

Permalink
feat(DataGenerator): rework deprecated keras.Iterator to keras.Sequen…
Browse files Browse the repository at this point in the history
…ce - Closes #76,#189
  • Loading branch information
muellerdo committed Feb 26, 2023
1 parent c70fc4c commit 2e8f2a9
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 22 deletions.
87 changes: 72 additions & 15 deletions aucmedi/data_processing/data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# Library imports #
#-----------------------------------------------------#
# External libraries
from tensorflow.keras.preprocessing.image import Iterator
from tensorflow.keras.utils import Sequence
import numpy as np
from multiprocessing.pool import ThreadPool
from itertools import repeat
Expand All @@ -34,7 +34,7 @@
#-----------------------------------------------------#
# Keras Data Generator #
#-----------------------------------------------------#
class DataGenerator(Iterator):
class DataGenerator(Sequence):
""" Infinite Data Generator which automatically creates batches from a list of samples.
The created batches are model ready. This generator can be supplied directly
Expand Down Expand Up @@ -188,21 +188,30 @@ def __init__(self, samples, path_imagedir, labels=None, metadata=None,
**kwargs (dict): Additional parameters for the sample loader.
"""
# Cache class variables
self.samples = samples
self.labels = labels
self.metadata = metadata
self.sample_weights = sample_weights
self.prepare_images = prepare_images
self.workers = workers
self.sample_loader = loader
self.kwargs = kwargs
self.samples = samples
self.path_imagedir = path_imagedir
self.image_format = image_format
self.grayscale = grayscale
self.subfunctions = subfunctions
self.batch_size = batch_size
self.data_aug = data_aug
self.standardize_mode = standardize_mode
self.resize = resize
self.shuffle = shuffle
self.seed = seed
# Cache keras.Sequence class variables
self.n = len(samples)
self.max_iterations = (self.n + self.batch_size - 1) // self.batch_size
self.iterations = self.max_iterations
self.seed_walk = 0
self.index_array = None

# Initialize Standardization Subfunction
if standardize_mode is not None:
Expand Down Expand Up @@ -259,10 +268,6 @@ def __init__(self, samples, path_imagedir, labels=None, metadata=None,
print("A directory for image preparation was created:",
self.prepare_dir)

# Pass initialization parameters to parent Iterator class
size = len(samples)
super(DataGenerator, self).__init__(size, batch_size, shuffle, seed)

#-----------------------------------------------------#
# Batch Generation Function #
#-----------------------------------------------------#
Expand Down Expand Up @@ -310,17 +315,19 @@ def _get_batches_of_transformed_samples(self, index_array):
#-----------------------------------------------------#
# Image Preprocessing #
#-----------------------------------------------------#
""" Internal preprocessing function for applying Subfunctions, augmentation, resizing and standardization
on an image given its index.
def preprocess_image(self, index, prepared_image=False, run_aug=True,
run_standardize=True, dump_pickle=False):
""" Internal preprocessing function for applying Subfunctions, augmentation, resizing and standardization
on an image given its index.
Activating the prepared_image option also allows loading a beforehand preprocessed image from disk.
Can be utilized for debugging purposes.
Deactivating the run_aug & run_standardize option to output image without augmentation and standardization.
Activating the prepared_image option also allows loading a beforehand preprocessed image from disk.
Activating dump_pickle will store the preprocessed image as pickle on disk instead of returning.
"""
def preprocess_image(self, index, prepared_image=False, run_aug=True,
run_standardize=True, dump_pickle=False):
Deactivating the run_aug & run_standardize option to output image without augmentation and standardization.
Activating dump_pickle will store the preprocessed image as pickle on disk instead of returning.
"""
# Load prepared image from disk
if prepared_image:
# Load from disk
Expand Down Expand Up @@ -359,3 +366,53 @@ def preprocess_image(self, index, prepared_image=False, run_aug=True,
pickle.dump(img, pickle_writer)
# Return preprocessed image
else : return img

#-----------------------------------------------------#
# Sample Generation Function #
#-----------------------------------------------------#
""" Internal function for calling the batch generation process. """
def __getitem__(self, raw_idx):
# Obtain the index based on the passed index offset to allow repetition
idx = raw_idx % self.max_iterations
# Build index array for the start
if self.index_array is None:
self.__set_index_array__()
# Select samples for next batch
index_array = self.index_array[
self.batch_size * idx : self.batch_size * (idx + 1)
]
# Generate batch
print(self.index_array, raw_idx, idx, index_array)
return self._get_batches_of_transformed_samples(index_array)

#-----------------------------------------------------#
# Generator Functions #
#-----------------------------------------------------#
""" Internal function for identifying the generator length. """
def __len__(self):
return self.iterations

""" Configuration function for fixing the number of iterations. """
def set_length(self, iterations):
self.iterations = iterations

""" Configuration function for reseting the number of iterations. """
def reset_length(self):
self.iterations = self.max_iterations

""" Internal function for initializing and shuffling the index array. """
def __set_index_array__(self):
# Generate index array
self.index_array = np.arange(self.n)
# Shuffle if needed
if self.shuffle:
# Update seed for repeated permutation of the index_array
if self.seed is not None:
np.random.seed(self.seed + self.seed_walk)
self.seed_walk += 1
# Permutate index array
self.index_array = np.random.permutation(self.n)

""" Internal function at the end of an epoch. """
def on_epoch_end(self):
self.__set_index_array__()
14 changes: 8 additions & 6 deletions aucmedi/neural_network/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ def train(self, training_generator, validation_generator=None, epochs=20,
Returns:
history (dict): A history dictionary from a Keras history object which contains several logs.
"""
# Adjust number of iterations in training DataGenerator to allow repitition
if iterations is not None : training_generator.set_length(iterations)
# Running a standard training process
if not transfer_learning:
# Run training process with the Keras fit function
Expand All @@ -301,8 +303,7 @@ def train(self, training_generator, validation_generator=None, epochs=20,
max_queue_size=self.batch_queue_size,
verbose=self.verbose)
# Return logged history object
return history.history

history_out = history.history
# Running a transfer learning training process
else:
# Freeze all base model layers (all layers after "avg_pool")
Expand Down Expand Up @@ -330,9 +331,6 @@ def train(self, training_generator, validation_generator=None, epochs=20,
# Compile model with lower learning rate
self.model.compile(optimizer=Adam(learning_rate=self.tf_lr_end),
loss=self.loss, metrics=self.metrics)
# Reset data generators
training_generator.reset()
if validation_generator is not None : validation_generator.reset()
# Run second training with unfrozed layers
history_end = self.model.fit(training_generator,
validation_data=validation_generator,
Expand All @@ -349,7 +347,11 @@ def train(self, training_generator, validation_generator=None, epochs=20,
he = {"ft_" + k: v for k, v in history_end.history.items()} # prefix : ft for fine tuning
history = {**hs, **he}
# Return combined history objects
return history
history_out = history
# Reset number of iterations of the training DataGenerator
if iterations is not None : training_generator.reset_length()
# Return fitting history
return history_out

#---------------------------------------------#
# Prediction #
Expand Down
2 changes: 1 addition & 1 deletion aucmedi/xai/methods/occlusion_sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class OcclusionSensitivity(XAImethod_Base):
This class provides functionality for running the compute_heatmap function,
which computes a Occlusion Sensitivity Map for an image with a model.
"""
def __init__(self, model, layerName=None, patch_size=4):
def __init__(self, model, layerName=None, patch_size=16):
""" Initialization function for creating a Occlusion Sensitivity Map as XAI Method object.
Args:
Expand Down

0 comments on commit 2e8f2a9

Please sign in to comment.