Skip to content

Commit

Permalink
feat(Metalearner): Implemented abstract base class for Metalearner
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed May 21, 2022
1 parent 7014130 commit cdf9f54
Showing 1 changed file with 32 additions and 24 deletions.
56 changes: 32 additions & 24 deletions aucmedi/ensemble/metalearner/ml_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
class Metalearner_Base(ABC):
""" An abstract base class for a Metalearner class.
Metaleaner are similar to [aggregate functions][aucmedi.ensemble.aggregate],
however metaleaners are models which are trained before usage.
Metalearner are similar to [aggregate functions][aucmedi.ensemble.aggregate],
however Metalearners are models which are trained before usage.
Metaleaners are utilized in [Stacking][aucmedi.ensemble.stacking] pipelines.
Metalearners are utilized in [Stacking][aucmedi.ensemble.stacking] pipelines.
A Metaleaner act as combiner algorithm which is trained to make a final prediction
A Metalearner act as combiner algorithm which is trained to make a final prediction
using predictions of the other algorithms (Neural_Networks) as inputs.
```
Expand All @@ -48,51 +48,47 @@ class Metalearner_Base(ABC):
-> shape (1, 3)
```
???+ example "Create a custom Metaleaner class"
```python
from aucmedi.ensemble.aggregate.agg_base import Aggregate_Base
class My_custom_Aggregate(Aggregate_Base):
def __init__(self): # you can pass class variables here
pass
def aggregate(self, preds):
preds_combined = np.mean(preds, axis=0) # do some combination operation
return preds_combined # return combined predictions
```
!!! info "Required Functions"
| Function | Description |
| ------------------- | ---------------------------------------------------------- |
| `__init__()` | Object creation function. |
| `training()` | Fit Metaleaner model. |
| `training()` | Fit Metalearner model. |
| `prediction()` | Merge multiple class predictions into a single prediction. |
| `dump()` | Store Metaleaner model to disk. |
| `load()` | Load Metaleaner model from disk. |
| `dump()` | Store Metalearner model to disk. |
| `load()` | Load Metalearner model from disk. |
"""
#---------------------------------------------#
# Initialization #
#---------------------------------------------#
@abstractmethod
def __init__(self):
""" Initialization function which will be called during the Aggregation object creation.
""" Initialization function which will be called during the Metalearner object creation.
This function can be used to pass variables and options in the Aggregation instance.
This function can be used to pass variables and options in the Metalearner instance.
The are no mandatory required parameters for the initialization.
"""
pass

#---------------------------------------------#
# Training #
#---------------------------------------------#
def training(self, train_x, train_y):
@abstractmethod
def training(self, x, y):
""" Training function to fit the Metalearner model.
Args:
x (numpy.ndarray): Ensembled predictions encoded in a NumPy Matrix with shape (N_models, N_classes).
y (numpy.ndarray): Classification list with One-Hot Encoding. Provided by
[input_interface][aucmedi.data_processing.io_data.input_interface].
"""
pass

#---------------------------------------------#
# Prediction #
#---------------------------------------------#
@abstractmethod
def prediction(self, data):
""" Aggregate the image by merging multiple predictions into a single one.
""" Merge multiple predictions into a single one.
It is required to return the merged predictions (as NumPy matrix).
It is possible to pass configurations through the initialization function for this class.
Expand All @@ -107,11 +103,23 @@ def prediction(self, data):
#---------------------------------------------#
# Dump Model to Disk #
#---------------------------------------------#
@abstractmethod
def dump(self, path):
""" Store metalearner model to disk.
Args:
file_path (str): Path to store the model on disk.
"""
pass

#---------------------------------------------#
# Load Model from Disk #
#---------------------------------------------#
@abstractmethod
def load(self, path):
""" Load metalearner model and its weights from a file.
Args:
file_path (str): Input path, from which the model will be loaded.
"""
pass

0 comments on commit cdf9f54

Please sign in to comment.