Skip to content

Commit

Permalink
feat(Metalearner): Implemented Metalearner Logistic Regression
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed May 22, 2022
1 parent 1e7056e commit 633af82
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 30 deletions.
38 changes: 11 additions & 27 deletions aucmedi/ensemble/metalearner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
#-----------------------------------------------------#
# Documentation #
#-----------------------------------------------------#
""" Library of implemented Aggregate functions in AUCMEDI.
""" Library of implemented Metalearners in AUCMEDI.
An Aggregate function can be passed to an Ensemble and merges multiple class predictions
into a single prediction.
A Metalearner can be passed to an Ensemble like Stacking and merges multiple class
predictions into a single prediction.
```
Ensembled predictions encoded in a NumPy Matrix with shape (N_models, N_classes).
Expand All @@ -36,34 +36,18 @@
-> shape (1, 3)
```
???+ example "Example"
```python
# Recommended: Apply an Ensemble like Augmenting (test-time augmentation) with Majority Vote
preds = predict_augmenting(model, test_datagen, n_cycles=5, aggregate="majority_vote")
# Manual: Apply an Ensemble like Augmenting (test-time augmentation) with Majority Vote
from aucmedi.ensemble.aggregate import Majority_Vote
my_agg = Majority_Vote()
preds = predict_augmenting(model, test_datagen, n_cycles=5, aggregate=my_agg)
```
Aggregate functions are based on the abstract base class [Aggregate_Base][aucmedi.ensemble.aggregate.agg_base.Aggregate_Base],
which allow simple integration of custom aggregate methods for Ensemble.
Metalearners are based on the abstract base class [Metalearner_Base][aucmedi.ensemble.metalearner.ml_base],
which allow simple integration of custom Metalearners for Ensemble.
"""
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# Import aggregate functions
from aucmedi.ensemble.aggregate.averaging_mean import Averaging_Mean
from aucmedi.ensemble.aggregate.averaging_median import Averaging_Median
from aucmedi.ensemble.aggregate.majority_vote import Majority_Vote
from aucmedi.ensemble.aggregate.softmax import Softmax
# Import metalearners
from aucmedi.ensemble.metalearner.logistic_regression import Logistic_Regression

#-----------------------------------------------------#
# Access Functions to Aggregate Functions #
# Access Functions to Metalearners #
#-----------------------------------------------------#
aggregate_dict = {"mean": Averaging_Mean,
"median": Averaging_Median,
"majority_vote": Majority_Vote,
"softmax": Softmax}
""" Dictionary of implemented Aggregate functions. """
metalearner_dict = {"logistic_regression": Logistic_Regression,
}
""" Dictionary of implemented Metalearners. """
82 changes: 82 additions & 0 deletions aucmedi/ensemble/metalearner/logistic_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#==============================================================================#
# Author: Dominik Müller #
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, #
# University of Augsburg #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
#==============================================================================#
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# External libraries
import pickle
from sklearn.linear_model import LogisticRegression
import numpy as np
# Internal libraries/scripts
from aucmedi.ensemble.metalearner.ml_base import Metalearner_Base

#-----------------------------------------------------#
# Metalearner: Logistic Regression #
#-----------------------------------------------------#
class Logistic_Regression(Metalearner_Base):
""" A Logistic Regression based Metalearner.
This class should be passed to a Ensemble function like Stacking for combining predictions.
!!! warning
Can only be utilized for binary and multi-class tasks.
Does not work on multi-label annotations!
"""
#---------------------------------------------#
# Initialization #
#---------------------------------------------#
def __init__(self):
self.model = LogisticRegression(random_state=0,
solver="newton-cg",
multi_class="multinomial")

#---------------------------------------------#
# Training #
#---------------------------------------------#
def training(self, x, y):
# Preprocess to sparse encoding
y = np.argmax(y, axis=-1)
# Train model
self.model = self.model.fit(x, y)

#---------------------------------------------#
# Prediction #
#---------------------------------------------#
def prediction(self, data):
# Compute prediction probabilities via fitted model
pred = self.model.predict_proba(data)
# Return results as NumPy array
return pred

#---------------------------------------------#
# Dump Model to Disk #
#---------------------------------------------#
def dump(self, path):
# Dump model to disk via pickle
with open(path, "wb") as pickle_writer:
pickle.dump(self.model, pickle_writer)

#---------------------------------------------#
# Load Model from Disk #
#---------------------------------------------#
def load(self, path):
# Load model from disk via pickle
with open(path, "rb") as pickle_reader:
self.model = pickle.load(pickle_reader)
6 changes: 3 additions & 3 deletions aucmedi/ensemble/metalearner/ml_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def prediction(self, data):
It is possible to pass configurations through the initialization function for this class.
Args:
preds (numpy.ndarray): Ensembled predictions encoded in a NumPy Matrix with shape (N_models, N_classes).
data (numpy.ndarray): Ensembled predictions encoded in a NumPy Matrix with shape (N_models, N_classes).
Returns:
pred (numpy.ndarray): Merged prediction encoded in a NumPy Matrix with shape (1, N_classes).
"""
Expand All @@ -108,7 +108,7 @@ def dump(self, path):
""" Store metalearner model to disk.
Args:
file_path (str): Path to store the model on disk.
path (str): Path to store the model on disk.
"""
pass

Expand All @@ -120,6 +120,6 @@ def load(self, path):
""" Load metalearner model and its weights from a file.
Args:
file_path (str): Input path, from which the model will be loaded.
path (str): Input path, from which the model will be loaded.
"""
pass

0 comments on commit 633af82

Please sign in to comment.