Skip to content

Commit

Permalink
feat(Ensemble): Added return_ensemble feature for Stacking and Baggin…
Browse files Browse the repository at this point in the history
…g predictions.
  • Loading branch information
muellerdo committed Jun 4, 2022
1 parent 5ecb785 commit 4f90c5c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 5 deletions.
13 changes: 10 additions & 3 deletions aucmedi/ensemble/bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def train(self, training_generator, epochs=20, iterations=None,
# Return Bagging history object
return history_bagging

def predict(self, prediction_generator, aggregate="mean"):
def predict(self, prediction_generator, aggregate="mean",
return_ensemble=False):
""" Prediction function for the Bagging models.
The fitted models will predict classifications for the provided [DataGenerator][aucmedi.data_processing.data_generator.DataGenerator].
Expand All @@ -219,17 +220,21 @@ def predict(self, prediction_generator, aggregate="mean"):
- self-initialization with an AUCMEDI Aggregate function,
- use a string key to call an AUCMEDI Aggregate function by name, or
- implementing a custom Aggregate function by extending the [AUCMEDI base class for Aggregate functions][aucmedi.ensemble.aggregate.agg_base.py]
- implementing a custom Aggregate function by extending the [AUCMEDI base class for Aggregate functions][aucmedi.ensemble.aggregate.agg_base]
!!! info
Description and list of implemented Aggregate functions can be found here:
[Aggregate][aucmedi.ensemble.aggregate]
Args:
prediction_generator (DataGenerator): A data generator which will be used for inference.
aggregate (str or aggregate Function): Aggregate function class instance or a string for an AUCMEDI Aggregate function.
return_ensemble (bool): Option, whether gathered ensemble of predictions should be returned.
Returns:
preds (numpy.ndarray): A NumPy array of predictions formatted with shape (n_samples, n_labels).
ensemble (numpy.ndarray): Optional ensemble of predictions: Will be only passed if `return_ensemble=True`.
Shape (n_models, n_samples, n_labels).
"""
# Verify if there is a linked cache dictionary
con_tmp = (isinstance(self.cache_dir, tempfile.TemporaryDirectory) and \
Expand Down Expand Up @@ -301,8 +306,10 @@ def predict(self, prediction_generator, aggregate="mean"):

# Convert prediction list to NumPy
preds_final = np.asarray(preds_final)

# Return ensembled predictions
return preds_final
if return_ensemble : return preds_final, preds_ensemble
else : return preds_final

# Dump model to file
def dump(self, directory_path):
Expand Down
8 changes: 6 additions & 2 deletions aucmedi/ensemble/stacking.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def train_metalearner(self, training_generator):
"metalearner.model.pickle")
self.ml_model.dump(path_metalearner)

def predict(self, prediction_generator):
def predict(self, prediction_generator, return_ensemble=False):
""" Prediction function for Stacking.
The fitted models and selected Metalearner will predict classifications for the provided
Expand All @@ -353,9 +353,12 @@ def predict(self, prediction_generator):
Args:
prediction_generator (DataGenerator): A data generator which will be used for inference.
return_ensemble (bool): Option, whether gathered ensemble of predictions should be returned.
Returns:
preds (numpy.ndarray): A NumPy array of predictions formatted with shape (n_samples, n_labels).
ensemble (numpy.ndarray): Optional ensemble of predictions: Will be only passed if `return_ensemble=True`.
Shape (n_models, n_samples, n_labels).
"""
# Verify if there is a linked cache dictionary
con_tmp = (isinstance(self.cache_dir, tempfile.TemporaryDirectory) and \
Expand Down Expand Up @@ -436,7 +439,8 @@ def predict(self, prediction_generator):
preds_final = np.asarray(preds_final)

# Return ensembled predictions
return preds_final
if return_ensemble : return preds_final, np.swapaxes(preds_ensemble,1,0)
else : return preds_final

# Dump model to file
def dump(self, directory_path):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,10 @@ def test_Bagging_predict(self):
# Run Inference with majority vote aggregation
preds = el.predict(datagen, aggregate="majority_vote")
self.assertTrue(np.array_equal(preds.shape, (3,2)))
# Run Inference with returned ensemble
preds, ensemble = el.predict(datagen, return_ensemble=True)
self.assertTrue(np.array_equal(preds.shape, (3,2)))
self.assertTrue(np.array_equal(ensemble.shape, (2,3,2)))

def test_Bagging_dump(self):
# Initialize training DataGenerator
Expand Down Expand Up @@ -382,6 +386,11 @@ def test_Stacking_predict_metalearner(self):
preds = el.predict(datagen)
self.assertTrue(np.array_equal(preds.shape, (12,2)))

# Run Inference with returned ensemble
preds, ensemble = el.predict(datagen, return_ensemble=True)
self.assertTrue(np.array_equal(preds.shape, (12,2)))
self.assertTrue(np.array_equal(ensemble.shape, (2,12,2)))

def test_Stacking_predict_aggregate(self):
# Initialize training DataGenerator
datagen = DataGenerator(np.repeat(self.sampleList2D, 4),
Expand All @@ -403,6 +412,11 @@ def test_Stacking_predict_aggregate(self):
preds = el.predict(datagen)
self.assertTrue(np.array_equal(preds.shape, (12,2)))

# Run Inference with returned ensemble
preds, ensemble = el.predict(datagen, return_ensemble=True)
self.assertTrue(np.array_equal(preds.shape, (12,2)))
self.assertTrue(np.array_equal(ensemble.shape, (2,12,2)))

def test_Stacking_dump(self):
# Initialize training DataGenerator
datagen = DataGenerator(np.repeat(self.sampleList2D, 4),
Expand Down

0 comments on commit 4f90c5c

Please sign in to comment.