Skip to content

Commit

Permalink
feat(Metalearner): Added Metalearner Weighted Mean
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed May 22, 2022
1 parent 54d6f4e commit 162b977
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aucmedi/ensemble/metalearner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from aucmedi.ensemble.metalearner.gaussian_process import Gaussian_Process
from aucmedi.ensemble.metalearner.decision_tree import Decision_Tree
from aucmedi.ensemble.metalearner.best_model import Best_Model
from aucmedi.ensemble.metalearner.averaging_mean_weighted import Averaging_WeightedMean

#-----------------------------------------------------#
# Access Functions to Metalearners #
Expand All @@ -62,5 +63,6 @@
"gaussian_process": Gaussian_Process,
"decision_tree": Decision_Tree,
"best_model": Best_Model,
"weighted_mean": Averaging_WeightedMean,
}
""" Dictionary of implemented Metalearners. """
98 changes: 98 additions & 0 deletions aucmedi/ensemble/metalearner/averaging_mean_weighted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#==============================================================================#
# Author: Dominik Müller #
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, #
# University of Augsburg #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
#==============================================================================#
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# External libraries
import pickle
from sklearn.metrics import roc_auc_score
import numpy as np
# Internal libraries/scripts
from aucmedi.ensemble.metalearner.ml_base import Metalearner_Base

#-----------------------------------------------------#
# Metalearner: Weighted Mean #
#-----------------------------------------------------#
class Averaging_WeightedMean(Metalearner_Base):
""" A Weighted Mean based Metalearner.
This class should be passed to a Ensemble function like Stacking for combining predictions.
This Metalearner computes the Area Under the Receiver Operating Characteristic Curve (ROC AUC)
for each model, and utilizes these scores for a weighted Mean to average predictions.
!!! info
Can be utilized for binary, multi-class and multi-label tasks.
"""
#---------------------------------------------#
# Initialization #
#---------------------------------------------#
def __init__(self):
self.model = {}

#---------------------------------------------#
# Training #
#---------------------------------------------#
def train(self, x, y):
# Identify number of models and classes
n_classes = y.shape[1]
n_models = int(x.shape[1] / n_classes)
# Preprocess data input
data = np.reshape(x, (x.shape[0], n_models, n_classes))

# Compute AUC scores and store them to cache
weights = []
for m in range(n_models):
pred = data[:,m,:]
score = roc_auc_score(y, pred, average="macro")
weights.append(score)

# Store results to cache
self.model["weights"] = weights
self.model["n_classes"] = n_classes
self.model["n_models"] = n_models

#---------------------------------------------#
# Prediction #
#---------------------------------------------#
def predict(self, data):
# Preprocess data input
preds = np.reshape(data, (data.shape[0],
self.model["n_models"],
self.model["n_classes"]))
# Compute weighted mean
pred = np.average(preds, axis=1, weights=self.model["weights"])
# Return results
return pred

#---------------------------------------------#
# Dump Model to Disk #
#---------------------------------------------#
def dump(self, path):
# Dump model to disk via pickle
with open(path, "wb") as pickle_writer:
pickle.dump(self.model, pickle_writer)

#---------------------------------------------#
# Load Model from Disk #
#---------------------------------------------#
def load(self, path):
# Load model from disk via pickle
with open(path, "rb") as pickle_reader:
self.model = pickle.load(pickle_reader)
31 changes: 31 additions & 0 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,3 +231,34 @@ def test_Best_Model_usage(self):
preds = ml.predict(data=self.pred_data)
# Check
self.assertTrue(np.array_equal(preds.shape, (25,4)))

#-------------------------------------------------#
# Weighted Mean #
#-------------------------------------------------#
def test_Averaging_WeightedMean_create(self):
# Initializations
ml = Averaging_WeightedMean()
self.assertTrue("weighted_mean" in metalearner_dict)
ml = metalearner_dict["weighted_mean"]()
# Storage
model_path = os.path.join(self.tmp_data.name, "ml_model.pickle")
self.assertFalse(os.path.exists(model_path))
ml.dump(model_path)
self.assertTrue(os.path.exists(model_path))
# Loading
ml.model = None
self.assertTrue(ml.model is None)
ml.load(model_path)
self.assertFalse(ml.model is None)
# Cleanup
os.remove(model_path)

def test_Averaging_WeightedMean_usage(self):
# Initializations
ml = Averaging_WeightedMean()
# Training
ml.train(x=self.pred_data, y=self.labels_ohe)
# Inference
preds = ml.predict(data=self.pred_data)
# Check
self.assertTrue(np.array_equal(preds.shape, (25,4)))

0 comments on commit 162b977

Please sign in to comment.