forked from matterport/Mask_RCNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
2868 lines (2475 loc) · 124 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Mask R-CNN
The main Mask R-CNN model implementation.
Copyright (c) 2017 Matterport, Inc.
Licensed under the MIT License (see LICENSE for details)
Written by Waleed Abdulla
"""
import os
import random
import datetime
import re
import math
import logging
from collections import OrderedDict
import multiprocessing
import numpy as np
import tensorflow as tf
import keras
import keras.backend as K
import keras.layers as KL
import keras.engine as KE
import keras.models as KM
from mrcnn import utils
# Requires TensorFlow 1.3+ and Keras 2.0.8+.
from distutils.version import LooseVersion
assert LooseVersion(tf.__version__) >= LooseVersion("1.3")
assert LooseVersion(keras.__version__) >= LooseVersion('2.0.8')
############################################################
# Utility Functions
############################################################
def log(text, array=None):
"""Prints a text message. And, optionally, if a Numpy array is provided it
prints it's shape, min, and max values.
"""
if array is not None:
text = text.ljust(25)
text += ("shape: {:20} ".format(str(array.shape)))
if array.size:
text += ("min: {:10.5f} max: {:10.5f}".format(array.min(),array.max()))
else:
text += ("min: {:10} max: {:10}".format("",""))
text += " {}".format(array.dtype)
print(text)
class BatchNorm(KL.BatchNormalization):
"""Extends the Keras BatchNormalization class to allow a central place
to make changes if needed.
Batch normalization has a negative effect on training if batches are small
so this layer is often frozen (via setting in Config class) and functions
as linear layer.
"""
def call(self, inputs, training=None):
"""
Note about training values:
None: Train BN layers. This is the normal mode
False: Freeze BN layers. Good when batch size is small
True: (don't use). Set layer in training mode even when making inferences
"""
return super(self.__class__, self).call(inputs, training=training)
def compute_backbone_shapes(config, image_shape):
"""Computes the width and height of each stage of the backbone network.
Returns:
[N, (height, width)]. Where N is the number of stages
"""
if callable(config.BACKBONE):
return config.COMPUTE_BACKBONE_SHAPE(image_shape)
# Currently supports ResNet only
assert config.BACKBONE in ["resnet50", "resnet101"]
return np.array(
[[int(math.ceil(image_shape[0] / stride)),
int(math.ceil(image_shape[1] / stride))]
for stride in config.BACKBONE_STRIDES])
############################################################
# Resnet Graph
############################################################
# Code adopted from:
# https://github.com/fchollet/deep-learning-models/blob/master/resnet50.py
def identity_block(input_tensor, kernel_size, filters, stage, block,
use_bias=True, train_bn=True):
"""The identity_block is the block that has no conv layer at shortcut
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of middle conv layer at main path
filters: list of integers, the nb_filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_bias: Boolean. To use or not use a bias in conv layers.
train_bn: Boolean. Train or freeze Batch Norm layers
"""
nb_filter1, nb_filter2, nb_filter3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = KL.Conv2D(nb_filter1, (1, 1), name=conv_name_base + '2a',
use_bias=use_bias)(input_tensor)
x = BatchNorm(name=bn_name_base + '2a')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.Conv2D(nb_filter2, (kernel_size, kernel_size), padding='same',
name=conv_name_base + '2b', use_bias=use_bias)(x)
x = BatchNorm(name=bn_name_base + '2b')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.Conv2D(nb_filter3, (1, 1), name=conv_name_base + '2c',
use_bias=use_bias)(x)
x = BatchNorm(name=bn_name_base + '2c')(x, training=train_bn)
x = KL.Add()([x, input_tensor])
x = KL.Activation('relu', name='res' + str(stage) + block + '_out')(x)
return x
def conv_block(input_tensor, kernel_size, filters, stage, block,
strides=(2, 2), use_bias=True, train_bn=True):
"""conv_block is the block that has a conv layer at shortcut
# Arguments
input_tensor: input tensor
kernel_size: default 3, the kernel size of middle conv layer at main path
filters: list of integers, the nb_filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_bias: Boolean. To use or not use a bias in conv layers.
train_bn: Boolean. Train or freeze Batch Norm layers
Note that from stage 3, the first conv layer at main path is with subsample=(2,2)
And the shortcut should have subsample=(2,2) as well
"""
nb_filter1, nb_filter2, nb_filter3 = filters
conv_name_base = 'res' + str(stage) + block + '_branch'
bn_name_base = 'bn' + str(stage) + block + '_branch'
x = KL.Conv2D(nb_filter1, (1, 1), strides=strides,
name=conv_name_base + '2a', use_bias=use_bias)(input_tensor)
x = BatchNorm(name=bn_name_base + '2a')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.Conv2D(nb_filter2, (kernel_size, kernel_size), padding='same',
name=conv_name_base + '2b', use_bias=use_bias)(x)
x = BatchNorm(name=bn_name_base + '2b')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.Conv2D(nb_filter3, (1, 1), name=conv_name_base +
'2c', use_bias=use_bias)(x)
x = BatchNorm(name=bn_name_base + '2c')(x, training=train_bn)
shortcut = KL.Conv2D(nb_filter3, (1, 1), strides=strides,
name=conv_name_base + '1', use_bias=use_bias)(input_tensor)
shortcut = BatchNorm(name=bn_name_base + '1')(shortcut, training=train_bn)
x = KL.Add()([x, shortcut])
x = KL.Activation('relu', name='res' + str(stage) + block + '_out')(x)
return x
def resnet_graph(input_image, architecture, stage5=False, train_bn=True):
"""Build a ResNet graph.
architecture: Can be resnet50 or resnet101
stage5: Boolean. If False, stage5 of the network is not created
train_bn: Boolean. Train or freeze Batch Norm layers
"""
assert architecture in ["resnet50", "resnet101"]
# Stage 1
x = KL.ZeroPadding2D((3, 3))(input_image)
x = KL.Conv2D(64, (7, 7), strides=(2, 2), name='conv1', use_bias=True)(x)
x = BatchNorm(name='bn_conv1')(x, training=train_bn)
x = KL.Activation('relu')(x)
C1 = x = KL.MaxPooling2D((3, 3), strides=(2, 2), padding="same")(x)
# Stage 2
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1), train_bn=train_bn)
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b', train_bn=train_bn)
C2 = x = identity_block(x, 3, [64, 64, 256], stage=2, block='c', train_bn=train_bn)
# Stage 3
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a', train_bn=train_bn)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b', train_bn=train_bn)
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c', train_bn=train_bn)
C3 = x = identity_block(x, 3, [128, 128, 512], stage=3, block='d', train_bn=train_bn)
# Stage 4
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a', train_bn=train_bn)
block_count = {"resnet50": 5, "resnet101": 22}[architecture]
for i in range(block_count):
x = identity_block(x, 3, [256, 256, 1024], stage=4, block=chr(98 + i), train_bn=train_bn)
C4 = x
# Stage 5
if stage5:
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a', train_bn=train_bn)
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b', train_bn=train_bn)
C5 = x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c', train_bn=train_bn)
else:
C5 = None
return [C1, C2, C3, C4, C5]
############################################################
# Proposal Layer
############################################################
def apply_box_deltas_graph(boxes, deltas):
"""Applies the given deltas to the given boxes.
boxes: [N, (y1, x1, y2, x2)] boxes to update
deltas: [N, (dy, dx, log(dh), log(dw))] refinements to apply
"""
# Convert to y, x, h, w
height = boxes[:, 2] - boxes[:, 0]
width = boxes[:, 3] - boxes[:, 1]
center_y = boxes[:, 0] + 0.5 * height
center_x = boxes[:, 1] + 0.5 * width
# Apply deltas
center_y += deltas[:, 0] * height
center_x += deltas[:, 1] * width
height *= tf.exp(deltas[:, 2])
width *= tf.exp(deltas[:, 3])
# Convert back to y1, x1, y2, x2
y1 = center_y - 0.5 * height
x1 = center_x - 0.5 * width
y2 = y1 + height
x2 = x1 + width
result = tf.stack([y1, x1, y2, x2], axis=1, name="apply_box_deltas_out")
return result
def clip_boxes_graph(boxes, window):
"""
boxes: [N, (y1, x1, y2, x2)]
window: [4] in the form y1, x1, y2, x2
"""
# Split
wy1, wx1, wy2, wx2 = tf.split(window, 4)
y1, x1, y2, x2 = tf.split(boxes, 4, axis=1)
# Clip
y1 = tf.maximum(tf.minimum(y1, wy2), wy1)
x1 = tf.maximum(tf.minimum(x1, wx2), wx1)
y2 = tf.maximum(tf.minimum(y2, wy2), wy1)
x2 = tf.maximum(tf.minimum(x2, wx2), wx1)
clipped = tf.concat([y1, x1, y2, x2], axis=1, name="clipped_boxes")
clipped.set_shape((clipped.shape[0], 4))
return clipped
class ProposalLayer(KE.Layer):
"""Receives anchor scores and selects a subset to pass as proposals
to the second stage. Filtering is done based on anchor scores and
non-max suppression to remove overlaps. It also applies bounding
box refinement deltas to anchors.
Inputs:
rpn_probs: [batch, num_anchors, (bg prob, fg prob)]
rpn_bbox: [batch, num_anchors, (dy, dx, log(dh), log(dw))]
anchors: [batch, num_anchors, (y1, x1, y2, x2)] anchors in normalized coordinates
Returns:
Proposals in normalized coordinates [batch, rois, (y1, x1, y2, x2)]
"""
def __init__(self, proposal_count, nms_threshold, config=None, **kwargs):
super(ProposalLayer, self).__init__(**kwargs)
self.config = config
self.proposal_count = proposal_count
self.nms_threshold = nms_threshold
def call(self, inputs):
# Box Scores. Use the foreground class confidence. [Batch, num_rois, 1]
scores = inputs[0][:, :, 1]
# Box deltas [batch, num_rois, 4]
deltas = inputs[1]
deltas = deltas * np.reshape(self.config.RPN_BBOX_STD_DEV, [1, 1, 4])
# Anchors
anchors = inputs[2]
# Improve performance by trimming to top anchors by score
# and doing the rest on the smaller subset.
pre_nms_limit = tf.minimum(self.config.PRE_NMS_LIMIT, tf.shape(anchors)[1])
ix = tf.nn.top_k(scores, pre_nms_limit, sorted=True,
name="top_anchors").indices
scores = utils.batch_slice([scores, ix], lambda x, y: tf.gather(x, y),
self.config.IMAGES_PER_GPU)
deltas = utils.batch_slice([deltas, ix], lambda x, y: tf.gather(x, y),
self.config.IMAGES_PER_GPU)
pre_nms_anchors = utils.batch_slice([anchors, ix], lambda a, x: tf.gather(a, x),
self.config.IMAGES_PER_GPU,
names=["pre_nms_anchors"])
# Apply deltas to anchors to get refined anchors.
# [batch, N, (y1, x1, y2, x2)]
boxes = utils.batch_slice([pre_nms_anchors, deltas],
lambda x, y: apply_box_deltas_graph(x, y),
self.config.IMAGES_PER_GPU,
names=["refined_anchors"])
# Clip to image boundaries. Since we're in normalized coordinates,
# clip to 0..1 range. [batch, N, (y1, x1, y2, x2)]
window = np.array([0, 0, 1, 1], dtype=np.float32)
boxes = utils.batch_slice(boxes,
lambda x: clip_boxes_graph(x, window),
self.config.IMAGES_PER_GPU,
names=["refined_anchors_clipped"])
# Filter out small boxes
# According to Xinlei Chen's paper, this reduces detection accuracy
# for small objects, so we're skipping it.
# Non-max suppression
def nms(boxes, scores):
indices = tf.image.non_max_suppression(
boxes, scores, self.proposal_count,
self.nms_threshold, name="rpn_non_max_suppression")
proposals = tf.gather(boxes, indices)
# Pad if needed
padding = tf.maximum(self.proposal_count - tf.shape(proposals)[0], 0)
proposals = tf.pad(proposals, [(0, padding), (0, 0)])
return proposals
proposals = utils.batch_slice([boxes, scores], nms,
self.config.IMAGES_PER_GPU)
return proposals
def compute_output_shape(self, input_shape):
return (None, self.proposal_count, 4)
############################################################
# ROIAlign Layer
############################################################
def log2_graph(x):
"""Implementation of Log2. TF doesn't have a native implementation."""
return tf.log(x) / tf.log(2.0)
class PyramidROIAlign(KE.Layer):
"""Implements ROI Pooling on multiple levels of the feature pyramid.
Params:
- pool_shape: [pool_height, pool_width] of the output pooled regions. Usually [7, 7]
Inputs:
- boxes: [batch, num_boxes, (y1, x1, y2, x2)] in normalized
coordinates. Possibly padded with zeros if not enough
boxes to fill the array.
- image_meta: [batch, (meta data)] Image details. See compose_image_meta()
- feature_maps: List of feature maps from different levels of the pyramid.
Each is [batch, height, width, channels]
Output:
Pooled regions in the shape: [batch, num_boxes, pool_height, pool_width, channels].
The width and height are those specific in the pool_shape in the layer
constructor.
"""
def __init__(self, pool_shape, **kwargs):
super(PyramidROIAlign, self).__init__(**kwargs)
self.pool_shape = tuple(pool_shape)
def call(self, inputs):
# Crop boxes [batch, num_boxes, (y1, x1, y2, x2)] in normalized coords
boxes = inputs[0]
# Image meta
# Holds details about the image. See compose_image_meta()
image_meta = inputs[1]
# Feature Maps. List of feature maps from different level of the
# feature pyramid. Each is [batch, height, width, channels]
feature_maps = inputs[2:]
# Assign each ROI to a level in the pyramid based on the ROI area.
y1, x1, y2, x2 = tf.split(boxes, 4, axis=2)
h = y2 - y1
w = x2 - x1
# Use shape of first image. Images in a batch must have the same size.
image_shape = parse_image_meta_graph(image_meta)['image_shape'][0]
# Equation 1 in the Feature Pyramid Networks paper. Account for
# the fact that our coordinates are normalized here.
# e.g. a 224x224 ROI (in pixels) maps to P4
image_area = tf.cast(image_shape[0] * image_shape[1], tf.float32)
roi_level = log2_graph(tf.sqrt(h * w) / (224.0 / tf.sqrt(image_area)))
roi_level = tf.minimum(5, tf.maximum(
2, 4 + tf.cast(tf.round(roi_level), tf.int32)))
roi_level = tf.squeeze(roi_level, 2)
# Loop through levels and apply ROI pooling to each. P2 to P5.
pooled = []
box_to_level = []
for i, level in enumerate(range(2, 6)):
ix = tf.where(tf.equal(roi_level, level))
level_boxes = tf.gather_nd(boxes, ix)
# Box indices for crop_and_resize.
box_indices = tf.cast(ix[:, 0], tf.int32)
# Keep track of which box is mapped to which level
box_to_level.append(ix)
# Stop gradient propogation to ROI proposals
level_boxes = tf.stop_gradient(level_boxes)
box_indices = tf.stop_gradient(box_indices)
# Crop and Resize
# From Mask R-CNN paper: "We sample four regular locations, so
# that we can evaluate either max or average pooling. In fact,
# interpolating only a single value at each bin center (without
# pooling) is nearly as effective."
#
# Here we use the simplified approach of a single value per bin,
# which is how it's done in tf.crop_and_resize()
# Result: [batch * num_boxes, pool_height, pool_width, channels]
pooled.append(tf.image.crop_and_resize(
feature_maps[i], level_boxes, box_indices, self.pool_shape,
method="bilinear"))
# Pack pooled features into one tensor
pooled = tf.concat(pooled, axis=0)
# Pack box_to_level mapping into one array and add another
# column representing the order of pooled boxes
box_to_level = tf.concat(box_to_level, axis=0)
box_range = tf.expand_dims(tf.range(tf.shape(box_to_level)[0]), 1)
box_to_level = tf.concat([tf.cast(box_to_level, tf.int32), box_range],
axis=1)
# Rearrange pooled features to match the order of the original boxes
# Sort box_to_level by batch then box index
# TF doesn't have a way to sort by two columns, so merge them and sort.
sorting_tensor = box_to_level[:, 0] * 100000 + box_to_level[:, 1]
ix = tf.nn.top_k(sorting_tensor, k=tf.shape(
box_to_level)[0]).indices[::-1]
ix = tf.gather(box_to_level[:, 2], ix)
pooled = tf.gather(pooled, ix)
# Re-add the batch dimension
shape = tf.concat([tf.shape(boxes)[:2], tf.shape(pooled)[1:]], axis=0)
pooled = tf.reshape(pooled, shape)
return pooled
def compute_output_shape(self, input_shape):
return input_shape[0][:2] + self.pool_shape + (input_shape[2][-1], )
############################################################
# Detection Target Layer
############################################################
def overlaps_graph(boxes1, boxes2):
"""Computes IoU overlaps between two sets of boxes.
boxes1, boxes2: [N, (y1, x1, y2, x2)].
"""
# 1. Tile boxes2 and repeat boxes1. This allows us to compare
# every boxes1 against every boxes2 without loops.
# TF doesn't have an equivalent to np.repeat() so simulate it
# using tf.tile() and tf.reshape.
b1 = tf.reshape(tf.tile(tf.expand_dims(boxes1, 1),
[1, 1, tf.shape(boxes2)[0]]), [-1, 4])
b2 = tf.tile(boxes2, [tf.shape(boxes1)[0], 1])
# 2. Compute intersections
b1_y1, b1_x1, b1_y2, b1_x2 = tf.split(b1, 4, axis=1)
b2_y1, b2_x1, b2_y2, b2_x2 = tf.split(b2, 4, axis=1)
y1 = tf.maximum(b1_y1, b2_y1)
x1 = tf.maximum(b1_x1, b2_x1)
y2 = tf.minimum(b1_y2, b2_y2)
x2 = tf.minimum(b1_x2, b2_x2)
intersection = tf.maximum(x2 - x1, 0) * tf.maximum(y2 - y1, 0)
# 3. Compute unions
b1_area = (b1_y2 - b1_y1) * (b1_x2 - b1_x1)
b2_area = (b2_y2 - b2_y1) * (b2_x2 - b2_x1)
union = b1_area + b2_area - intersection
# 4. Compute IoU and reshape to [boxes1, boxes2]
iou = intersection / union
overlaps = tf.reshape(iou, [tf.shape(boxes1)[0], tf.shape(boxes2)[0]])
return overlaps
def detection_targets_graph(proposals, gt_class_ids, gt_boxes, gt_masks, config):
"""Generates detection targets for one image. Subsamples proposals and
generates target class IDs, bounding box deltas, and masks for each.
Inputs:
proposals: [POST_NMS_ROIS_TRAINING, (y1, x1, y2, x2)] in normalized coordinates. Might
be zero padded if there are not enough proposals.
gt_class_ids: [MAX_GT_INSTANCES] int class IDs
gt_boxes: [MAX_GT_INSTANCES, (y1, x1, y2, x2)] in normalized coordinates.
gt_masks: [height, width, MAX_GT_INSTANCES] of boolean type.
Returns: Target ROIs and corresponding class IDs, bounding box shifts,
and masks.
rois: [TRAIN_ROIS_PER_IMAGE, (y1, x1, y2, x2)] in normalized coordinates
class_ids: [TRAIN_ROIS_PER_IMAGE]. Integer class IDs. Zero padded.
deltas: [TRAIN_ROIS_PER_IMAGE, (dy, dx, log(dh), log(dw))]
masks: [TRAIN_ROIS_PER_IMAGE, height, width]. Masks cropped to bbox
boundaries and resized to neural network output size.
Note: Returned arrays might be zero padded if not enough target ROIs.
"""
# Assertions
asserts = [
tf.Assert(tf.greater(tf.shape(proposals)[0], 0), [proposals],
name="roi_assertion"),
]
with tf.control_dependencies(asserts):
proposals = tf.identity(proposals)
# Remove zero padding
proposals, _ = trim_zeros_graph(proposals, name="trim_proposals")
gt_boxes, non_zeros = trim_zeros_graph(gt_boxes, name="trim_gt_boxes")
gt_class_ids = tf.boolean_mask(gt_class_ids, non_zeros,
name="trim_gt_class_ids")
gt_masks = tf.gather(gt_masks, tf.where(non_zeros)[:, 0], axis=2,
name="trim_gt_masks")
# Handle COCO crowds
# A crowd box in COCO is a bounding box around several instances. Exclude
# them from training. A crowd box is given a negative class ID.
crowd_ix = tf.where(gt_class_ids < 0)[:, 0]
non_crowd_ix = tf.where(gt_class_ids > 0)[:, 0]
crowd_boxes = tf.gather(gt_boxes, crowd_ix)
gt_class_ids = tf.gather(gt_class_ids, non_crowd_ix)
gt_boxes = tf.gather(gt_boxes, non_crowd_ix)
gt_masks = tf.gather(gt_masks, non_crowd_ix, axis=2)
# Compute overlaps matrix [proposals, gt_boxes]
overlaps = overlaps_graph(proposals, gt_boxes)
# Compute overlaps with crowd boxes [proposals, crowd_boxes]
crowd_overlaps = overlaps_graph(proposals, crowd_boxes)
crowd_iou_max = tf.reduce_max(crowd_overlaps, axis=1)
no_crowd_bool = (crowd_iou_max < 0.001)
# Determine positive and negative ROIs
roi_iou_max = tf.reduce_max(overlaps, axis=1)
# 1. Positive ROIs are those with >= 0.5 IoU with a GT box
positive_roi_bool = (roi_iou_max >= 0.5)
positive_indices = tf.where(positive_roi_bool)[:, 0]
# 2. Negative ROIs are those with < 0.5 with every GT box. Skip crowds.
negative_indices = tf.where(tf.logical_and(roi_iou_max < 0.5, no_crowd_bool))[:, 0]
# Subsample ROIs. Aim for 33% positive
# Positive ROIs
positive_count = int(config.TRAIN_ROIS_PER_IMAGE *
config.ROI_POSITIVE_RATIO)
positive_indices = tf.random_shuffle(positive_indices)[:positive_count]
positive_count = tf.shape(positive_indices)[0]
# Negative ROIs. Add enough to maintain positive:negative ratio.
r = 1.0 / config.ROI_POSITIVE_RATIO
negative_count = tf.cast(r * tf.cast(positive_count, tf.float32), tf.int32) - positive_count
negative_indices = tf.random_shuffle(negative_indices)[:negative_count]
# Gather selected ROIs
positive_rois = tf.gather(proposals, positive_indices)
negative_rois = tf.gather(proposals, negative_indices)
# Assign positive ROIs to GT boxes.
positive_overlaps = tf.gather(overlaps, positive_indices)
roi_gt_box_assignment = tf.cond(
tf.greater(tf.shape(positive_overlaps)[1], 0),
true_fn = lambda: tf.argmax(positive_overlaps, axis=1),
false_fn = lambda: tf.cast(tf.constant([]),tf.int64)
)
roi_gt_boxes = tf.gather(gt_boxes, roi_gt_box_assignment)
roi_gt_class_ids = tf.gather(gt_class_ids, roi_gt_box_assignment)
# Compute bbox refinement for positive ROIs
deltas = utils.box_refinement_graph(positive_rois, roi_gt_boxes)
deltas /= config.BBOX_STD_DEV
# Assign positive ROIs to GT masks
# Permute masks to [N, height, width, 1]
transposed_masks = tf.expand_dims(tf.transpose(gt_masks, [2, 0, 1]), -1)
# Pick the right mask for each ROI
roi_masks = tf.gather(transposed_masks, roi_gt_box_assignment)
# Compute mask targets
boxes = positive_rois
if config.USE_MINI_MASK:
# Transform ROI coordinates from normalized image space
# to normalized mini-mask space.
y1, x1, y2, x2 = tf.split(positive_rois, 4, axis=1)
gt_y1, gt_x1, gt_y2, gt_x2 = tf.split(roi_gt_boxes, 4, axis=1)
gt_h = gt_y2 - gt_y1
gt_w = gt_x2 - gt_x1
y1 = (y1 - gt_y1) / gt_h
x1 = (x1 - gt_x1) / gt_w
y2 = (y2 - gt_y1) / gt_h
x2 = (x2 - gt_x1) / gt_w
boxes = tf.concat([y1, x1, y2, x2], 1)
box_ids = tf.range(0, tf.shape(roi_masks)[0])
masks = tf.image.crop_and_resize(tf.cast(roi_masks, tf.float32), boxes,
box_ids,
config.MASK_SHAPE)
# Remove the extra dimension from masks.
masks = tf.squeeze(masks, axis=3)
# Threshold mask pixels at 0.5 to have GT masks be 0 or 1 to use with
# binary cross entropy loss.
masks = tf.round(masks)
# Append negative ROIs and pad bbox deltas and masks that
# are not used for negative ROIs with zeros.
rois = tf.concat([positive_rois, negative_rois], axis=0)
N = tf.shape(negative_rois)[0]
P = tf.maximum(config.TRAIN_ROIS_PER_IMAGE - tf.shape(rois)[0], 0)
rois = tf.pad(rois, [(0, P), (0, 0)])
roi_gt_boxes = tf.pad(roi_gt_boxes, [(0, N + P), (0, 0)])
roi_gt_class_ids = tf.pad(roi_gt_class_ids, [(0, N + P)])
deltas = tf.pad(deltas, [(0, N + P), (0, 0)])
masks = tf.pad(masks, [[0, N + P], (0, 0), (0, 0)])
return rois, roi_gt_class_ids, deltas, masks
class DetectionTargetLayer(KE.Layer):
"""Subsamples proposals and generates target box refinement, class_ids,
and masks for each.
Inputs:
proposals: [batch, N, (y1, x1, y2, x2)] in normalized coordinates. Might
be zero padded if there are not enough proposals.
gt_class_ids: [batch, MAX_GT_INSTANCES] Integer class IDs.
gt_boxes: [batch, MAX_GT_INSTANCES, (y1, x1, y2, x2)] in normalized
coordinates.
gt_masks: [batch, height, width, MAX_GT_INSTANCES] of boolean type
Returns: Target ROIs and corresponding class IDs, bounding box shifts,
and masks.
rois: [batch, TRAIN_ROIS_PER_IMAGE, (y1, x1, y2, x2)] in normalized
coordinates
target_class_ids: [batch, TRAIN_ROIS_PER_IMAGE]. Integer class IDs.
target_deltas: [batch, TRAIN_ROIS_PER_IMAGE, (dy, dx, log(dh), log(dw)]
target_mask: [batch, TRAIN_ROIS_PER_IMAGE, height, width]
Masks cropped to bbox boundaries and resized to neural
network output size.
Note: Returned arrays might be zero padded if not enough target ROIs.
"""
def __init__(self, config, **kwargs):
super(DetectionTargetLayer, self).__init__(**kwargs)
self.config = config
def call(self, inputs):
proposals = inputs[0]
gt_class_ids = inputs[1]
gt_boxes = inputs[2]
gt_masks = inputs[3]
# Slice the batch and run a graph for each slice
# TODO: Rename target_bbox to target_deltas for clarity
names = ["rois", "target_class_ids", "target_bbox", "target_mask"]
outputs = utils.batch_slice(
[proposals, gt_class_ids, gt_boxes, gt_masks],
lambda w, x, y, z: detection_targets_graph(
w, x, y, z, self.config),
self.config.IMAGES_PER_GPU, names=names)
return outputs
def compute_output_shape(self, input_shape):
return [
(None, self.config.TRAIN_ROIS_PER_IMAGE, 4), # rois
(None, self.config.TRAIN_ROIS_PER_IMAGE), # class_ids
(None, self.config.TRAIN_ROIS_PER_IMAGE, 4), # deltas
(None, self.config.TRAIN_ROIS_PER_IMAGE, self.config.MASK_SHAPE[0],
self.config.MASK_SHAPE[1]) # masks
]
def compute_mask(self, inputs, mask=None):
return [None, None, None, None]
############################################################
# Detection Layer
############################################################
def refine_detections_graph(rois, probs, deltas, window, config):
"""Refine classified proposals and filter overlaps and return final
detections.
Inputs:
rois: [N, (y1, x1, y2, x2)] in normalized coordinates
probs: [N, num_classes]. Class probabilities.
deltas: [N, num_classes, (dy, dx, log(dh), log(dw))]. Class-specific
bounding box deltas.
window: (y1, x1, y2, x2) in normalized coordinates. The part of the image
that contains the image excluding the padding.
Returns detections shaped: [num_detections, (y1, x1, y2, x2, class_id, score)] where
coordinates are normalized.
"""
# Class IDs per ROI
class_ids = tf.argmax(probs, axis=1, output_type=tf.int32)
# Class probability of the top class of each ROI
indices = tf.stack([tf.range(probs.shape[0]), class_ids], axis=1)
class_scores = tf.gather_nd(probs, indices)
# Class-specific bounding box deltas
deltas_specific = tf.gather_nd(deltas, indices)
# Apply bounding box deltas
# Shape: [boxes, (y1, x1, y2, x2)] in normalized coordinates
refined_rois = apply_box_deltas_graph(
rois, deltas_specific * config.BBOX_STD_DEV)
# Clip boxes to image window
refined_rois = clip_boxes_graph(refined_rois, window)
# TODO: Filter out boxes with zero area
# Filter out background boxes
keep = tf.where(class_ids > 0)[:, 0]
# Filter out low confidence boxes
if config.DETECTION_MIN_CONFIDENCE:
conf_keep = tf.where(class_scores >= config.DETECTION_MIN_CONFIDENCE)[:, 0]
keep = tf.sets.set_intersection(tf.expand_dims(keep, 0),
tf.expand_dims(conf_keep, 0))
keep = tf.sparse_tensor_to_dense(keep)[0]
# Apply per-class NMS
# 1. Prepare variables
pre_nms_class_ids = tf.gather(class_ids, keep)
pre_nms_scores = tf.gather(class_scores, keep)
pre_nms_rois = tf.gather(refined_rois, keep)
unique_pre_nms_class_ids = tf.unique(pre_nms_class_ids)[0]
def nms_keep_map(class_id):
"""Apply Non-Maximum Suppression on ROIs of the given class."""
# Indices of ROIs of the given class
ixs = tf.where(tf.equal(pre_nms_class_ids, class_id))[:, 0]
# Apply NMS
class_keep = tf.image.non_max_suppression(
tf.gather(pre_nms_rois, ixs),
tf.gather(pre_nms_scores, ixs),
max_output_size=config.DETECTION_MAX_INSTANCES,
iou_threshold=config.DETECTION_NMS_THRESHOLD)
# Map indices
class_keep = tf.gather(keep, tf.gather(ixs, class_keep))
# Pad with -1 so returned tensors have the same shape
gap = config.DETECTION_MAX_INSTANCES - tf.shape(class_keep)[0]
class_keep = tf.pad(class_keep, [(0, gap)],
mode='CONSTANT', constant_values=-1)
# Set shape so map_fn() can infer result shape
class_keep.set_shape([config.DETECTION_MAX_INSTANCES])
return class_keep
# 2. Map over class IDs
nms_keep = tf.map_fn(nms_keep_map, unique_pre_nms_class_ids,
dtype=tf.int64)
# 3. Merge results into one list, and remove -1 padding
nms_keep = tf.reshape(nms_keep, [-1])
nms_keep = tf.gather(nms_keep, tf.where(nms_keep > -1)[:, 0])
# 4. Compute intersection between keep and nms_keep
keep = tf.sets.set_intersection(tf.expand_dims(keep, 0),
tf.expand_dims(nms_keep, 0))
keep = tf.sparse_tensor_to_dense(keep)[0]
# Keep top detections
roi_count = config.DETECTION_MAX_INSTANCES
class_scores_keep = tf.gather(class_scores, keep)
num_keep = tf.minimum(tf.shape(class_scores_keep)[0], roi_count)
top_ids = tf.nn.top_k(class_scores_keep, k=num_keep, sorted=True)[1]
keep = tf.gather(keep, top_ids)
# Arrange output as [N, (y1, x1, y2, x2, class_id, score)]
# Coordinates are normalized.
detections = tf.concat([
tf.gather(refined_rois, keep),
tf.to_float(tf.gather(class_ids, keep))[..., tf.newaxis],
tf.gather(class_scores, keep)[..., tf.newaxis]
], axis=1)
# Pad with zeros if detections < DETECTION_MAX_INSTANCES
gap = config.DETECTION_MAX_INSTANCES - tf.shape(detections)[0]
detections = tf.pad(detections, [(0, gap), (0, 0)], "CONSTANT")
return detections
class DetectionLayer(KE.Layer):
"""Takes classified proposal boxes and their bounding box deltas and
returns the final detection boxes.
Returns:
[batch, num_detections, (y1, x1, y2, x2, class_id, class_score)] where
coordinates are normalized.
"""
def __init__(self, config=None, **kwargs):
super(DetectionLayer, self).__init__(**kwargs)
self.config = config
def call(self, inputs):
rois = inputs[0]
mrcnn_class = inputs[1]
mrcnn_bbox = inputs[2]
image_meta = inputs[3]
# Get windows of images in normalized coordinates. Windows are the area
# in the image that excludes the padding.
# Use the shape of the first image in the batch to normalize the window
# because we know that all images get resized to the same size.
m = parse_image_meta_graph(image_meta)
image_shape = m['image_shape'][0]
window = norm_boxes_graph(m['window'], image_shape[:2])
# Run detection refinement graph on each item in the batch
detections_batch = utils.batch_slice(
[rois, mrcnn_class, mrcnn_bbox, window],
lambda x, y, w, z: refine_detections_graph(x, y, w, z, self.config),
self.config.IMAGES_PER_GPU)
# Reshape output
# [batch, num_detections, (y1, x1, y2, x2, class_id, class_score)] in
# normalized coordinates
return tf.reshape(
detections_batch,
[self.config.BATCH_SIZE, self.config.DETECTION_MAX_INSTANCES, 6])
def compute_output_shape(self, input_shape):
return (None, self.config.DETECTION_MAX_INSTANCES, 6)
############################################################
# Region Proposal Network (RPN)
############################################################
def rpn_graph(feature_map, anchors_per_location, anchor_stride):
"""Builds the computation graph of Region Proposal Network.
feature_map: backbone features [batch, height, width, depth]
anchors_per_location: number of anchors per pixel in the feature map
anchor_stride: Controls the density of anchors. Typically 1 (anchors for
every pixel in the feature map), or 2 (every other pixel).
Returns:
rpn_class_logits: [batch, H * W * anchors_per_location, 2] Anchor classifier logits (before softmax)
rpn_probs: [batch, H * W * anchors_per_location, 2] Anchor classifier probabilities.
rpn_bbox: [batch, H * W * anchors_per_location, (dy, dx, log(dh), log(dw))] Deltas to be
applied to anchors.
"""
# TODO: check if stride of 2 causes alignment issues if the feature map
# is not even.
# Shared convolutional base of the RPN
shared = KL.Conv2D(512, (3, 3), padding='same', activation='relu',
strides=anchor_stride,
name='rpn_conv_shared')(feature_map)
# Anchor Score. [batch, height, width, anchors per location * 2].
x = KL.Conv2D(2 * anchors_per_location, (1, 1), padding='valid',
activation='linear', name='rpn_class_raw')(shared)
# Reshape to [batch, anchors, 2]
rpn_class_logits = KL.Lambda(
lambda t: tf.reshape(t, [tf.shape(t)[0], -1, 2]))(x)
# Softmax on last dimension of BG/FG.
rpn_probs = KL.Activation(
"softmax", name="rpn_class_xxx")(rpn_class_logits)
# Bounding box refinement. [batch, H, W, anchors per location * depth]
# where depth is [x, y, log(w), log(h)]
x = KL.Conv2D(anchors_per_location * 4, (1, 1), padding="valid",
activation='linear', name='rpn_bbox_pred')(shared)
# Reshape to [batch, anchors, 4]
rpn_bbox = KL.Lambda(lambda t: tf.reshape(t, [tf.shape(t)[0], -1, 4]))(x)
return [rpn_class_logits, rpn_probs, rpn_bbox]
def build_rpn_model(anchor_stride, anchors_per_location, depth):
"""Builds a Keras model of the Region Proposal Network.
It wraps the RPN graph so it can be used multiple times with shared
weights.
anchors_per_location: number of anchors per pixel in the feature map
anchor_stride: Controls the density of anchors. Typically 1 (anchors for
every pixel in the feature map), or 2 (every other pixel).
depth: Depth of the backbone feature map.
Returns a Keras Model object. The model outputs, when called, are:
rpn_class_logits: [batch, H * W * anchors_per_location, 2] Anchor classifier logits (before softmax)
rpn_probs: [batch, H * W * anchors_per_location, 2] Anchor classifier probabilities.
rpn_bbox: [batch, H * W * anchors_per_location, (dy, dx, log(dh), log(dw))] Deltas to be
applied to anchors.
"""
input_feature_map = KL.Input(shape=[None, None, depth],
name="input_rpn_feature_map")
outputs = rpn_graph(input_feature_map, anchors_per_location, anchor_stride)
return KM.Model([input_feature_map], outputs, name="rpn_model")
############################################################
# Feature Pyramid Network Heads
############################################################
def fpn_classifier_graph(rois, feature_maps, image_meta,
pool_size, num_classes, train_bn=True,
fc_layers_size=1024):
"""Builds the computation graph of the feature pyramid network classifier
and regressor heads.
rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
coordinates.
feature_maps: List of feature maps from different layers of the pyramid,
[P2, P3, P4, P5]. Each has a different resolution.
image_meta: [batch, (meta data)] Image details. See compose_image_meta()
pool_size: The width of the square feature map generated from ROI Pooling.
num_classes: number of classes, which determines the depth of the results
train_bn: Boolean. Train or freeze Batch Norm layers
fc_layers_size: Size of the 2 FC layers
Returns:
logits: [batch, num_rois, NUM_CLASSES] classifier logits (before softmax)
probs: [batch, num_rois, NUM_CLASSES] classifier probabilities
bbox_deltas: [batch, num_rois, NUM_CLASSES, (dy, dx, log(dh), log(dw))] Deltas to apply to
proposal boxes
"""
# ROI Pooling
# Shape: [batch, num_rois, POOL_SIZE, POOL_SIZE, channels]
x = PyramidROIAlign([pool_size, pool_size],
name="roi_align_classifier")([rois, image_meta] + feature_maps)
# Two 1024 FC layers (implemented with Conv2D for consistency)
x = KL.TimeDistributed(KL.Conv2D(fc_layers_size, (pool_size, pool_size), padding="valid"),
name="mrcnn_class_conv1")(x)
x = KL.TimeDistributed(BatchNorm(), name='mrcnn_class_bn1')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.TimeDistributed(KL.Conv2D(fc_layers_size, (1, 1)),
name="mrcnn_class_conv2")(x)
x = KL.TimeDistributed(BatchNorm(), name='mrcnn_class_bn2')(x, training=train_bn)
x = KL.Activation('relu')(x)
shared = KL.Lambda(lambda x: K.squeeze(K.squeeze(x, 3), 2),
name="pool_squeeze")(x)
# Classifier head
mrcnn_class_logits = KL.TimeDistributed(KL.Dense(num_classes),
name='mrcnn_class_logits')(shared)
mrcnn_probs = KL.TimeDistributed(KL.Activation("softmax"),
name="mrcnn_class")(mrcnn_class_logits)
# BBox head
# [batch, num_rois, NUM_CLASSES * (dy, dx, log(dh), log(dw))]
x = KL.TimeDistributed(KL.Dense(num_classes * 4, activation='linear'),
name='mrcnn_bbox_fc')(shared)
# Reshape to [batch, num_rois, NUM_CLASSES, (dy, dx, log(dh), log(dw))]
s = K.int_shape(x)
mrcnn_bbox = KL.Reshape((s[1], num_classes, 4), name="mrcnn_bbox")(x)
return mrcnn_class_logits, mrcnn_probs, mrcnn_bbox
def build_fpn_mask_graph(rois, feature_maps, image_meta,
pool_size, num_classes, train_bn=True):
"""Builds the computation graph of the mask head of Feature Pyramid Network.
rois: [batch, num_rois, (y1, x1, y2, x2)] Proposal boxes in normalized
coordinates.
feature_maps: List of feature maps from different layers of the pyramid,
[P2, P3, P4, P5]. Each has a different resolution.
image_meta: [batch, (meta data)] Image details. See compose_image_meta()
pool_size: The width of the square feature map generated from ROI Pooling.
num_classes: number of classes, which determines the depth of the results
train_bn: Boolean. Train or freeze Batch Norm layers
Returns: Masks [batch, num_rois, MASK_POOL_SIZE, MASK_POOL_SIZE, NUM_CLASSES]
"""
# ROI Pooling
# Shape: [batch, num_rois, MASK_POOL_SIZE, MASK_POOL_SIZE, channels]
x = PyramidROIAlign([pool_size, pool_size],
name="roi_align_mask")([rois, image_meta] + feature_maps)
# Conv layers
x = KL.TimeDistributed(KL.Conv2D(256, (3, 3), padding="same"),
name="mrcnn_mask_conv1")(x)
x = KL.TimeDistributed(BatchNorm(),
name='mrcnn_mask_bn1')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.TimeDistributed(KL.Conv2D(256, (3, 3), padding="same"),
name="mrcnn_mask_conv2")(x)
x = KL.TimeDistributed(BatchNorm(),
name='mrcnn_mask_bn2')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.TimeDistributed(KL.Conv2D(256, (3, 3), padding="same"),
name="mrcnn_mask_conv3")(x)
x = KL.TimeDistributed(BatchNorm(),
name='mrcnn_mask_bn3')(x, training=train_bn)
x = KL.Activation('relu')(x)
x = KL.TimeDistributed(KL.Conv2D(256, (3, 3), padding="same"),
name="mrcnn_mask_conv4")(x)
x = KL.TimeDistributed(BatchNorm(),
name='mrcnn_mask_bn4')(x, training=train_bn)
x = KL.Activation('relu')(x)