-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
186 lines (167 loc) · 6.42 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
## utils.py
# some useful functions!
import numpy as np
from itertools import cycle
import torch
import shutil
import os
import logging
import argparse
import sys
import math
def getFiles(targetdir):
ls = []
for fname in os.listdir(targetdir):
path = os.path.join(targetdir, fname)
if os.path.isdir(path):
continue
ls.append(fname)
return ls
def getDirs(parent_dir):
ls = []
for dir_name in os.listdir(parent_dir):
path = os.path.join(parent_dir, dir_name)
if os.path.isdir(path):
ls.append(dir_name)
return ls
def windowLevelNormalize(image, level, window):
minval = level - window/2
maxval = level + window/2
wld = np.clip(image, minval, maxval)
wld -= minval
wld *= (1 / window)
return wld
def try_mkdir(dir_name):
try:
os.mkdir(dir_name)
except OSError:
pass
def get_logger(name, level=logging.INFO):
logger = logging.getLogger(name)
logger.setLevel(level)
# Logging to console
stream_handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
'%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s')
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
return logger
def get_number_of_learnable_parameters(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
class RunningAverage:
# Computes and stores the average
def __init__(self):
self.count = 0
self.sum = 0
self.avg = 0
def update(self, value, n=1):
self.count += n
self.sum += value * n
self.avg = self.sum / self.count
def save_checkpoint(state, is_best, checkpoint_dir, logger=None):
"""Saves model and training parameters at '{checkpoint_dir}/last_checkpoint.pytorch'.
If is_best==True saves '{checkpoint_dir}/best_checkpoint.pytorch' as well.
Args:
state (dict): contains model's state_dict, optimizer's state_dict, epoch
and best evaluation metric value so far
is_best (bool): if True state contains the best model seen so far
checkpoint_dir (string): directory where the checkpoint are to be saved
"""
def log_info(message):
if logger is not None:
logger.info(message)
if not os.path.exists(checkpoint_dir):
log_info(
f"Checkpoint directory does not exists. Creating {checkpoint_dir}")
os.mkdir(checkpoint_dir)
# old - save all
last_file_path = os.path.join(checkpoint_dir, 'last_checkpoint.pytorch')
log_info(f"Saving last checkpoint to '{last_file_path}'")
torch.save(state, last_file_path)
if is_best:
best_file_path = os.path.join(checkpoint_dir, 'best_checkpoint.pytorch')
log_info(f"Saving best checkpoint to '{best_file_path}'")
shutil.copyfile(last_file_path, best_file_path)
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def k_fold_split_train_val_test(dataset_size, fold_num, seed=230597):
k = int(fold_num-1)
if dataset_size == 72:
train_ims, val_ims, test_ims = 48, 8, 16
else:
train_ims, val_ims, test_ims = math.floor(dataset_size*0.7), math.floor(dataset_size*0.1), math.ceil(dataset_size*0.2)
#train_ims, val_ims, test_ims = round(dataset_size*0.7), round(dataset_size*0.1), round(dataset_size*0.2)
if dataset_size - (train_ims+val_ims+test_ims) == 1:
val_ims += 1 # put the extra into val set
try:
assert(train_ims+val_ims+test_ims == dataset_size)
except AssertionError:
print("Check the k fold data splitting, something's dodgy...")
exit(1)
train_inds, val_inds, test_inds = [], [], []
# initial shuffle
np.random.seed(seed)
shuffled_ind_list = np.random.permutation(dataset_size)
# allocate dataset indices based upon the fold number --> not the prettiest or most efficient implementation, but functional
cyclic_ind_list = cycle(shuffled_ind_list)
for i in range(k*test_ims):
next(cyclic_ind_list) # shift start pos
for i in range(test_ims):
test_inds.append(next(cyclic_ind_list))
for i in range(train_ims):
train_inds.append(next(cyclic_ind_list))
for i in range(val_ims):
val_inds.append(next(cyclic_ind_list))
return train_inds, val_inds, test_inds
def k_fold_split_testset_inds(dataset_size, fold_num):
k = int(fold_num-1)
train_ims, val_ims, test_ims = 192, 24, 22
assert(train_ims+val_ims+test_ims == dataset_size)
train_inds, val_inds, test_inds = [], [], []
# initial shuffle
np.random.seed(2305)
shuffled_ind_list = np.random.permutation(dataset_size)
# allocate dataset indices based upon the fold number --> not the prettiest or most efficient implementation, but functional
cyclic_ind_list = cycle(shuffled_ind_list)
for i in range(k*test_ims):
next(cyclic_ind_list) # shift start pos
for i in range(test_ims):
test_inds.append(next(cyclic_ind_list))
for i in range(train_ims):
train_inds.append(next(cyclic_ind_list))
for i in range(val_ims):
val_inds.append(next(cyclic_ind_list))
return test_inds
class RunningAverage:
# Computes and stores the average
def __init__(self):
self.count = 0
self.sum = 0
self.avg = 0
def update(self, value, n=1):
self.count += n
self.sum += value * n
self.avg = self.sum / self.count
def save_checkpoint(state, is_best, checkpoint_dir, logger=None):
def log_info(message):
if logger is not None:
logger.info(message)
if not os.path.exists(checkpoint_dir):
log_info(
f"Checkpoint directory does not exists. Creating {checkpoint_dir}")
os.mkdir(checkpoint_dir)
last_file_path = os.path.join(checkpoint_dir, 'last_checkpoint.pytorch')
log_info(f"Saving last checkpoint to '{last_file_path}'")
torch.save(state, last_file_path)
if is_best:
best_file_path = os.path.join(checkpoint_dir, 'best_checkpoint.pytorch')
log_info(f"Saving best checkpoint to '{best_file_path}'")
shutil.copyfile(last_file_path, best_file_path)