Skip to content

Commit

Permalink
feat(Metalearner): Added k-Nearest Neighbors - Contributes to #129
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed May 22, 2022
1 parent be429f6 commit 817e02d
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aucmedi/ensemble/metalearner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from aucmedi.ensemble.metalearner.best_model import Best_Model
from aucmedi.ensemble.metalearner.averaging_mean_weighted import Averaging_WeightedMean
from aucmedi.ensemble.metalearner.random_forest import Random_Forest
from aucmedi.ensemble.metalearner.k_neighbors import KNearestNeighbors

#-----------------------------------------------------#
# Access Functions to Metalearners #
Expand All @@ -66,5 +67,6 @@
"best_model": Best_Model,
"weighted_mean": Averaging_WeightedMean,
"random_forest": Random_Forest,
"k_neighbors": KNearestNeighbors,
}
""" Dictionary of implemented Metalearners. """
85 changes: 85 additions & 0 deletions aucmedi/ensemble/metalearner/k_neighbors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#==============================================================================#
# Author: Dominik Müller #
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, #
# University of Augsburg #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
#==============================================================================#
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# External libraries
import pickle
from sklearn.neighbors import KNeighborsClassifier
import numpy as np
# Internal libraries/scripts
from aucmedi.ensemble.metalearner.ml_base import Metalearner_Base

#-----------------------------------------------------#
# Metalearner: k-Nearest Neighbors #
#-----------------------------------------------------#
class KNearestNeighbors(Metalearner_Base):
""" A k-Nearest Neighbors based Metalearner.
This class should be passed to a Ensemble function like Stacking for combining predictions.
!!! info
Can be utilized for binary, multi-class and multi-label tasks.
???+ abstract "Reference - Implementation"
https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KNeighborsClassifier.html
Scikit-learn: Machine Learning in Python, Pedregosa et al., JMLR 12, pp. 2825-2830, 2011.
https://jmlr.csail.mit.edu/papers/v12/pedregosa11a.html
"""
#---------------------------------------------#
# Initialization #
#---------------------------------------------#
def __init__(self):
self.model = KNeighborsClassifier()

#---------------------------------------------#
# Training #
#---------------------------------------------#
def train(self, x, y):
# Train model
self.model = self.model.fit(x, y)

#---------------------------------------------#
# Prediction #
#---------------------------------------------#
def predict(self, data):
# Compute prediction probabilities via fitted model
pred = self.model.predict_proba(data)
# Postprocess decision tree predictions
pred = np.asarray(pred)
pred = np.swapaxes(pred[:,:,1], 0, 1)
# Return results as NumPy array
return pred

#---------------------------------------------#
# Dump Model to Disk #
#---------------------------------------------#
def dump(self, path):
# Dump model to disk via pickle
with open(path, "wb") as pickle_writer:
pickle.dump(self.model, pickle_writer)

#---------------------------------------------#
# Load Model from Disk #
#---------------------------------------------#
def load(self, path):
# Load model from disk via pickle
with open(path, "rb") as pickle_reader:
self.model = pickle.load(pickle_reader)
31 changes: 31 additions & 0 deletions tests/test_metalearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,34 @@ def test_Random_Forest_usage(self):
preds = ml.predict(data=self.pred_data)
# Check
self.assertTrue(np.array_equal(preds.shape, (25,4)))

#-------------------------------------------------#
# k-Nearest Neighbors #
#-------------------------------------------------#
def test_KNearestNeighbors_create(self):
# Initializations
ml = KNearestNeighbors()
self.assertTrue("k_neighbors" in metalearner_dict)
ml = metalearner_dict["k_neighbors"]()
# Storage
model_path = os.path.join(self.tmp_data.name, "ml_model.pickle")
self.assertFalse(os.path.exists(model_path))
ml.dump(model_path)
self.assertTrue(os.path.exists(model_path))
# Loading
ml.model = None
self.assertTrue(ml.model is None)
ml.load(model_path)
self.assertFalse(ml.model is None)
# Cleanup
os.remove(model_path)

def test_KNearestNeighbors_usage(self):
# Initializations
ml = KNearestNeighbors()
# Training
ml.train(x=self.pred_data, y=self.labels_ohe)
# Inference
preds = ml.predict(data=self.pred_data)
# Check
self.assertTrue(np.array_equal(preds.shape, (25,4)))

0 comments on commit 817e02d

Please sign in to comment.