-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_utils.py
174 lines (157 loc) · 5.41 KB
/
custom_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# import albumentations as A
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
# from albumentations.pytorch import ToTensorV2
from config import DEVICE, CLASSES
plt.style.use('ggplot')
# This class keeps track of the training and validation loss values
# and helps to get the average for each epoch as well.
class Averager:
def __init__(self):
self.current_total = 0.0
self.iterations = 0.0
def send(self, value):
self.current_total += value
self.iterations += 1
@property
def value(self):
if self.iterations == 0:
return 0
else:
return 1.0 * self.current_total / self.iterations
def reset(self):
self.current_total = 0.0
self.iterations = 0.0
class SaveBestModel:
"""
Class to save the best model while training. If the current epoch's
validation mAP @0.5:0.95 IoU higher than the previous highest, then save the
model state.
"""
def __init__(
self, best_valid_map=float(0)
):
self.best_valid_map = best_valid_map
def __call__(
self,
model,
current_valid_map,
epoch,
OUT_DIR,
):
if current_valid_map > self.best_valid_map:
self.best_valid_map = current_valid_map
print(f"\nBEST VALIDATION mAP: {self.best_valid_map}")
print(f"\nSAVING BEST MODEL FOR EPOCH: {epoch+1}\n")
torch.save({
'epoch': epoch+1,
'model_state_dict': model.state_dict(),
}, f"{OUT_DIR}/best_model.pth")
def collate_fn(batch):
"""
To handle the data loading as different images may have different number
of objects and to handle varying size tensors as well.
"""
return tuple(zip(*batch))
# Define the training tranforms.
def get_train_transform():
return A.Compose([
A.HorizontalFlip(p=0.5),
A.Blur(blur_limit=3, p=0.1),
A.MotionBlur(blur_limit=3, p=0.1),
A.MedianBlur(blur_limit=3, p=0.1),
A.ToGray(p=0.3),
A.RandomBrightnessContrast(p=0.3),
A.ColorJitter(p=0.3),
A.RandomGamma(p=0.3),
ToTensorV2(p=1.0),
], bbox_params={
'format': 'pascal_voc',
'label_fields': ['labels']
})
# Define the validation transforms.
def get_valid_transform():
return A.Compose([
ToTensorV2(p=1.0),
], bbox_params={
'format': 'pascal_voc',
'label_fields': ['labels']
})
def show_tranformed_image(train_loader):
"""
This function shows the transformed images from the `train_loader`.
Helps to check whether the tranformed images along with the corresponding
labels are correct or not.
Only runs if `VISUALIZE_TRANSFORMED_IMAGES = True` in config.py.
"""
if len(train_loader) > 0:
for i in range(1):
images, targets = next(iter(train_loader))
images = list(image.to(DEVICE) for image in images)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
boxes = targets[i]['boxes'].cpu().numpy().astype(np.int32)
labels = targets[i]['labels'].cpu().numpy().astype(np.int32)
sample = images[i].permute(1, 2, 0).cpu().numpy()
sample = cv2.cvtColor(sample, cv2.COLOR_RGB2BGR)
for box_num, box in enumerate(boxes):
cv2.rectangle(sample,
(box[0], box[1]),
(box[2], box[3]),
(0, 0, 255), 2)
cv2.putText(sample, CLASSES[labels[box_num]],
(box[0], box[1]-10), cv2.FONT_HERSHEY_SIMPLEX,
1.0, (0, 0, 255), 2)
cv2.imshow('Transformed image', sample)
cv2.waitKey(0)
cv2.destroyAllWindows()
def save_model(epoch, model, optimizer):
"""
Function to save the trained model till current epoch, or whenver called
"""
torch.save({
'epoch': epoch+1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, 'outputs/last_model.pth')
def save_loss_plot(
OUT_DIR,
train_loss_list,
x_label='iterations',
y_label='train loss',
save_name='train_loss'
):
"""
Function to save both train loss graph.
:param OUT_DIR: Path to save the graphs.
:param train_loss_list: List containing the training loss values.
"""
figure_1 = plt.figure(figsize=(10, 7), num=1, clear=True)
train_ax = figure_1.add_subplot()
train_ax.plot(train_loss_list, color='tab:blue')
train_ax.set_xlabel(x_label)
train_ax.set_ylabel(y_label)
figure_1.savefig(f"{OUT_DIR}/{save_name}.png")
print('SAVING PLOTS COMPLETE...')
def save_mAP(OUT_DIR, map_05, map):
"""
Saves the mAP@0.5 and mAP@0.5:0.95 per epoch.
:param OUT_DIR: Path to save the graphs.
:param map_05: List containing mAP values at 0.5 IoU.
:param map: List containing mAP values at 0.5:0.95 IoU.
"""
figure = plt.figure(figsize=(10, 7), num=1, clear=True)
ax = figure.add_subplot()
ax.plot(
map_05, color='tab:orange', linestyle='-',
label='mAP@0.5'
)
ax.plot(
map, color='tab:red', linestyle='-',
label='mAP@0.5:0.95'
)
ax.set_xlabel('Epochs')
ax.set_ylabel('mAP')
ax.legend()
figure.savefig(f"{OUT_DIR}/map.png")