Skip to content

Commit

Permalink
feat(Metalearner): Implemented Metalearner: Naive Bayes
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed May 22, 2022
1 parent 66dedc6 commit 726c6f2
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 24 deletions.
2 changes: 2 additions & 0 deletions aucmedi/ensemble/metalearner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@
#-----------------------------------------------------#
# Import metalearners
from aucmedi.ensemble.metalearner.logistic_regression import Logistic_Regression
from aucmedi.ensemble.metalearner.naive_bayes import Naive_Bayes

#-----------------------------------------------------#
# Access Functions to Metalearners #
#-----------------------------------------------------#
metalearner_dict = {"logistic_regression": Logistic_Regression,
"naive_bayes": Naive_Bayes,
}
""" Dictionary of implemented Metalearners. """
80 changes: 80 additions & 0 deletions aucmedi/ensemble/metalearner/naive_bayes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#==============================================================================#
# Author: Dominik Müller #
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, #
# University of Augsburg #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
#==============================================================================#
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# External libraries
import pickle
from sklearn.naive_bayes import ComplementNB
import numpy as np
# Internal libraries/scripts
from aucmedi.ensemble.metalearner.ml_base import Metalearner_Base

#-----------------------------------------------------#
# Metalearner: Naive Bayes #
#-----------------------------------------------------#
class Naive_Bayes(Metalearner_Base):
""" A Naive Bayes based Metalearner.
This class should be passed to a Ensemble function like Stacking for combining predictions.
!!! warning
Can only be utilized for binary and multi-class tasks.
Does not work on multi-label annotations!
"""
#---------------------------------------------#
# Initialization #
#---------------------------------------------#
def __init__(self):
self.model = ComplementNB()

#---------------------------------------------#
# Training #
#---------------------------------------------#
def train(self, x, y):
# Preprocess to sparse encoding
y = np.argmax(y, axis=-1)
# Train model
self.model = self.model.fit(x, y)

#---------------------------------------------#
# Prediction #
#---------------------------------------------#
def predict(self, data):
# Compute prediction probabilities via fitted model
pred = self.model.predict_proba(data)
# Return results as NumPy array
return pred

#---------------------------------------------#
# Dump Model to Disk #
#---------------------------------------------#
def dump(self, path):
# Dump model to disk via pickle
with open(path, "wb") as pickle_writer:
pickle.dump(self.model, pickle_writer)

#---------------------------------------------#
# Load Model from Disk #
#---------------------------------------------#
def load(self, path):
# Load model from disk via pickle
with open(path, "rb") as pickle_reader:
self.model = pickle.load(pickle_reader)
55 changes: 31 additions & 24 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,35 @@ def test_Logistic_Regression_usage(self):
# Inference
preds = ml.predict(data=self.pred_data)
# Check
self.assertTrue(np.array_equal(pred.shape, (25,4)))
self.assertTrue(np.array_equal(preds.shape, (25,4)))

# #-------------------------------------------------#
# # Aggregate: Averaging by Median #
# #-------------------------------------------------#
# def test_Aggregate_Median(self):
# agg_func = Averaging_Median()
# pred = agg_func.aggregate(self.pred_data.copy())
# self.assertTrue(np.array_equal(pred.shape, (10,)))
#
# #-------------------------------------------------#
# # Aggregate: Majority Vote (Hard) #
# #-------------------------------------------------#
# def test_Aggregate_MajorityVote(self):
# agg_func = Majority_Vote()
# pred = agg_func.aggregate(self.pred_data.copy())
# self.assertTrue(np.array_equal(pred.shape, (10,)))
#
# #-------------------------------------------------#
# # Aggregate: Softmax #
# #-------------------------------------------------#
# def test_Aggregate_Softmax(self):
# agg_func = Softmax()
# pred = agg_func.aggregate(self.pred_data.copy())
# self.assertTrue(np.array_equal(pred.shape, (10,)))
#-------------------------------------------------#
# Naive Bayes #
#-------------------------------------------------#
def test_Naive_Bayes_create(self):
# Initializations
ml = Naive_Bayes()
self.assertTrue("naive_bayes" in metalearner_dict)
ml = metalearner_dict["naive_bayes"]()
# Storage
model_path = os.path.join(self.tmp_data.name, "ml_model.pickle")
self.assertFalse(os.path.exists(model_path))
ml.dump(model_path)
self.assertTrue(os.path.exists(model_path))
# Loading
ml.model = None
self.assertTrue(ml.model is None)
ml.load(model_path)
self.assertFalse(ml.model is None)
# Cleanup
os.remove(model_path)

def test_Naive_Bayes_usage(self):
# Initializations
ml = Naive_Bayes()
# Training
ml.train(x=self.pred_data, y=self.labels_ohe)
# Inference
preds = ml.predict(data=self.pred_data)
# Check
self.assertTrue(np.array_equal(preds.shape, (25,4)))

0 comments on commit 726c6f2

Please sign in to comment.