Skip to content

Commit

Permalink
feat(Aggregate): Added Global Argmax - Contributes to #129
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerdo committed May 22, 2022
1 parent 4c7643f commit 205dd12
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
5 changes: 4 additions & 1 deletion aucmedi/ensemble/aggregate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,15 @@
from aucmedi.ensemble.aggregate.averaging_median import Averaging_Median
from aucmedi.ensemble.aggregate.majority_vote import Majority_Vote
from aucmedi.ensemble.aggregate.softmax import Softmax
from aucmedi.ensemble.aggregate.global_argmax import Global_Argmax

#-----------------------------------------------------#
# Access Functions to Aggregate Functions #
#-----------------------------------------------------#
aggregate_dict = {"mean": Averaging_Mean,
"median": Averaging_Median,
"majority_vote": Majority_Vote,
"softmax": Softmax}
"softmax": Softmax,
"global_argmax": Global_Argmax,
}
""" Dictionary of implemented Aggregate functions. """
58 changes: 58 additions & 0 deletions aucmedi/ensemble/aggregate/global_argmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#==============================================================================#
# Author: Dominik Müller #
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, #
# University of Augsburg #
# #
# This program is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation, either version 3 of the License, or #
# (at your option) any later version. #
# #
# This program is distributed in the hope that it will be useful, #
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
# GNU General Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
#==============================================================================#
#-----------------------------------------------------#
# Library imports #
#-----------------------------------------------------#
# External libraries
import numpy as np
# Internal libraries/scripts
from aucmedi.ensemble.aggregate.agg_base import Aggregate_Base

#-----------------------------------------------------#
# Aggregate: Global Argmax #
#-----------------------------------------------------#
class Global_Argmax(Aggregate_Base):
""" Aggregate function based on Global Argmax.
This class should be passed to a ensemble function/class for combining predictions.
"""
#---------------------------------------------#
# Initialization #
#---------------------------------------------#
def __init__(self):
# No hyperparameter adjustment required for this method, therefore skip
pass

#---------------------------------------------#
# Aggregate #
#---------------------------------------------#
def aggregate(self, preds):
# Identify global argmax
max = np.amax(preds)
argmax_flatten = np.argmax(preds)
argmax = np.unravel_index(argmax_flatten, preds.shape)[-1]

# Compute prediction by global argmax and equally distributed remaining
# probability for other classes
prob_remaining = np.divide(1-max, preds.shape[1]-1)
pred = np.full((preds.shape[1],), fill_value=prob_remaining)
pred[argmax] = max

# Return prediction
return pred

0 comments on commit 205dd12

Please sign in to comment.