-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
executable file
·166 lines (146 loc) · 5.27 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import cv2
from model import l2_norm
import torch
from data.data_pipe import de_preprocess
from torchvision import transforms as trans
import io
from datetime import datetime
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')
def separate_bn_paras(modules):
if not isinstance(modules, list):
modules = [*modules.modules()]
paras_only_bn = []
paras_wo_bn = []
for layer in modules:
if 'model' in str(layer.__class__):
continue
if 'container' in str(layer.__class__):
continue
else:
if 'batchnorm' in str(layer.__class__):
paras_only_bn.extend([*layer.parameters()])
else:
paras_wo_bn.extend([*layer.parameters()])
return paras_only_bn, paras_wo_bn
def prepare_facebank(conf, model, mtcnn, tta=True):
model.eval()
embeddings = []
names = ['Unknown']
for path in conf.facebank_path.iterdir():
if path.is_file():
continue
else:
embs = []
for file in path.iterdir():
if not file.is_file():
continue
else:
try:
img = Image.open(file)
except:
continue
if img.size != (112, 112):
img = mtcnn.align(img)
with torch.no_grad():
if tta:
mirror = trans.functional.hflip(img)
emb = model(conf.test_transform(
img).to(conf.device).unsqueeze(0))
emb_mirror = model(conf.test_transform(
mirror).to(conf.device).unsqueeze(0))
embs.append(l2_norm(emb + emb_mirror))
else:
embs.append(model(conf.test_transform(
img).to(conf.device).unsqueeze(0)))
if len(embs) == 0:
continue
embedding = torch.cat(embs).mean(0, keepdim=True)
embeddings.append(embedding)
names.append(path.name)
embeddings = torch.cat(embeddings)
names = np.array(names)
torch.save(embeddings, './data/AnotherMissOh/Face/facebank/facebank.pth')
np.save(conf.facebank_path/'names', names)
return embeddings, names
def load_facebank(conf):
embeddings = torch.load(str(conf.facebank_path/'facebank.pth'))
names = np.load(conf.facebank_path/'names.npy')
return embeddings, names
def face_reader(conf, conn, flag, boxes_arr,
result_arr, learner, mtcnn, targets, tta):
while True:
try:
image = conn.recv()
except:
continue
try:
bboxes, faces = mtcnn.align_multi(image, limit=conf.face_limit)
except:
bboxes = []
results = learner.infer(conf, faces, targets, tta)
if len(bboxes) > 0:
print('bboxes in reader : {}'.format(bboxes))
# shape:[10,4],only keep 10 highest possibiity faces
bboxes = bboxes[:, :-1]
bboxes = bboxes.astype(int)
bboxes = bboxes + [-1, -1, 1, 1] # personal choice
assert bboxes.shape[0] == results.shape[0], 'bbox and faces number not same'
bboxes = bboxes.reshape([-1])
for i in range(len(boxes_arr)):
if i < len(bboxes):
boxes_arr[i] = bboxes[i]
else:
boxes_arr[i] = 0
for i in range(len(result_arr)):
if i < len(results):
result_arr[i] = results[i]
else:
result_arr[i] = -1
else:
for i in range(len(boxes_arr)):
boxes_arr[i] = 0 # by default,it's all 0
for i in range(len(result_arr)):
result_arr[i] = -1 # by default,it's all -1
print('boxes_arr : {}'.format(boxes_arr[:4]))
print('result_arr : {}'.format(result_arr[:4]))
flag.value = 0
hflip = trans.Compose([
de_preprocess,
trans.ToPILImage(),
trans.functional.hflip,
trans.ToTensor(),
trans.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
def hflip_batch(imgs_tensor):
hfliped_imgs = torch.empty_like(imgs_tensor)
for i, img_ten in enumerate(imgs_tensor):
hfliped_imgs[i] = hflip(img_ten)
return hfliped_imgs
def get_time():
return (str(datetime.now())[:-10]).replace(' ', '-').replace(':', '-')
def gen_plot(fpr, tpr):
"""Create a pyplot plot and save to buffer."""
plt.figure()
plt.xlabel("FPR", fontsize=14)
plt.ylabel("TPR", fontsize=14)
plt.title("ROC Curve", fontsize=14)
buf = io.BytesIO()
plt.savefig(buf, format='jpeg')
buf.seek(0)
plt.close()
return buf
def draw_box_name(bbox, name, frame):
frame = cv2.rectangle(
frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (0, 0, 255), 6)
frame = cv2.putText(frame,
name,
(bbox[0], bbox[1]),
cv2.FONT_HERSHEY_SIMPLEX,
2,
(0, 255, 0),
3,
cv2.LINE_AA)
return frame