Skip to content

Commit

Permalink
feat(Bagging): Implemented Bagging load (re-using result directory) f…
Browse files Browse the repository at this point in the history
…unctionality
  • Loading branch information
muellerdo committed May 19, 2022
1 parent 2110666 commit 41c4114
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions aucmedi/ensemble/bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,23 +233,25 @@ def dump(self, directory_path):
self.cache_dir = directory_path

# Load model from file
def load(self, file_path, custom_objects={}):
""" Load neural network model and its weights from a file.
After loading, the model will be compiled.
If loading a model in ".hdf5" format, it is not necessary to define any custom_objects.
def load(self, directory_path):
""" Load a Bagging model directory which can be used for aggregated inference.
Args:
file_path (str): Input path, from which the model will be loaded.
custom_objects (dict): Dictionary of custom objects for compiling
(e.g. non-TensorFlow based loss functions or architectures).
directory_path (str): Input path, from which the Bagging models will be loaded.
"""
# Create model input path
self.model = load_model(file_path, custom_objects, compile=False)
# Compile model
self.model.compile(optimizer=Adam(learning_rate=self.learning_rate),
loss=self.loss, metrics=self.metrics)
# Check directory existence
if not os.path.exists(directory_path):
raise FileNotFoundError("Provided model directory path does not exist!",
directory_path)
# Check model existence
for i in range(self.k_fold):
path_model = os.path.join(directory_path,
"cv_" + str(i) + ".model.hdf5")
if not os.path.exists(path_model):
raise FileNotFoundError("Bagging model for fold " + str(i) + \
" does not exist!", path_model)
# Update model directory
self.cache_dir = directory_path

#-----------------------------------------------------#
# Subroutines #
Expand Down

0 comments on commit 41c4114

Please sign in to comment.