Skip to content

Commit

Permalink
feat(performance): add display option for charts
Browse files Browse the repository at this point in the history
  • Loading branch information
SherlockMones committed Jun 30, 2022
1 parent c31355d commit 7f66f48
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions aucmedi/evaluation/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
def evaluate_performance(preds,
labels,
out_path,
show=False,
class_names=None,
multi_label=False,
metrics_threshold=0.5,
Expand Down Expand Up @@ -99,6 +100,7 @@ def evaluate_performance(preds,
labels (numpy.ndarray): Classification list with One-Hot Encoding. Provided by
[input_interface][aucmedi.data_processing.io_data.input_interface].
out_path (str): Path to directory in which plotted figures are stored.
show (bool): Option, whether to also display the generated charts.
class_names (list of str): List of names for corresponding classes. Used for evaluation. Provided by
[input_interface][aucmedi.data_processing.io_data.input_interface].
If not provided (`None` provided), class indices will be used.
Expand Down Expand Up @@ -135,19 +137,19 @@ def evaluate_performance(preds,

# Store metrics to CSV
if store_csv:
evalby_csv(metrics, out_path, class_names, suffix)
evalby_csv(metrics, out_path, class_names, suffix=suffix)

# Generate bar plot
if plot_barplot:
evalby_barplot(metrics, out_path, class_names, suffix)
evalby_barplot(metrics, out_path, class_names, show=show, suffix=suffix)

# Generate confusion matrix plot
if plot_confusion_matrix and not multi_label:
evalby_confusion_matrix(cm, out_path, class_names, suffix)
evalby_confusion_matrix(cm, out_path, class_names, show=show, suffix=suffix)

# Generate ROC curve
if plot_roc_curve:
evalby_rocplot(fpr_list, tpr_list, out_path, class_names, suffix)
evalby_rocplot(fpr_list, tpr_list, out_path, class_names, show=show, suffix=suffix)

# Return metrics
return metrics
Expand All @@ -156,7 +158,9 @@ def evaluate_performance(preds,
# Evaluation Performance - Confusion Matrix #
#-----------------------------------------------------#
def evalby_confusion_matrix(confusion_matrix, out_path, class_names,
show=False,
suffix=None):

# Convert confusion matrix to a Pandas dataframe
rawcm = pd.DataFrame(confusion_matrix)
# Tidy dataframe
Expand All @@ -172,7 +176,7 @@ def evalby_confusion_matrix(confusion_matrix, out_path, class_names,
dt = dt.melt(id_vars=["index"], var_name="gt", value_name="score")
dt.rename(columns={"index": "pd"}, inplace=True)

# Plot confusion matrix
# Generate confusion matrix
fig = (ggplot(dt, aes("pd", "gt", fill="score"))
+ geom_tile()
+ geom_text(aes("pd", "gt", label="score"), color="black")
Expand All @@ -191,15 +195,18 @@ def evalby_confusion_matrix(confusion_matrix, out_path, class_names,
filename += ".png"
fig.save(filename=filename, path=out_path, width=10, height=9, dpi=200)

# Plot figure
if show : print(fig)

#-----------------------------------------------------#
# Evaluation Performance - Barplots #
#-----------------------------------------------------#
def evalby_barplot(metrics, out_path, class_names, suffix=None):
def evalby_barplot(metrics, out_path, class_names, show=False, suffix=None):
# Remove confusion matrix from metric dataframe
df_metrics = metrics[~metrics["metric"].isin(["TN", "FN", "FP", "TP"])]
df_metrics["class"] = pd.Categorical(df_metrics["class"])

# Plot metric results
# Generate metric results
fig = (ggplot(df_metrics, aes("class", "score", fill="class"))
+ geom_col(stat='identity', width=0.6, color="black",
position = position_dodge(width=0.6))
Expand All @@ -218,10 +225,13 @@ def evalby_barplot(metrics, out_path, class_names, suffix=None):
filename += ".png"
fig.save(filename=filename, path=out_path, width=12, height=9, dpi=200)

# Plot figure
if show : print(fig)

#-----------------------------------------------------#
# Evaluation Performance - ROC plot #
#-----------------------------------------------------#
def evalby_rocplot(fpr_list, tpr_list, out_path, class_names, suffix=None):
def evalby_rocplot(fpr_list, tpr_list, out_path, class_names, show=False, suffix=None):
# Initialize result dataframe
df_roc = pd.DataFrame(data=[fpr_list, tpr_list], dtype=np.float64)
# Preprocess dataframe
Expand All @@ -240,7 +250,7 @@ def evalby_rocplot(fpr_list, tpr_list, out_path, class_names, suffix=None):
df_roc["FPR"] = df_roc["FPR"].astype(float)
df_roc["TPR"] = df_roc["TPR"].astype(float)

# Plot roc results
# Generate roc results
fig = (ggplot(df_roc, aes("FPR", "TPR", color="class"))
+ geom_line(size=1.0)
+ geom_abline(intercept=0, slope=1, color="black",
Expand All @@ -259,6 +269,9 @@ def evalby_rocplot(fpr_list, tpr_list, out_path, class_names, suffix=None):
filename += ".png"
fig.save(filename=filename, path=out_path, width=10, height=9, dpi=200)

# Plot figure
if show : print(fig)

#-----------------------------------------------------#
# Evaluation Performance - CSV file #
#-----------------------------------------------------#
Expand All @@ -271,3 +284,4 @@ def evalby_csv(metrics, out_path, class_names, suffix=None):

# Store file to disk
metrics.to_csv(path_csv, index=False)

0 comments on commit 7f66f48

Please sign in to comment.