Skip to content

Commit

Permalink
feat(Metalearner): Added Metalearner Gaussian Process
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed May 22, 2022
1 parent b7d5548 commit d07b12d
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aucmedi/ensemble/metalearner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,14 @@
from aucmedi.ensemble.metalearner.logistic_regression import Logistic_Regression
from aucmedi.ensemble.metalearner.naive_bayes import Naive_Bayes
from aucmedi.ensemble.metalearner.support_vector_machine import SupportVectorMachine
from aucmedi.ensemble.metalearner.gaussian_process import Gaussian_Process

#-----------------------------------------------------#
# Access Functions to Metalearners #
#-----------------------------------------------------#
metalearner_dict = {"logistic_regression": Logistic_Regression,
"naive_bayes": Naive_Bayes,
"support_vector_machine": SupportVectorMachine,
"gaussian_process": Gaussian_Process,
}
""" Dictionary of implemented Metalearners. """
81 changes: 81 additions & 0 deletions aucmedi/ensemble/metalearner/gaussian_process.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#==============================================================================#
# Author: Dominik Müller #
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, #
# University of Augsburg #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
#==============================================================================#
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# External libraries
import pickle
from sklearn.gaussian_process import GaussianProcessClassifier
import numpy as np
# Internal libraries/scripts
from aucmedi.ensemble.metalearner.ml_base import Metalearner_Base

#-----------------------------------------------------#
# Metalearner: Gaussian Process #
#-----------------------------------------------------#
class Gaussian_Process(Metalearner_Base):
""" A Gaussian Process based Metalearner.
This class should be passed to a Ensemble function like Stacking for combining predictions.
!!! warning
Can only be utilized for binary and multi-class tasks.
Does not work on multi-label annotations!
"""
#---------------------------------------------#
# Initialization #
#---------------------------------------------#
def __init__(self):
self.model = GaussianProcessClassifier(random_state=0,
multi_class="one_vs_rest")

#---------------------------------------------#
# Training #
#---------------------------------------------#
def train(self, x, y):
# Preprocess to sparse encoding
y = np.argmax(y, axis=-1)
# Train model
self.model = self.model.fit(x, y)

#---------------------------------------------#
# Prediction #
#---------------------------------------------#
def predict(self, data):
# Compute prediction probabilities via fitted model
pred = self.model.predict_proba(data)
# Return results as NumPy array
return pred

#---------------------------------------------#
# Dump Model to Disk #
#---------------------------------------------#
def dump(self, path):
# Dump model to disk via pickle
with open(path, "wb") as pickle_writer:
pickle.dump(self.model, pickle_writer)

#---------------------------------------------#
# Load Model from Disk #
#---------------------------------------------#
def load(self, path):
# Load model from disk via pickle
with open(path, "rb") as pickle_reader:
self.model = pickle.load(pickle_reader)
31 changes: 31 additions & 0 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,34 @@ def test_SVM_usage(self):
preds = ml.predict(data=self.pred_data)
# Check
self.assertTrue(np.array_equal(preds.shape, (25,4)))

#-------------------------------------------------#
# Gaussian Process #
#-------------------------------------------------#
def test_Gaussian_Process_create(self):
# Initializations
ml = Gaussian_Process()
self.assertTrue("gaussian_process" in metalearner_dict)
ml = metalearner_dict["gaussian_process"]()
# Storage
model_path = os.path.join(self.tmp_data.name, "ml_model.pickle")
self.assertFalse(os.path.exists(model_path))
ml.dump(model_path)
self.assertTrue(os.path.exists(model_path))
# Loading
ml.model = None
self.assertTrue(ml.model is None)
ml.load(model_path)
self.assertFalse(ml.model is None)
# Cleanup
os.remove(model_path)

def test_Gaussian_Process_usage(self):
# Initializations
ml = Gaussian_Process()
# Training
ml.train(x=self.pred_data, y=self.labels_ohe)
# Inference
preds = ml.predict(data=self.pred_data)
# Check
self.assertTrue(np.array_equal(preds.shape, (25,4)))

0 comments on commit d07b12d

Please sign in to comment.