-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(Evaluation): Started implementing performance evaluation
- Loading branch information
Showing
3 changed files
with
291 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn.metrics import roc_curve, roc_auc_score | ||
|
||
def compute_metrics(preds, labels, n_labels, threshold=None): | ||
df_list = [] | ||
for c in range(0, n_labels): | ||
# Initialize variables | ||
data_dict = {} | ||
|
||
# Identify truth and prediction for class c | ||
truth = labels[:, c] | ||
if threshold is None: | ||
pred_argmax = np.argmax(preds, axis=-1) | ||
pred = (pred_argmax == c).astype(np.int) | ||
pred_prob = np.max(preds, axis=-1) | ||
else: | ||
pred = np.where(preds[:, c] >= threshold, 1, 0) | ||
pred_prob = preds[:, c] | ||
|
||
# Compute the confusion matrix | ||
tp, tn, fp, fn = compute_CM(truth, pred) | ||
data_dict["TP"] = tp | ||
data_dict["TN"] = tn | ||
data_dict["FP"] = fp | ||
data_dict["FN"] = fn | ||
|
||
# Compute several metrics based on confusion matrix | ||
data_dict["Sensitivity"] = np.divide(tp, tp+fn) | ||
data_dict["Specificity"] = np.divide(tn, tn+fp) | ||
data_dict["Precision"] = np.divide(tp, tp+fp) | ||
data_dict["FPR"] = np.divide(fp, fp+tn) | ||
data_dict["FNR"] = np.divide(fn, fn+tp) | ||
data_dict["FDR"] = np.divide(fp, fp+tp) | ||
data_dict["Accuracy"] = np.divide(tp+tn, tp+tn+fp+fn) | ||
data_dict["F1"] = np.divide(2*tp, 2*tp+fp+fn) | ||
|
||
# Compute area under the ROC curve | ||
data_dict["AUC"] = roc_auc_score(truth, pred_prob) | ||
|
||
# Parse metrics to dataframe | ||
df = pd.DataFrame.from_dict(data_dict, orient="index", | ||
columns=["score"]) | ||
df = df.reset_index() | ||
df.rename(columns={"index": "metric"}, inplace=True) | ||
df["class"] = c | ||
|
||
# Append dataframe to list | ||
df_list.append(df) | ||
|
||
# Combine dataframes | ||
df_final = pd.concat(df_list, axis=0, ignore_index=True) | ||
# Return final dataframe | ||
return df_final | ||
|
||
def compute_confusion_matrix(preds, labels, n_labels): | ||
preds_argmax = np.argmax(preds, axis=-1) | ||
labels_argmax = np.argmax(labels, axis=-1) | ||
rawcm = np.zeros((n_labels, n_labels)) | ||
for i in range(0, labels.shape[0]): | ||
rawcm[labels_argmax[i]][preds_argmax[i]] += 1 | ||
return rawcm | ||
|
||
|
||
def compute_roc(preds, labels, n_labels): | ||
fpr_list = [] | ||
tpr_list = [] | ||
for i in range(0, n_labels): | ||
truth_class = labels[:, i].astype(int) | ||
pdprob_class = preds[:, i] | ||
fpr, tpr, _ = roc_curve(truth_class, pdprob_class) | ||
fpr_list.append(fpr) | ||
tpr_list.append(tpr) | ||
return fpr_list, tpr_list | ||
|
||
|
||
|
||
# Compute confusion matrix | ||
def compute_CM(gt, pd): | ||
tp = 0 | ||
tn = 0 | ||
fp = 0 | ||
fn = 0 | ||
for i in range(0, len(gt)): | ||
if gt[i] == 1 and pd[i] == 1 : tp += 1 | ||
elif gt[i] == 1 and pd[i] == 0 : fn += 1 | ||
elif gt[i] == 0 and pd[i] == 0 : tn += 1 | ||
elif gt[i] == 0 and pd[i] == 1 : fp += 1 | ||
else : print("ERROR at confusion matrix", i) | ||
return tp, tn, fp, fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
#==============================================================================# | ||
# Author: Dominik Müller # | ||
# Copyright: 2022 IT-Infrastructure for Translational Medical Research, # | ||
# University of Augsburg # | ||
# # | ||
# This program is free software: you can redistribute it and/or modify # | ||
# it under the terms of the GNU General Public License as published by # | ||
# the Free Software Foundation, either version 3 of the License, or # | ||
# (at your option) any later version. # | ||
# # | ||
# This program is distributed in the hope that it will be useful, # | ||
# but WITHOUT ANY WARRANTY; without even the implied warranty of # | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # | ||
# GNU General Public License for more details. # | ||
# # | ||
# You should have received a copy of the GNU General Public License # | ||
# along with this program. If not, see <http://www.gnu.org/licenses/>. # | ||
#==============================================================================# | ||
#-----------------------------------------------------# | ||
# Library imports # | ||
#-----------------------------------------------------# | ||
# External Libraries | ||
import numpy as np | ||
import pandas as pd | ||
import os | ||
from plotnine import * | ||
# Internal libraries/scripts | ||
from aucmedi.evaluation.metrics import * | ||
|
||
#-----------------------------------------------------# | ||
# Evaluation - Plot Performance # | ||
#-----------------------------------------------------# | ||
def evaluate_performance(preds, | ||
labels, | ||
out_path, | ||
class_names=None, | ||
multi_label=False, | ||
metrics_threshold = 0.5, | ||
suffix=None, | ||
plot_confusion_matrix=True, | ||
plot_roc_curve=True): | ||
""" Function for automatic performance evaluation plots generation based on model predictions. | ||
Args: | ||
preds (numpy.ndarray): A NumPy array of predictions formatted with shape (n_samples, n_labels). Provided by | ||
[Neural_Network][aucmedi.neural_network.model]. | ||
labels (numpy.ndarray): Classification list with One-Hot Encoding. Provided by | ||
[input_interface][aucmedi.data_processing.io_data.input_interface]. | ||
out_path (str): Path to directory in which plotted figures are stored. | ||
class_names (list of str): List of names for corresponding classes. Used for evaluation. Provided by | ||
[input_interface][aucmedi.data_processing.io_data.input_interface]. | ||
If not provided (`None` provided), class indices will be used. | ||
multi_label (bool): Option, whether task is multi-label based (has impact on evaluation). | ||
metrics_threshold (float): Only required if 'multi_label==True`. Threshold value if prediction is positive. | ||
Used in metric computation for CSV and bar plot. | ||
suffix (str): Special suffix to add in the created figure filename. | ||
plot_barplot (bool): Option, whether to generate a bar plot of various metrics. | ||
plot_confusion_matrix (bool): Option, whether to generate a confusion matrix plot. | ||
plot_roc_curve (bool): Option, whether to generate a ROC curve plot. | ||
""" | ||
# Identify number of labels | ||
n_labels = labels.shape[-1] | ||
# Identify prediction threshold | ||
if multi_label : threshold = metrics_threshold | ||
else : threshold = None | ||
|
||
# Compute metrics | ||
metrics = compute_metrics(preds, labels, n_labels, threshold) | ||
cm = compute_confusion_matrix(preds, labels, n_labels) | ||
fpr_list, tpr_list = compute_roc(preds, labels, n_labels) | ||
|
||
# Generate bar plot | ||
evalby_barplot(metrics, out_path, class_names, suffix) | ||
|
||
# Generate confusion matrix plot | ||
if not multi_label : evalby_confusion_matrix(cm, out_path, class_names, | ||
suffix) | ||
# Generate ROC curve | ||
evalby_rocplot(fpr_list, tpr_list, out_path, class_names, suffix) | ||
|
||
|
||
#-----------------------------------------------------# | ||
# Evaluation Performance - Confusion Matrix # | ||
#-----------------------------------------------------# | ||
def evalby_confusion_matrix(confusion_matrix, out_path, class_names, | ||
suffix=None): | ||
# Convert confusion matrix to a Pandas dataframe | ||
rawcm = pd.DataFrame(confusion_matrix) | ||
# Tidy dataframe | ||
if class_names is None or len(class_names) != confusion_matrix.shape[0]: | ||
class_names = list(range(0, confusion_matrix.shape[0])) | ||
rawcm.index = class_names | ||
rawcm.columns = class_names | ||
|
||
# Preprocess dataframe | ||
dt = rawcm.div(rawcm.sum(axis=0), axis=1) * 100 | ||
dt = dt.round(decimals=2) | ||
dt.reset_index(drop=False, inplace=True) | ||
dt = dt.melt(id_vars=["index"], var_name="gt", value_name="score") | ||
dt.rename(columns={"index": "pd"}, inplace=True) | ||
|
||
# Plot confusion matrix | ||
fig = (ggplot(dt, aes("pd", "gt", fill="score")) | ||
+ geom_tile() | ||
+ geom_text(aes("pd", "gt", label="score"), color="black", | ||
size=28) | ||
+ ggtitle("Performance Evaluation: Confusion Matrix") | ||
+ xlab("Prediction") | ||
+ ylab("Ground Truth") | ||
+ scale_fill_gradient(low="white", high="royalblue", | ||
limits=[0, 100]) | ||
+ theme_bw(base_size=28) | ||
+ theme(axis_text_x = element_text(angle = 45, vjust = 1, | ||
hjust = 1))) | ||
|
||
# Store figure to disk | ||
filename = "plot.performance.confusion_matrix" | ||
if suffix is not None : filename += "." + str(suffix) | ||
filename += ".png" | ||
fig.save(filename=filename, path=out_path, width=10, height=9, dpi=200) | ||
|
||
#-----------------------------------------------------# | ||
# Evaluation Performance - Barplots # | ||
#-----------------------------------------------------# | ||
def evalby_barplot(metrics, out_path, class_names, suffix=None): | ||
# Rename columns | ||
class_mapping = {} | ||
if class_names is not None: | ||
for c in range(len(class_names)): | ||
class_mapping[c] = class_names[c] | ||
metrics["class"].replace(class_mapping, inplace=True) | ||
if class_names is None : metrics["class"] = pd.Categorical(metrics["class"]) | ||
# Remove confusion matrix from metric dataframe | ||
df_metrics = metrics[~metrics["metric"].isin(["TN", "FN", "FP", "TP"])] | ||
|
||
# Plot metric results | ||
fig = (ggplot(df_metrics, aes("class", "score", fill="class")) | ||
+ geom_col(stat='identity', width=0.6, color="black", | ||
position = position_dodge(width=0.6)) | ||
+ ggtitle("Performance Evaluation: Metric Overview") | ||
+ facet_wrap("metric") | ||
+ coord_flip() | ||
+ xlab("") | ||
+ ylab("Score") | ||
+ scale_y_continuous(limits=[0, 1], breaks=np.arange(0, 1.1, 0.1)) | ||
+ scale_fill_discrete(name="Classes") | ||
+ theme_bw()) | ||
|
||
# Store figure to disk | ||
filename = "plot.performance.barplot" | ||
if suffix is not None : filename += "." + str(suffix) | ||
filename += ".png" | ||
fig.save(filename=filename, path=out_path, width=12, height=9, dpi=200) | ||
|
||
#-----------------------------------------------------# | ||
# Evaluation Performance - ROC plot # | ||
#-----------------------------------------------------# | ||
def evalby_rocplot(fpr_list, tpr_list, out_path, class_names, suffix=None): | ||
# Initialize result dataframe | ||
df_roc = pd.DataFrame(data=[fpr_list, tpr_list], dtype=np.float64) | ||
# Preprocess dataframe | ||
df_roc = df_roc.transpose() | ||
df_roc = df_roc.apply(pd.Series.explode) | ||
# Rename columns | ||
class_mapping = {} | ||
if class_names is not None: | ||
for c in range(len(class_names)): | ||
class_mapping[c] = class_names[c] | ||
df_roc.rename(index=class_mapping, inplace=True) | ||
df_roc = df_roc.reset_index() | ||
df_roc.rename(columns={"index": "class", 0: "FPR", 1: "TPR"}, inplace=True) | ||
if class_names is None : df_roc["class"] = pd.Categorical(df_roc["class"]) | ||
# Convert from object to float | ||
df_roc["FPR"] = df_roc["FPR"].astype(float) | ||
df_roc["TPR"] = df_roc["TPR"].astype(float) | ||
|
||
# Plot roc results | ||
fig = (ggplot(df_roc, aes("FPR", "TPR", color="class")) | ||
+ geom_line(size=1.5) | ||
+ geom_abline(intercept=0, slope=1, color="black", | ||
linetype="dashed") | ||
+ ggtitle("Performance Evaluation: ROC Curves") | ||
+ xlab("False Positive Rate") | ||
+ ylab("True Positive Rate") | ||
+ scale_x_continuous(limits=[0, 1], breaks=np.arange(0,1.1,0.1)) | ||
+ scale_y_continuous(limits=[0, 1], breaks=np.arange(0,1.1,0.1)) | ||
+ scale_color_discrete(name="Classes") | ||
+ theme_bw()) | ||
|
||
# Store figure to disk | ||
filename = "plot.performance.roc" | ||
if suffix is not None : filename += "." + str(suffix) | ||
filename += ".png" | ||
fig.save(filename=filename, path=out_path, width=10, height=9, dpi=200) | ||
|
||
#-----------------------------------------------------# | ||
# Evaluation Performance - CSV file # | ||
#-----------------------------------------------------# | ||
def evalby_csv(): | ||
pass |