-
Notifications
You must be signed in to change notification settings - Fork 14
/
boxlist.py
452 lines (384 loc) · 15.1 KB
/
boxlist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
import torch
from torchvision import ops
FLIP_LEFT_RIGHT = 0
FLIP_TOP_BOTTOM = 1
class BoxList(object):
"""
This class represents a set of bounding boxes.
The bounding boxes are represented as a Nx4 Tensor.
In order to uniquely determine the bounding boxes with respect
to an image, we also store the corresponding image dimensions.
They can contain extra information that is specific to each bounding box, such as
labels.
"""
def __init__(self, bbox, image_size, mode="xyxy"):
device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu")
bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device)
# if bbox.ndimension() != 2:
# raise ValueError(
# "bbox should have 2 dimensions, got {}".format(bbox.ndimension())
# )
# if bbox.size(-1) != 4:
# raise ValueError(
# "last dimension of bbox should have a "
# "size of 4, got {}".format(bbox.size(-1))
# )
if mode not in ("xyxy", "xywh"):
raise ValueError("mode should be 'xyxy' or 'xywh'")
self.bbox = bbox
self.size = image_size # (image_width, image_height)
self.mode = mode
self.extra_fields = {}
def add_field(self, field, field_data):
self.extra_fields[field] = field_data
def get_field(self, field):
return self.extra_fields[field]
def has_field(self, field):
return field in self.extra_fields
def fields(self):
return list(self.extra_fields.keys())
def _copy_extra_fields(self, bbox):
for k, v in bbox.extra_fields.items():
self.extra_fields[k] = v
def convert(self, mode):
if mode not in ("xyxy", "xywh"):
raise ValueError("mode should be 'xyxy' or 'xywh'")
if mode == self.mode:
return self
# we only have two modes, so don't need to check
# self.mode
xmin, ymin, xmax, ymax = self._split_into_xyxy()
if mode == "xyxy":
bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
bbox = BoxList(bbox, self.size, mode=mode)
else:
TO_REMOVE = 1
bbox = torch.cat(
(xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1
)
bbox = BoxList(bbox, self.size, mode=mode)
bbox._copy_extra_fields(self)
return bbox
def _split_into_xyxy(self):
if self.mode == "xyxy":
xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1)
return xmin, ymin, xmax, ymax
elif self.mode == "xywh":
TO_REMOVE = 1
xmin, ymin, w, h = self.bbox.split(1, dim=-1)
return (
xmin,
ymin,
xmin + (w - TO_REMOVE).clamp(min=0),
ymin + (h - TO_REMOVE).clamp(min=0),
)
else:
raise RuntimeError("Should not be here")
def resize(self, size, *args, **kwargs):
"""
Returns a resized copy of this bounding box
:param size: The requested size in pixels, as a 2-tuple:
(width, height).
"""
ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
if ratios[0] == ratios[1]:
ratio = ratios[0]
scaled_box = self.bbox * ratio
bbox = BoxList(scaled_box, size, mode=self.mode)
# bbox._copy_extra_fields(self)
for k, v in self.extra_fields.items():
if not isinstance(v, torch.Tensor):
v = v.resize(size, *args, **kwargs)
bbox.add_field(k, v)
return bbox
ratio_width, ratio_height = ratios
xmin, ymin, xmax, ymax = self._split_into_xyxy()
scaled_xmin = xmin * ratio_width
scaled_xmax = xmax * ratio_width
scaled_ymin = ymin * ratio_height
scaled_ymax = ymax * ratio_height
scaled_box = torch.cat(
(scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1
)
bbox = BoxList(scaled_box, size, mode="xyxy")
# bbox._copy_extra_fields(self)
for k, v in self.extra_fields.items():
if not isinstance(v, torch.Tensor):
v = v.resize(size, *args, **kwargs)
bbox.add_field(k, v)
return bbox.convert(self.mode)
def transpose(self, method):
"""
Transpose bounding box (flip or rotate in 90 degree steps)
:param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`,
:py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`,
:py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`,
:py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`.
"""
if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
raise NotImplementedError(
"Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
)
image_width, image_height = self.size
xmin, ymin, xmax, ymax = self._split_into_xyxy()
if method == FLIP_LEFT_RIGHT:
TO_REMOVE = 1
transposed_xmin = image_width - xmax - TO_REMOVE
transposed_xmax = image_width - xmin - TO_REMOVE
transposed_ymin = ymin
transposed_ymax = ymax
elif method == FLIP_TOP_BOTTOM:
transposed_xmin = xmin
transposed_xmax = xmax
transposed_ymin = image_height - ymax
transposed_ymax = image_height - ymin
transposed_boxes = torch.cat(
(transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
)
bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
# bbox._copy_extra_fields(self)
for k, v in self.extra_fields.items():
if not isinstance(v, torch.Tensor):
v = v.transpose(method)
bbox.add_field(k, v)
return bbox.convert(self.mode)
def crop(self, box):
"""
Cropss a rectangular region from this bounding box. The box is a
4-tuple defining the left, upper, right, and lower pixel
coordinate.
"""
xmin, ymin, xmax, ymax = self._split_into_xyxy()
w, h = box[2] - box[0], box[3] - box[1]
cropped_xmin = (xmin - box[0]).clamp(min=0, max=w)
cropped_ymin = (ymin - box[1]).clamp(min=0, max=h)
cropped_xmax = (xmax - box[0]).clamp(min=0, max=w)
cropped_ymax = (ymax - box[1]).clamp(min=0, max=h)
# TODO should I filter empty boxes here?
if False:
is_empty = (cropped_xmin == cropped_xmax) | (cropped_ymin == cropped_ymax)
cropped_box = torch.cat(
(cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1
)
bbox = BoxList(cropped_box, (w, h), mode="xyxy")
# bbox._copy_extra_fields(self)
for k, v in self.extra_fields.items():
if not isinstance(v, torch.Tensor):
v = v.crop(box)
bbox.add_field(k, v)
return bbox.convert(self.mode)
# Tensor-like methods
def to(self, device):
bbox = BoxList(self.bbox.to(device), self.size, self.mode)
for k, v in self.extra_fields.items():
if hasattr(v, "to"):
v = v.to(device)
bbox.add_field(k, v)
return bbox
def __getitem__(self, item):
bbox = BoxList(self.bbox[item], self.size, self.mode)
for k, v in self.extra_fields.items():
bbox.add_field(k, v[item])
return bbox
def __len__(self):
return self.bbox.shape[0]
def clip_to_image(self, remove_empty=True):
TO_REMOVE = 1
self.bbox[:, 0].clamp_(min=0, max=self.size[0] - TO_REMOVE)
self.bbox[:, 1].clamp_(min=0, max=self.size[1] - TO_REMOVE)
self.bbox[:, 2].clamp_(min=0, max=self.size[0] - TO_REMOVE)
self.bbox[:, 3].clamp_(min=0, max=self.size[1] - TO_REMOVE)
if remove_empty:
box = self.bbox
keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
return self[keep]
return self
def area(self):
box = self.bbox
if self.mode == "xyxy":
TO_REMOVE = 1
area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
elif self.mode == "xywh":
area = box[:, 2] * box[:, 3]
else:
raise RuntimeError("Should not be here")
return area
def box_span(self):
box = self.bbox
if self.mode == "xyxy":
TO_REMOVE = 1
span = torch.max((box[:, 2] - box[:, 0] + TO_REMOVE), (box[:, 3] - box[:, 1] + TO_REMOVE))
elif self.mode == "xywh":
span = torch.max(box[:, 2], box[:, 3])
else:
raise RuntimeError("Should not be here")
return span
def copy_with_fields(self, fields, skip_missing=False):
bbox = BoxList(self.bbox, self.size, self.mode)
if not isinstance(fields, (list, tuple)):
fields = [fields]
for field in fields:
if self.has_field(field):
bbox.add_field(field, self.get_field(field))
elif not skip_missing:
raise KeyError("Field '{}' not found in {}".format(field, self))
return bbox
def __repr__(self):
s = self.__class__.__name__ + "("
s += "num_boxes={}, ".format(len(self))
s += "image_width={}, ".format(self.size[0])
s += "image_height={}, ".format(self.size[1])
s += "mode={})".format(self.mode)
return s
def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
"""
Performs non-maximum suppression on a boxlist, with scores specified
in a boxlist field via score_field.
Arguments:
boxlist(BoxList)
nms_thresh (float)
max_proposals (int): if > 0, then only the top max_proposals are kept
after non-maximum suppression
score_field (str)
"""
if nms_thresh <= 0:
return boxlist
mode = boxlist.mode
boxlist = boxlist.convert("xyxy")
boxes = boxlist.bbox
score = boxlist.get_field(score_field)
keep = ops.nms(boxes, score, nms_thresh)
if max_proposals > 0:
keep = keep[: max_proposals]
boxlist = boxlist[keep]
return boxlist.convert(mode)
def batched_nms(boxes, scores, idxs, iou_threshold):
"""
Performs non-maximum suppression in a batched fashion.
Each index value correspond to a category, and NMS
will not be applied between elements of different categories.
Parameters
----------
boxes : Tensor[N, 4]
boxes where NMS will be performed. They
are expected to be in (x1, y1, x2, y2) format
scores : Tensor[N]
scores for each one of the boxes
idxs : Tensor[N]
indices of the categories for each one of the boxes.
iou_threshold : float
discards all overlapping boxes
with IoU < iou_threshold
Returns
-------
keep : Tensor
int64 tensor with the indices of
the elements that have been kept by NMS, sorted
in decreasing order of scores
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.int64, device=boxes.device)
# strategy: in order to perform NMS independently per class.
# we add an offset to all the boxes. The offset is dependent
# only on the class idx, and is large enough so that boxes
# from different classes do not overlap
max_coordinate = boxes.max()
offsets = idxs.to(boxes) * (max_coordinate + 1)
boxes_for_nms = boxes + offsets[:, None]
keep = ops.nms(boxes_for_nms, scores, iou_threshold)
return keep
def boxlist_ml_nms(boxlist, nms_thresh, max_proposals=-1,
score_field="scores", label_field="labels"):
"""
Performs non-maximum suppression on a boxlist, with scores specified
in a boxlist field via score_field.
Arguments:
boxlist(BoxList)
nms_thresh (float)
max_proposals (int): if > 0, then only the top max_proposals are kept
after non-maximum suppression
score_field (str)
"""
if nms_thresh <= 0:
return boxlist
mode = boxlist.mode
boxlist = boxlist.convert("xyxy")
boxes = boxlist.bbox
scores = boxlist.get_field(score_field)
labels = boxlist.get_field(label_field)
keep = batched_nms(boxes, scores, labels.float(), nms_thresh)
if max_proposals > 0:
keep = keep[: max_proposals]
boxlist = boxlist[keep]
return boxlist.convert(mode)
def remove_small_boxes(boxlist, min_size):
"""
Only keep boxes with both sides >= min_size
Arguments:
boxlist (Boxlist)
min_size (int)
"""
# TODO maybe add an API for querying the ws / hs
xywh_boxes = boxlist.convert("xywh").bbox
_, _, ws, hs = xywh_boxes.unbind(dim=1)
keep = (
(ws >= min_size) & (hs >= min_size)
).nonzero().squeeze(1)
return boxlist[keep]
# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def boxlist_iou(boxlist1, boxlist2):
"""Compute the intersection over union of two set of boxes.
The box order must be (xmin, ymin, xmax, ymax).
Arguments:
box1: (BoxList) bounding boxes, sized [N,4].
box2: (BoxList) bounding boxes, sized [M,4].
Returns:
(tensor) iou, sized [N,M].
Reference:
https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py
"""
if boxlist1.size != boxlist2.size:
raise RuntimeError(
"boxlists should have same image size, got {}, {}".format(boxlist1, boxlist2))
N = len(boxlist1)
M = len(boxlist2)
area1 = boxlist1.area()
area2 = boxlist2.area()
box1, box2 = boxlist1.bbox, boxlist2.bbox
lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2]
rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2]
TO_REMOVE = 1
wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2]
inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
iou = inter / (area1[:, None] + area2 - inter)
return iou
# TODO redundant, remove
def _cat(tensors, dim=0):
"""
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
"""
assert isinstance(tensors, (list, tuple))
if len(tensors) == 1:
return tensors[0]
return torch.cat(tensors, dim)
def cat_boxlist(bboxes):
"""
Concatenates a list of BoxList (having the same image size) into a
single BoxList
Arguments:
bboxes (list[BoxList])
"""
assert isinstance(bboxes, (list, tuple))
assert all(isinstance(bbox, BoxList) for bbox in bboxes)
size = bboxes[0].size
assert all(bbox.size == size for bbox in bboxes)
mode = bboxes[0].mode
assert all(bbox.mode == mode for bbox in bboxes)
fields = set(bboxes[0].fields())
assert all(set(bbox.fields()) == fields for bbox in bboxes)
cat_boxes = BoxList(_cat([bbox.bbox for bbox in bboxes], dim=0), size, mode)
for field in fields:
data = _cat([bbox.get_field(field) for bbox in bboxes], dim=0)
cat_boxes.add_field(field, data)
return cat_boxes