Skip to content

Commit

Permalink
feat(Metalearner): Added MLP Neural Network - Contributes to #129
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed May 22, 2022
1 parent 817e02d commit b66de88
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aucmedi/ensemble/metalearner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from aucmedi.ensemble.metalearner.averaging_mean_weighted import Averaging_WeightedMean
from aucmedi.ensemble.metalearner.random_forest import Random_Forest
from aucmedi.ensemble.metalearner.k_neighbors import KNearestNeighbors
from aucmedi.ensemble.metalearner.mlp_neural_network import MLP_NeuralNetwork

#-----------------------------------------------------#
# Access Functions to Metalearners #
Expand All @@ -68,5 +69,6 @@
"weighted_mean": Averaging_WeightedMean,
"random_forest": Random_Forest,
"k_neighbors": KNearestNeighbors,
"mlp_neural_network": MLP_NeuralNetwork,
}
""" Dictionary of implemented Metalearners. """
82 changes: 82 additions & 0 deletions aucmedi/ensemble/metalearner/mlp_neural_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#==============================================================================#
# Author: Dominik Müller #
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, #
# University of Augsburg #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
#==============================================================================#
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# External libraries
import pickle
from sklearn.neural_network import MLPClassifier
import numpy as np
# Internal libraries/scripts
from aucmedi.ensemble.metalearner.ml_base import Metalearner_Base

#-----------------------------------------------------#
# Metalearner: MLP Neural Network #
#-----------------------------------------------------#
class MLP_NeuralNetwork(Metalearner_Base):
""" A MLP Neural Network (scikit-learn) based Metalearner.
This class should be passed to a Ensemble function like Stacking for combining predictions.
!!! info
Can be utilized for binary, multi-class and multi-label tasks.
???+ abstract "Reference - Implementation"
https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html
Scikit-learn: Machine Learning in Python, Pedregosa et al., JMLR 12, pp. 2825-2830, 2011.
https://jmlr.csail.mit.edu/papers/v12/pedregosa11a.html
"""
#---------------------------------------------#
# Initialization #
#---------------------------------------------#
def __init__(self):
self.model = MLPClassifier(random_state=0)

#---------------------------------------------#
# Training #
#---------------------------------------------#
def train(self, x, y):
# Train model
self.model = self.model.fit(x, y)

#---------------------------------------------#
# Prediction #
#---------------------------------------------#
def predict(self, data):
# Compute prediction probabilities via fitted model
pred = self.model.predict_proba(data)
# Return results as NumPy array
return pred

#---------------------------------------------#
# Dump Model to Disk #
#---------------------------------------------#
def dump(self, path):
# Dump model to disk via pickle
with open(path, "wb") as pickle_writer:
pickle.dump(self.model, pickle_writer)

#---------------------------------------------#
# Load Model from Disk #
#---------------------------------------------#
def load(self, path):
# Load model from disk via pickle
with open(path, "rb") as pickle_reader:
self.model = pickle.load(pickle_reader)
31 changes: 31 additions & 0 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,34 @@ def test_KNearestNeighbors_usage(self):
preds = ml.predict(data=self.pred_data)
# Check
self.assertTrue(np.array_equal(preds.shape, (25,4)))

#-------------------------------------------------#
# MLP Neural Network #
#-------------------------------------------------#
def test_MLP_NeuralNetwork_create(self):
# Initializations
ml = MLP_NeuralNetwork()
self.assertTrue("mlp_neural_network" in metalearner_dict)
ml = metalearner_dict["mlp_neural_network"]()
# Storage
model_path = os.path.join(self.tmp_data.name, "ml_model.pickle")
self.assertFalse(os.path.exists(model_path))
ml.dump(model_path)
self.assertTrue(os.path.exists(model_path))
# Loading
ml.model = None
self.assertTrue(ml.model is None)
ml.load(model_path)
self.assertFalse(ml.model is None)
# Cleanup
os.remove(model_path)

def test_MLP_NeuralNetwork_usage(self):
# Initializations
ml = MLP_NeuralNetwork()
# Training
ml.train(x=self.pred_data, y=self.labels_ohe)
# Inference
preds = ml.predict(data=self.pred_data)
# Check
self.assertTrue(np.array_equal(preds.shape, (25,4)))

0 comments on commit b66de88

Please sign in to comment.