-
Notifications
You must be signed in to change notification settings - Fork 0
/
helpers.py
109 lines (89 loc) · 3.14 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import pickle
import cv2
import numpy as np
import torch
import matplotlib.pyplot as plt
def list_of_distances(X, Y):
return torch.sum((torch.unsqueeze(X, dim=2) - torch.unsqueeze(Y.t(), dim=0)) ** 2,
dim=1)
def make_one_hot(target, target_one_hot):
target = target.view(-1, 1)
target_one_hot.zero_()
target_one_hot.scatter_(dim=1, index=target, value=1.)
def makedir(path: str):
"""
if path does not exist in the file system, create it
"""
if not os.path.exists(path):
os.makedirs(path)
def find_high_activation_crop(activation_map, percentile: int = 95):
threshold = np.percentile(activation_map, percentile)
mask = np.ones(activation_map.shape)
mask[activation_map < threshold] = 0
lower_y, upper_y, lower_x, upper_x = 0, 0, 0, 0
for i in range(mask.shape[0]):
if np.amax(mask[i]) > 0.5:
lower_y = i
break
for i in reversed(range(mask.shape[0])):
if np.amax(mask[i]) > 0.5:
upper_y = i
break
for j in range(mask.shape[1]):
if np.amax(mask[:, j]) > 0.5:
lower_x = j
break
for j in reversed(range(mask.shape[1])):
if np.amax(mask[:, j]) > 0.5:
upper_x = j
break
return lower_y, upper_y + 1, lower_x, upper_x + 1
def imsave_with_bbox(fname: str,
img_rgb,
bbox_height_start: int,
bbox_height_end: int,
bbox_width_start: int,
bbox_width_end: int,
color=(0, 255, 255)):
img_bgr_uint8 = cv2.cvtColor(np.uint8(255 * img_rgb), cv2.COLOR_RGB2BGR)
cv2.rectangle(img_bgr_uint8, (bbox_width_start, bbox_height_start),
(bbox_width_end - 1, bbox_height_end - 1),
color, thickness=2)
img_rgb_uint8 = img_bgr_uint8[..., ::-1]
img_rgb_float = np.float32(img_rgb_uint8) / 255
# plt.imshow(img_rgb_float)
# plt.axis('off')
plt.imsave(fname, img_rgb_float)
def save_model_w_condition(state: dict, model, save_path: str, to_save: bool,
log_wandb: bool = False):
"""
model: this is not the multigpu model
"""
if to_save:
# torch.save(obj=model, f=save_path)
torch.save(state, save_path + '.tar')
# if log_wandb:
# artifact = wandb.Artifact('model', type='model')
# artifact.add_file(save_path + '.tar')
# wandb.log_artifact(artifact)
def create_logger(log_filename, display=True):
f = open(log_filename, 'a')
counter = [0]
# this function will still have access to f after create_logger terminates
def logger(text):
if display:
print(text)
f.write(text + '\n')
counter[0] += 1
if counter[0] % 10 == 0:
f.flush()
os.fsync(f.fileno())
# Question: do we need to flush()
return logger, f.close
def dump(path, what, **kwargs):
with open(path, 'wb') as fp:
pickle.dump(what, fp, **kwargs)
def load(path, **kwargs):
with open(path, 'rb') as fp:
return pickle.load(fp, **kwargs)