Skip to content

Commit

Permalink
Support bbox_clip_border for the augmentations of YOLOX (#6730)
Browse files Browse the repository at this point in the history
* support 'bbox_clip_border' for the augmentations of YOLOX

* update based on 1-st comments

* add comments

* fix typos

* rename remove_ouside_bboxes to find_inside_bboxes

* move comments to docstring
  • Loading branch information
GT9505 authored and ZwwWayne committed Dec 15, 2021
1 parent 5612624 commit 6ead450
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 27 deletions.
4 changes: 2 additions & 2 deletions mmdet/core/bbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .transforms import (bbox2distance, bbox2result, bbox2roi,
bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping,
bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh,
distance2bbox, roi2bbox)
distance2bbox, find_inside_bboxes, roi2bbox)

__all__ = [
'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner',
Expand All @@ -24,5 +24,5 @@
'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'DistancePointBBoxCoder',
'CenterRegionAssigner', 'bbox_rescale', 'bbox_cxcywh_to_xyxy',
'bbox_xyxy_to_cxcywh', 'RegionAssigner'
'bbox_xyxy_to_cxcywh', 'RegionAssigner', 'find_inside_bboxes'
]
16 changes: 16 additions & 0 deletions mmdet/core/bbox/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,22 @@
import torch


def find_inside_bboxes(bboxes, img_h, img_w):
"""Find bboxes as long as a part of bboxes is inside the image.
Args:
bboxes (Tensor): Shape (N, 4).
img_h (int): Image height.
img_w (int): Image width.
Returns:
Tensor: Index of the remaining bboxes.
"""
inside_inds = (bboxes[:, 0] < img_w) & (bboxes[:, 2] > 0) \
& (bboxes[:, 1] < img_h) & (bboxes[:, 3] > 0)
return inside_inds


def bbox_flip(bboxes, img_shape, direction='horizontal'):
"""Flip bboxes horizontally or vertically.
Expand Down
96 changes: 71 additions & 25 deletions mmdet/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from numpy import random

from mmdet.core import PolygonMasks
from mmdet.core import PolygonMasks, find_inside_bboxes
from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
from ..builder import PIPELINES

Expand Down Expand Up @@ -54,8 +54,10 @@ class Resize:
ratio_range (tuple[float]): (min_ratio, max_ratio)
keep_ratio (bool): Whether to keep the aspect ratio when resizing the
image.
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
bbox_clip_border (bool, optional): Whether to clip the objects outside
the border of the image. In some dataset like MOT17, the gt bboxes
are allowed to cross the border of images. Therefore, we don't
need to clip the gt bboxes in these cases. Defaults to True.
backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
These two backends generates slightly different results. Defaults
to 'cv2'.
Expand Down Expand Up @@ -1982,6 +1984,10 @@ class Mosaic:
output. Default to (0.5, 1.5).
min_bbox_size (int | float): The minimum pixel for filtering
invalid bboxes after the mosaic pipeline. Default to 0.
bbox_clip_border (bool, optional): Whether to clip the objects outside
the border of the image. In some dataset like MOT17, the gt bboxes
are allowed to cross the border of images. Therefore, we don't
need to clip the gt bboxes in these cases. Defaults to True.
skip_filter (bool): Whether to skip filtering rules. If it
is True, the filter rule will not be applied, and the
`min_bbox_size` is invalid. Default to True.
Expand All @@ -1992,12 +1998,14 @@ def __init__(self,
img_scale=(640, 640),
center_ratio_range=(0.5, 1.5),
min_bbox_size=0,
bbox_clip_border=True,
skip_filter=True,
pad_val=114):
assert isinstance(img_scale, tuple)
self.img_scale = img_scale
self.center_ratio_range = center_ratio_range
self.min_bbox_size = min_bbox_size
self.bbox_clip_border = bbox_clip_border
self.skip_filter = skip_filter
self.pad_val = pad_val

Expand Down Expand Up @@ -2099,16 +2107,24 @@ def _mosaic_transform(self, results):

if len(mosaic_labels) > 0:
mosaic_bboxes = np.concatenate(mosaic_bboxes, 0)
mosaic_bboxes[:, 0::2] = np.clip(mosaic_bboxes[:, 0::2], 0,
2 * self.img_scale[1])
mosaic_bboxes[:, 1::2] = np.clip(mosaic_bboxes[:, 1::2], 0,
2 * self.img_scale[0])
mosaic_labels = np.concatenate(mosaic_labels, 0)

if self.bbox_clip_border:
mosaic_bboxes[:, 0::2] = np.clip(mosaic_bboxes[:, 0::2], 0,
2 * self.img_scale[1])
mosaic_bboxes[:, 1::2] = np.clip(mosaic_bboxes[:, 1::2], 0,
2 * self.img_scale[0])

if not self.skip_filter:
mosaic_bboxes, mosaic_labels = \
self._filter_box_candidates(mosaic_bboxes, mosaic_labels)

# remove outside bboxes
inside_inds = find_inside_bboxes(mosaic_bboxes, 2 * self.img_scale[0],
2 * self.img_scale[1])
mosaic_bboxes = mosaic_bboxes[inside_inds]
mosaic_labels = mosaic_labels[inside_inds]

results['img'] = mosaic_img
results['img_shape'] = mosaic_img.shape
results['gt_bboxes'] = mosaic_bboxes
Expand Down Expand Up @@ -2243,6 +2259,10 @@ class MixUp:
max_aspect_ratio (float): Aspect ratio of width and height
threshold to filter bboxes. If max(h/w, w/h) larger than this
value, the box will be removed. Default: 20.
bbox_clip_border (bool, optional): Whether to clip the objects outside
the border of the image. In some dataset like MOT17, the gt bboxes
are allowed to cross the border of images. Therefore, we don't
need to clip the gt bboxes in these cases. Defaults to True.
skip_filter (bool): Whether to skip filtering rules. If it
is True, the filter rule will not be applied, and the
`min_bbox_size` and `min_area_ratio` and `max_aspect_ratio`
Expand All @@ -2258,6 +2278,7 @@ def __init__(self,
min_bbox_size=5,
min_area_ratio=0.2,
max_aspect_ratio=20,
bbox_clip_border=True,
skip_filter=True):
assert isinstance(img_scale, tuple)
self.dynamic_scale = img_scale
Expand All @@ -2268,6 +2289,7 @@ def __init__(self,
self.min_bbox_size = min_bbox_size
self.min_area_ratio = min_area_ratio
self.max_aspect_ratio = max_aspect_ratio
self.bbox_clip_border = bbox_clip_border
self.skip_filter = skip_filter

def __call__(self, results):
Expand Down Expand Up @@ -2371,21 +2393,29 @@ def _mixup_transform(self, results):

# 6. adjust bbox
retrieve_gt_bboxes = retrieve_results['gt_bboxes']
retrieve_gt_bboxes[:, 0::2] = np.clip(
retrieve_gt_bboxes[:, 0::2] * scale_ratio, 0, origin_w)
retrieve_gt_bboxes[:, 1::2] = np.clip(
retrieve_gt_bboxes[:, 1::2] * scale_ratio, 0, origin_h)
retrieve_gt_bboxes[:, 0::2] = retrieve_gt_bboxes[:, 0::2] * scale_ratio
retrieve_gt_bboxes[:, 1::2] = retrieve_gt_bboxes[:, 1::2] * scale_ratio
if self.bbox_clip_border:
retrieve_gt_bboxes[:, 0::2] = np.clip(retrieve_gt_bboxes[:, 0::2],
0, origin_w)
retrieve_gt_bboxes[:, 1::2] = np.clip(retrieve_gt_bboxes[:, 1::2],
0, origin_h)

if is_filp:
retrieve_gt_bboxes[:, 0::2] = (
origin_w - retrieve_gt_bboxes[:, 0::2][:, ::-1])

# 7. filter
cp_retrieve_gt_bboxes = retrieve_gt_bboxes.copy()
cp_retrieve_gt_bboxes[:, 0::2] = np.clip(
cp_retrieve_gt_bboxes[:, 0::2] - x_offset, 0, target_w)
cp_retrieve_gt_bboxes[:, 1::2] = np.clip(
cp_retrieve_gt_bboxes[:, 1::2] - y_offset, 0, target_h)
cp_retrieve_gt_bboxes[:, 0::2] = \
cp_retrieve_gt_bboxes[:, 0::2] - x_offset
cp_retrieve_gt_bboxes[:, 1::2] = \
cp_retrieve_gt_bboxes[:, 1::2] - y_offset
if self.bbox_clip_border:
cp_retrieve_gt_bboxes[:, 0::2] = np.clip(
cp_retrieve_gt_bboxes[:, 0::2], 0, target_w)
cp_retrieve_gt_bboxes[:, 1::2] = np.clip(
cp_retrieve_gt_bboxes[:, 1::2], 0, target_h)

# 8. mix up
ori_img = ori_img.astype(np.float32)
Expand All @@ -2405,6 +2435,11 @@ def _mixup_transform(self, results):
mixup_gt_labels = np.concatenate(
(results['gt_labels'], retrieve_gt_labels), axis=0)

# remove outside bbox
inside_inds = find_inside_bboxes(mixup_gt_bboxes, target_h, target_w)
mixup_gt_bboxes = mixup_gt_bboxes[inside_inds]
mixup_gt_labels = mixup_gt_labels[inside_inds]

results['img'] = mixup_img.astype(np.uint8)
results['img_shape'] = mixup_img.shape
results['gt_bboxes'] = mixup_gt_bboxes
Expand Down Expand Up @@ -2471,6 +2506,10 @@ class RandomAffine:
max_aspect_ratio (float): Aspect ratio of width and height
threshold to filter bboxes. If max(h/w, w/h) larger than this
value, the box will be removed.
bbox_clip_border (bool, optional): Whether to clip the objects outside
the border of the image. In some dataset like MOT17, the gt bboxes
are allowed to cross the border of images. Therefore, we don't
need to clip the gt bboxes in these cases. Defaults to True.
skip_filter (bool): Whether to skip filtering rules. If it
is True, the filter rule will not be applied, and the
`min_bbox_size` and `min_area_ratio` and `max_aspect_ratio`
Expand All @@ -2487,6 +2526,7 @@ def __init__(self,
min_bbox_size=2,
min_area_ratio=0.2,
max_aspect_ratio=20,
bbox_clip_border=True,
skip_filter=True):
assert 0 <= max_translate_ratio <= 1
assert scaling_ratio_range[0] <= scaling_ratio_range[1]
Expand All @@ -2500,6 +2540,7 @@ def __init__(self,
self.min_bbox_size = min_bbox_size
self.min_area_ratio = min_area_ratio
self.max_aspect_ratio = max_aspect_ratio
self.bbox_clip_border = bbox_clip_border
self.skip_filter = skip_filter

def __call__(self, results):
Expand Down Expand Up @@ -2560,20 +2601,25 @@ def __call__(self, results):
warp_bboxes = np.vstack(
(xs.min(1), ys.min(1), xs.max(1), ys.max(1))).T

warp_bboxes[:, [0, 2]] = warp_bboxes[:, [0, 2]].clip(0, width)
warp_bboxes[:, [1, 3]] = warp_bboxes[:, [1, 3]].clip(0, height)
if self.bbox_clip_border:
warp_bboxes[:, [0, 2]] = \
warp_bboxes[:, [0, 2]].clip(0, width)
warp_bboxes[:, [1, 3]] = \
warp_bboxes[:, [1, 3]].clip(0, height)

# remove outside bbox
valid_index = find_inside_bboxes(warp_bboxes, height, width)
if not self.skip_filter:
# filter bboxes
valid_index = self.filter_gt_bboxes(
filter_index = self.filter_gt_bboxes(
bboxes * scaling_ratio, warp_bboxes)
results[key] = warp_bboxes[valid_index]
if key in ['gt_bboxes']:
if 'gt_labels' in results:
results['gt_labels'] = results['gt_labels'][
valid_index]
else:
results[key] = warp_bboxes
valid_index = valid_index & filter_index

results[key] = warp_bboxes[valid_index]
if key in ['gt_bboxes']:
if 'gt_labels' in results:
results['gt_labels'] = results['gt_labels'][
valid_index]

if 'gt_masks' in results:
raise NotImplementedError(
Expand Down

0 comments on commit 6ead450

Please sign in to comment.