-
Notifications
You must be signed in to change notification settings - Fork 40
/
Copy pathblazebase.py
513 lines (407 loc) · 19.4 KB
/
blazebase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def resize_pad(img):
""" resize and pad images to be input to the detectors
The face and palm detector networks take 256x256 and 128x128 images
as input. As such the input image is padded and resized to fit the
size while maintaing the aspect ratio.
Returns:
img1: 256x256
img2: 128x128
scale: scale factor between original image and 256x256 image
pad: pixels of padding in the original image
"""
size0 = img.shape
if size0[0]>=size0[1]:
h1 = 256
w1 = 256 * size0[1] // size0[0]
padh = 0
padw = 256 - w1
scale = size0[1] / w1
else:
h1 = 256 * size0[0] // size0[1]
w1 = 256
padh = 256 - h1
padw = 0
scale = size0[0] / h1
padh1 = padh//2
padh2 = padh//2 + padh%2
padw1 = padw//2
padw2 = padw//2 + padw%2
img1 = cv2.resize(img, (w1,h1))
img1 = np.pad(img1, ((padh1, padh2), (padw1, padw2), (0,0)))
pad = (int(padh1 * scale), int(padw1 * scale))
img2 = cv2.resize(img1, (128,128))
return img1, img2, scale, pad
def denormalize_detections(detections, scale, pad):
""" maps detection coordinates from [0,1] to image coordinates
The face and palm detector networks take 256x256 and 128x128 images
as input. As such the input image is padded and resized to fit the
size while maintaing the aspect ratio. This function maps the
normalized coordinates back to the original image coordinates.
Inputs:
detections: nxm tensor. n is the number of detections.
m is 4+2*k where the first 4 valuse are the bounding
box coordinates and k is the number of additional
keypoints output by the detector.
scale: scalar that was used to resize the image
pad: padding in the x and y dimensions
"""
detections[:, 0] = detections[:, 0] * scale * 256 - pad[0]
detections[:, 1] = detections[:, 1] * scale * 256 - pad[1]
detections[:, 2] = detections[:, 2] * scale * 256 - pad[0]
detections[:, 3] = detections[:, 3] * scale * 256 - pad[1]
detections[:, 4::2] = detections[:, 4::2] * scale * 256 - pad[1]
detections[:, 5::2] = detections[:, 5::2] * scale * 256 - pad[0]
return detections
class BlazeBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, act='relu', skip_proj=False):
super(BlazeBlock, self).__init__()
self.stride = stride
self.kernel_size = kernel_size
self.channel_pad = out_channels - in_channels
# TFLite uses slightly different padding than PyTorch
# on the depthwise conv layer when the stride is 2.
if stride == 2:
self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride)
padding = 0
else:
padding = (kernel_size - 1) // 2
self.convs = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=in_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
groups=in_channels, bias=True),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=1, stride=1, padding=0, bias=True),
)
if skip_proj:
self.skip_proj = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=1, stride=1, padding=0, bias=True)
else:
self.skip_proj = None
if act == 'relu':
self.act = nn.ReLU(inplace=True)
elif act == 'prelu':
self.act = nn.PReLU(out_channels)
else:
raise NotImplementedError("unknown activation %s"%act)
def forward(self, x):
if self.stride == 2:
if self.kernel_size==3:
h = F.pad(x, (0, 2, 0, 2), "constant", 0)
else:
h = F.pad(x, (1, 2, 1, 2), "constant", 0)
x = self.max_pool(x)
else:
h = x
if self.skip_proj is not None:
x = self.skip_proj(x)
elif self.channel_pad > 0:
x = F.pad(x, (0, 0, 0, 0, 0, self.channel_pad), "constant", 0)
return self.act(self.convs(h) + x)
class FinalBlazeBlock(nn.Module):
def __init__(self, channels, kernel_size=3):
super(FinalBlazeBlock, self).__init__()
# TFLite uses slightly different padding than PyTorch
# on the depthwise conv layer when the stride is 2.
self.convs = nn.Sequential(
nn.Conv2d(in_channels=channels, out_channels=channels,
kernel_size=kernel_size, stride=2, padding=0,
groups=channels, bias=True),
nn.Conv2d(in_channels=channels, out_channels=channels,
kernel_size=1, stride=1, padding=0, bias=True),
)
self.act = nn.ReLU(inplace=True)
def forward(self, x):
h = F.pad(x, (0, 2, 0, 2), "constant", 0)
return self.act(self.convs(h))
class BlazeBase(nn.Module):
""" Base class for media pipe models. """
def _device(self):
"""Which device (CPU or GPU) is being used by this model?"""
return self.classifier_8.weight.device
def load_weights(self, path):
self.load_state_dict(torch.load(path))
self.eval()
class BlazeLandmark(BlazeBase):
""" Base class for landmark models. """
def extract_roi(self, frame, xc, yc, theta, scale):
# take points on unit square and transform them according to the roi
points = torch.tensor([[-1, -1, 1, 1],
[-1, 1, -1, 1]], device=scale.device).view(1,2,4)
points = points * scale.view(-1,1,1)/2
theta = theta.view(-1, 1, 1)
R = torch.cat((
torch.cat((torch.cos(theta), -torch.sin(theta)), 2),
torch.cat((torch.sin(theta), torch.cos(theta)), 2),
), 1)
center = torch.cat((xc.view(-1,1,1), yc.view(-1,1,1)), 1)
points = R @ points + center
# use the points to compute the affine transform that maps
# these points back to the output square
res = self.resolution
points1 = np.array([[0, 0, res-1],
[0, res-1, 0]], dtype=np.float32).T
affines = []
imgs = []
for i in range(points.shape[0]):
pts = points[i, :, :3].cpu().numpy().T
M = cv2.getAffineTransform(pts, points1)
img = cv2.warpAffine(frame, M, (res,res))#, borderValue=127.5)
img = torch.tensor(img, device=scale.device)
imgs.append(img)
affine = cv2.invertAffineTransform(M).astype('float32')
affine = torch.tensor(affine, device=scale.device)
affines.append(affine)
if imgs:
imgs = torch.stack(imgs).permute(0,3,1,2).float() / 255.#/ 127.5 - 1.0
affines = torch.stack(affines)
else:
imgs = torch.zeros((0, 3, res, res), device=scale.device)
affines = torch.zeros((0, 2, 3), device=scale.device)
return imgs, affines, points
def denormalize_landmarks(self, landmarks, affines):
landmarks[:,:,:2] *= self.resolution
for i in range(len(landmarks)):
landmark, affine = landmarks[i], affines[i]
landmark = (affine[:,:2] @ landmark[:,:2].T + affine[:,2:]).T
landmarks[i,:,:2] = landmark
return landmarks
class BlazeDetector(BlazeBase):
""" Base class for detector models.
Based on code from https://github.com/tkat0/PyTorch_BlazeFace/ and
https://github.com/hollance/BlazeFace-PyTorch and
https://github.com/google/mediapipe/
"""
def load_anchors(self, path):
self.anchors = torch.tensor(np.load(path), dtype=torch.float32, device=self._device())
assert(self.anchors.ndimension() == 2)
assert(self.anchors.shape[0] == self.num_anchors)
assert(self.anchors.shape[1] == 4)
def _preprocess(self, x):
"""Converts the image pixels to the range [-1, 1]."""
return x.float() / 255.# 127.5 - 1.0
def predict_on_image(self, img):
"""Makes a prediction on a single image.
Arguments:
img: a NumPy array of shape (H, W, 3) or a PyTorch tensor of
shape (3, H, W). The image's height and width should be
128 pixels.
Returns:
A tensor with face detections.
"""
if isinstance(img, np.ndarray):
img = torch.from_numpy(img).permute((2, 0, 1))
return self.predict_on_batch(img.unsqueeze(0))[0]
def predict_on_batch(self, x):
"""Makes a prediction on a batch of images.
Arguments:
x: a NumPy array of shape (b, H, W, 3) or a PyTorch tensor of
shape (b, 3, H, W). The height and width should be 128 pixels.
Returns:
A list containing a tensor of face detections for each image in
the batch. If no faces are found for an image, returns a tensor
of shape (0, 17).
Each face detection is a PyTorch tensor consisting of 17 numbers:
- ymin, xmin, ymax, xmax
- x,y-coordinates for the 6 keypoints
- confidence score
"""
if isinstance(x, np.ndarray):
x = torch.from_numpy(x).permute((0, 3, 1, 2))
assert x.shape[1] == 3
assert x.shape[2] == self.y_scale
assert x.shape[3] == self.x_scale
# 1. Preprocess the images into tensors:
x = x.to(self._device())
x = self._preprocess(x)
# 2. Run the neural network:
with torch.no_grad():
out = self.__call__(x)
# 3. Postprocess the raw predictions:
detections = self._tensors_to_detections(out[0], out[1], self.anchors)
# 4. Non-maximum suppression to remove overlapping detections:
filtered_detections = []
for i in range(len(detections)):
faces = self._weighted_non_max_suppression(detections[i])
faces = torch.stack(faces) if len(faces) > 0 else torch.zeros((0, self.num_coords+1))
filtered_detections.append(faces)
return filtered_detections
def detection2roi(self, detection):
""" Convert detections from detector to an oriented bounding box.
Adapted from:
# mediapipe/modules/face_landmark/face_detection_front_detection_to_roi.pbtxt
The center and size of the box is calculated from the center
of the detected box. Rotation is calcualted from the vector
between kp1 and kp2 relative to theta0. The box is scaled
and shifted by dscale and dy.
"""
if self.detection2roi_method == 'box':
# compute box center and scale
# use mediapipe/calculators/util/detections_to_rects_calculator.cc
xc = (detection[:,1] + detection[:,3]) / 2
yc = (detection[:,0] + detection[:,2]) / 2
scale = (detection[:,3] - detection[:,1]) # assumes square boxes
elif self.detection2roi_method == 'alignment':
# compute box center and scale
# use mediapipe/calculators/util/alignment_points_to_rects_calculator.cc
xc = detection[:,4+2*self.kp1]
yc = detection[:,4+2*self.kp1+1]
x1 = detection[:,4+2*self.kp2]
y1 = detection[:,4+2*self.kp2+1]
scale = ((xc-x1)**2 + (yc-y1)**2).sqrt() * 2
else:
raise NotImplementedError(
"detection2roi_method [%s] not supported"%self.detection2roi_method)
yc += self.dy * scale
scale *= self.dscale
# compute box rotation
x0 = detection[:,4+2*self.kp1]
y0 = detection[:,4+2*self.kp1+1]
x1 = detection[:,4+2*self.kp2]
y1 = detection[:,4+2*self.kp2+1]
#theta = np.arctan2(y0-y1, x0-x1) - self.theta0
theta = torch.atan2(y0-y1, x0-x1) - self.theta0
return xc, yc, scale, theta
def _tensors_to_detections(self, raw_box_tensor, raw_score_tensor, anchors):
"""The output of the neural network is a tensor of shape (b, 896, 16)
containing the bounding box regressor predictions, as well as a tensor
of shape (b, 896, 1) with the classification confidences.
This function converts these two "raw" tensors into proper detections.
Returns a list of (num_detections, 17) tensors, one for each image in
the batch.
This is based on the source code from:
mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.cc
mediapipe/calculators/tflite/tflite_tensors_to_detections_calculator.proto
"""
assert raw_box_tensor.ndimension() == 3
assert raw_box_tensor.shape[1] == self.num_anchors
assert raw_box_tensor.shape[2] == self.num_coords
assert raw_score_tensor.ndimension() == 3
assert raw_score_tensor.shape[1] == self.num_anchors
assert raw_score_tensor.shape[2] == self.num_classes
assert raw_box_tensor.shape[0] == raw_score_tensor.shape[0]
detection_boxes = self._decode_boxes(raw_box_tensor, anchors)
thresh = self.score_clipping_thresh
raw_score_tensor = raw_score_tensor.clamp(-thresh, thresh)
detection_scores = raw_score_tensor.sigmoid().squeeze(dim=-1)
# Note: we stripped off the last dimension from the scores tensor
# because there is only has one class. Now we can simply use a mask
# to filter out the boxes with too low confidence.
mask = detection_scores >= self.min_score_thresh
# Because each image from the batch can have a different number of
# detections, process them one at a time using a loop.
output_detections = []
for i in range(raw_box_tensor.shape[0]):
boxes = detection_boxes[i, mask[i]]
scores = detection_scores[i, mask[i]].unsqueeze(dim=-1)
output_detections.append(torch.cat((boxes, scores), dim=-1))
return output_detections
def _decode_boxes(self, raw_boxes, anchors):
"""Converts the predictions into actual coordinates using
the anchor boxes. Processes the entire batch at once.
"""
boxes = torch.zeros_like(raw_boxes)
x_center = raw_boxes[..., 0] / self.x_scale * anchors[:, 2] + anchors[:, 0]
y_center = raw_boxes[..., 1] / self.y_scale * anchors[:, 3] + anchors[:, 1]
w = raw_boxes[..., 2] / self.w_scale * anchors[:, 2]
h = raw_boxes[..., 3] / self.h_scale * anchors[:, 3]
boxes[..., 0] = y_center - h / 2. # ymin
boxes[..., 1] = x_center - w / 2. # xmin
boxes[..., 2] = y_center + h / 2. # ymax
boxes[..., 3] = x_center + w / 2. # xmax
for k in range(self.num_keypoints):
offset = 4 + k*2
keypoint_x = raw_boxes[..., offset ] / self.x_scale * anchors[:, 2] + anchors[:, 0]
keypoint_y = raw_boxes[..., offset + 1] / self.y_scale * anchors[:, 3] + anchors[:, 1]
boxes[..., offset ] = keypoint_x
boxes[..., offset + 1] = keypoint_y
return boxes
def _weighted_non_max_suppression(self, detections):
"""The alternative NMS method as mentioned in the BlazeFace paper:
"We replace the suppression algorithm with a blending strategy that
estimates the regression parameters of a bounding box as a weighted
mean between the overlapping predictions."
The original MediaPipe code assigns the score of the most confident
detection to the weighted detection, but we take the average score
of the overlapping detections.
The input detections should be a Tensor of shape (count, 17).
Returns a list of PyTorch tensors, one for each detected face.
This is based on the source code from:
mediapipe/calculators/util/non_max_suppression_calculator.cc
mediapipe/calculators/util/non_max_suppression_calculator.proto
"""
if len(detections) == 0: return []
output_detections = []
# Sort the detections from highest to lowest score.
remaining = torch.argsort(detections[:, self.num_coords], descending=True)
while len(remaining) > 0:
detection = detections[remaining[0]]
# Compute the overlap between the first box and the other
# remaining boxes. (Note that the other_boxes also include
# the first_box.)
first_box = detection[:4]
other_boxes = detections[remaining, :4]
ious = overlap_similarity(first_box, other_boxes)
# If two detections don't overlap enough, they are considered
# to be from different faces.
mask = ious > self.min_suppression_threshold
overlapping = remaining[mask]
remaining = remaining[~mask]
# Take an average of the coordinates from the overlapping
# detections, weighted by their confidence scores.
weighted_detection = detection.clone()
if len(overlapping) > 1:
coordinates = detections[overlapping, :self.num_coords]
scores = detections[overlapping, self.num_coords:self.num_coords+1]
total_score = scores.sum()
weighted = (coordinates * scores).sum(dim=0) / total_score
weighted_detection[:self.num_coords] = weighted
weighted_detection[self.num_coords] = total_score / len(overlapping)
output_detections.append(weighted_detection)
return output_detections
# IOU code from https://github.com/amdegroot/ssd.pytorch/blob/master/layers/box_utils.py
def intersect(box_a, box_b):
""" We resize both tensors to [A,B,2] without new malloc:
[A,2] -> [A,1,2] -> [A,B,2]
[B,2] -> [1,B,2] -> [A,B,2]
Then we compute the area of intersect between box_a and box_b.
Args:
box_a: (tensor) bounding boxes, Shape: [A,4].
box_b: (tensor) bounding boxes, Shape: [B,4].
Return:
(tensor) intersection area, Shape: [A,B].
"""
A = box_a.size(0)
B = box_b.size(0)
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
box_b[:, :2].unsqueeze(0).expand(A, B, 2))
inter = torch.clamp((max_xy - min_xy), min=0)
return inter[:, :, 0] * inter[:, :, 1]
def jaccard(box_a, box_b):
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
is simply the intersection over union of two boxes. Here we operate on
ground truth boxes and default boxes.
E.g.:
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
Args:
box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
Return:
jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
"""
inter = intersect(box_a, box_b)
area_a = ((box_a[:, 2]-box_a[:, 0]) *
(box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
area_b = ((box_b[:, 2]-box_b[:, 0]) *
(box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
union = area_a + area_b - inter
return inter / union # [A,B]
def overlap_similarity(box, other_boxes):
"""Computes the IOU between a bounding box and set of other boxes."""
return jaccard(box.unsqueeze(0), other_boxes).squeeze(0)