-
Notifications
You must be signed in to change notification settings - Fork 50
/
compute_metrics.py
170 lines (133 loc) · 5.79 KB
/
compute_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Written by: Erick Cobos T. (a01184587@itesm.mx)
# Date: July 2016
# Modified: October 2016
""" Calculate evaluation metrics in an evaluation set.
Metrics are calculated per-pixel over the entire data set. They are IOU,
F1-score, G-mean, accuracy, sensitivity, specificity, precision and recall.
Example:
$ python3 compute_metrics.py model_dir csv_path
where model_dir is the name of the folder where the checkpoint is and
csv_path is the path to the csv with image, label filenames.
"""
import tensorflow as tf
import numpy as np
import scipy.misc
import csv
import sys
import os.path
# Import network definition
import model_v4 as model
# Set some parameters
DATA_DIR = "data" # directory with evaluation data
THRESHOLD_PROB = 0.5 # probability threshold used to generate segmentations
def post(logits, label, threshold):
""" Creates segmentation assigning every pixel above the threshold a value
of 255, pixels where the label==0 to 0 and anything else to 127.
Args:
logits: An array of floats with shape [height, width]. A heatmap of
logits representing the probability of mass at each pixel.
label: An array of integers with shape [height, width]. The original
label used to segment the background from the rest of the image.
threshold: A float. The logit threshold used for the segmentation.
Returns:
An array of integers with shape [height, shape]. The produced
segmentation with labels 0 (background), 127 (breast tissue) and 255
(breast mass)
Note: Using the label may seem like cheating but the label background was
generated by thresholding the original image to zero, so (label == 0) is
equivalent to comparing the original mammogram to zero.
"""
thresholded = np.ones(logits.shape, dtype='uint8') * 127
thresholded[logits >= threshold] = 255
thresholded[label == 0] = 0
return thresholded
def compute_confusion_matrix(segmentation, label):
"""Confusion matrix for a mammogram: # of pixels in each category."""
# Confusion matrix (only over breast area)
true_positive = np.sum(np.logical_and(segmentation == 255, label == 255))
false_positive = np.sum(np.logical_and(segmentation == 255, label != 255))
true_negative = np.sum(np.logical_and(segmentation == 127, label == 127))
false_negative = np.sum(np.logical_and(segmentation == 127, label != 127))
cm_values = [true_positive, false_positive, true_negative, false_negative]
return np.array(cm_values)
def compute_metrics(true_positive, false_positive, true_negative, false_negative):
"""Array with different metrics from the given confusion matrix values."""
epsilon = 1e-7 # To avoid division by zero
# Evaluation metrics
accuracy = (true_positive + true_negative) / (true_positive + true_negative
+ false_positive + false_negative + epsilon)
sensitivity = true_positive / (true_positive + false_negative + epsilon)
specificity = true_negative / (false_positive + true_negative + epsilon)
precision = true_positive / (true_positive + false_positive + epsilon)
recall = sensitivity
iou = true_positive / (true_positive + false_positive + false_negative +
epsilon)
f1 = (2 * precision * recall) / (precision + recall + epsilon)
g_mean = np.sqrt(sensitivity * specificity)
metrics = [iou, f1, g_mean, accuracy, sensitivity, specificity, precision,
recall]
return np.array(metrics)
def main(data_dir=DATA_DIR, threshold_prob=THRESHOLD_PROB):
""" Loads network, reads image and returns mean metrics."""
# Model directory and path to the csv passed as arguments
model_dir = sys.argv[1]
csv_path = sys.argv[2]
# Read csv file
with open(csv_path) as f:
lines = f.read().splitlines()
csv_reader = csv.reader(lines)
# Image as placeholder
image = tf.placeholder(tf.float32, name='image')
whitened = tf.image.per_image_whitening(tf.expand_dims(image, 2))
# Define the model
prediction = model.forward(whitened, drop=tf.constant(False))
# Get a saver to load the model
saver = tf.train.Saver()
# Use CPU-only. To enable GPU, delete this and call with tf.Session() as ...
config = tf.ConfigProto(device_count={'GPU':0})
# Initialize some variables
confusion_matrix = np.zeros(4) # tp, fp, tn, fn
# Launch the graph
with tf.Session() as sess:
# Restore variables
checkpoint_path = tf.train.latest_checkpoint(model_dir)
print("Restoring model from:", checkpoint_path)
saver.restore(sess, checkpoint_path)
# Get logit threshold from probability
threshold = np.log(threshold_prob) - np.log(1 - threshold_prob)
print("Threshold: {} ({})".format(threshold, threshold_prob))
# For every example
for row in csv_reader:
# Read paths
image_path = os.path.join(data_dir, row[0])
label_path = os.path.join(data_dir, row[1])
# Read image and label
im = scipy.misc.imread(image_path)
label = scipy.misc.imread(label_path)
# Get prediction
logits = prediction.eval({image: im})
# Post-process prediction
segmentation = post(logits, label, threshold)
# Accumulate confusion matrix values
confusion_matrix += compute_confusion_matrix(segmentation, label)
# Calculate metrics
metrics = compute_metrics(*confusion_matrix)
# Report metrics
metric_names = ['IOU', 'F1-score', 'G-mean', 'Accuracy',
'Sensitivity', 'Specificity', 'Precision', 'Recall']
for name, metric in zip(metric_names, metrics):
print("{}: {}".format(name, metric))
print('')
# Logistic loss
label = tf.placeholder(tf.uint8, name='label')
loss = model.loss(prediction, label)
csv_reader = csv.reader(lines)
loss_accum = 0
for row in csv_reader:
im = scipy.misc.imread(os.path.join(data_dir, row[0]))
lbl = scipy.misc.imread(os.path.join(data_dir, row[1]))
loss_accum += loss.eval({image:im, label:lbl})
print("Logistic loss: ", loss_accum/csv_reader.line_num)
return metrics, metric_names
if __name__ == "__main__":
main()