-
Notifications
You must be signed in to change notification settings - Fork 0
/
result.py
123 lines (91 loc) · 3.37 KB
/
result.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
def class_level_accuracy(model, loader, device, classes):
"""Print test accuracy for each class in dataset.
Args:
model: Model instance.
loader: Data loader.
device: Device where data will be loaded.
classes: List of classes in the dataset.
"""
class_correct = list(0. for i in range(len(classes)))
class_total = list(0. for i in range(len(classes)))
with torch.no_grad():
for _, (images, labels) in enumerate(loader, 0):
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(len(labels)):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
def plot_metric(values, metric):
# Initialize a figure
fig = plt.figure(figsize=(7, 5))
# Plot values
plt.plot(values)
# Set plot title
plt.title(f'Validation {metric}')
# Label axes
plt.xlabel('Epoch')
plt.ylabel(metric)
# Set legend
location = 'upper' if metric == 'Loss' else 'lower'
# Save plot
fig.savefig(f'{metric.lower()}_change.png')
def plot_predictions(data, classes, plot_title, plot_path):
"""Display data.
Args:
data: List of images, model predictions and ground truths.
classes: List of classes in the dataset.
plot_title: Title for the plot.
plot_path: Complete path for saving the plot.
"""
# Initialize plot
row_count = -1
fig, axs = plt.subplots(5, 5, figsize=(10, 10))
fig.suptitle(plot_title)
for idx, result in enumerate(data):
# If 25 samples have been stored, break out of loop
if idx > 24:
break
rgb_image = np.transpose(result['image'], (1, 2, 0)) / 2 + 0.5
label = result['label'].item()
prediction = result['prediction'].item()
# Plot image
if idx % 5 == 0:
row_count += 1
axs[row_count][idx % 5].axis('off')
axs[row_count][idx % 5].set_title(f'Label: {classes[label]}\nPrediction: {classes[prediction]}')
axs[row_count][idx % 5].imshow(rgb_image)
# Set spacing
fig.tight_layout()
fig.subplots_adjust(top=0.88)
# Save image
fig.savefig(f'{plot_path}', bbox_inches='tight')
def save_and_show_result(correct_pred, incorrect_pred, classes):
"""Display network predictions.
Args:
correct_pred: Contains correct model predictions and labels.
incorrect_pred: Contains incorrect model predictions and labels.
classes: List of classes in the dataset.
"""
# Create directories for saving predictions
path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'predictions'
)
if not os.path.exists(path):
os.makedirs(path)
# Plot correct predicitons
plot_predictions(
correct_pred, classes, 'Correct Predictions', f'{path}/correct_predictions.png'
)
# Plot incorrect predicitons
plot_predictions(
incorrect_pred, classes, '\nIncorrect Predictions', f'{path}/incorrect_predictions.png'
)