From 5cf37f1896eaf935ad8e1fc476ec09e186721798 Mon Sep 17 00:00:00 2001 From: Tao Gong Date: Wed, 8 Dec 2021 20:12:50 +0800 Subject: [PATCH 1/6] support 'bbox_clip_border' for the augmentations of YOLOX --- mmdet/core/bbox/__init__.py | 4 +- mmdet/core/bbox/transforms.py | 18 ++++++ mmdet/datasets/pipelines/transforms.py | 86 +++++++++++++++++++------- 3 files changed, 83 insertions(+), 25 deletions(-) diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index 1e3fa12d8fe..cf15e22af07 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -12,7 +12,7 @@ from .transforms import (bbox2distance, bbox2result, bbox2roi, bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping, bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh, - distance2bbox, roi2bbox) + distance2bbox, remove_outside_bboxes, roi2bbox) __all__ = [ 'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner', @@ -24,5 +24,5 @@ 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'DistancePointBBoxCoder', 'CenterRegionAssigner', 'bbox_rescale', 'bbox_cxcywh_to_xyxy', - 'bbox_xyxy_to_cxcywh', 'RegionAssigner' + 'bbox_xyxy_to_cxcywh', 'RegionAssigner', 'remove_outside_bboxes' ] diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py index 246028b439e..039904f58f5 100644 --- a/mmdet/core/bbox/transforms.py +++ b/mmdet/core/bbox/transforms.py @@ -3,6 +3,24 @@ import torch +def remove_outside_bboxes(bboxes, img_h, img_w): + """Remove bboxes that are out of the image boundary. + + Args: + bboxes (Tensor): Shape (N, 4). + img_h (int): Image height. + img_w (int): Image width. + + Returns: + Tensor: Index of the remaining bboxes. + """ + inside_inds = bboxes[:, 0] < img_w + inside_inds = inside_inds & (bboxes[:, 2] > 0) + inside_inds = inside_inds & (bboxes[:, 1] < img_h) + inside_inds = inside_inds & (bboxes[:, 3] > 0) + return inside_inds + + def bbox_flip(bboxes, img_shape, direction='horizontal'): """Flip bboxes horizontally or vertically. diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index 75ec639c2ec..abab8a6d9f7 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -9,7 +9,7 @@ import numpy as np from numpy import random -from mmdet.core import PolygonMasks +from mmdet.core import PolygonMasks, remove_outside_bboxes from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps from ..builder import PIPELINES @@ -1982,6 +1982,8 @@ class Mosaic: output. Default to (0.5, 1.5). min_bbox_size (int | float): The minimum pixel for filtering invalid bboxes after the mosaic pipeline. Default to 0. + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. skip_filter (bool): Whether to skip filtering rules. If it is True, the filter rule will not be applied, and the `min_bbox_size` is invalid. Default to True. @@ -1992,12 +1994,14 @@ def __init__(self, img_scale=(640, 640), center_ratio_range=(0.5, 1.5), min_bbox_size=0, + bbox_clip_border=True, skip_filter=True, pad_val=114): assert isinstance(img_scale, tuple) self.img_scale = img_scale self.center_ratio_range = center_ratio_range self.min_bbox_size = min_bbox_size + self.bbox_clip_border = bbox_clip_border self.skip_filter = skip_filter self.pad_val = pad_val @@ -2099,16 +2103,25 @@ def _mosaic_transform(self, results): if len(mosaic_labels) > 0: mosaic_bboxes = np.concatenate(mosaic_bboxes, 0) - mosaic_bboxes[:, 0::2] = np.clip(mosaic_bboxes[:, 0::2], 0, - 2 * self.img_scale[1]) - mosaic_bboxes[:, 1::2] = np.clip(mosaic_bboxes[:, 1::2], 0, - 2 * self.img_scale[0]) mosaic_labels = np.concatenate(mosaic_labels, 0) + if self.bbox_clip_border: + mosaic_bboxes[:, 0::2] = np.clip(mosaic_bboxes[:, 0::2], 0, + 2 * self.img_scale[1]) + mosaic_bboxes[:, 1::2] = np.clip(mosaic_bboxes[:, 1::2], 0, + 2 * self.img_scale[0]) + if not self.skip_filter: mosaic_bboxes, mosaic_labels = \ self._filter_box_candidates(mosaic_bboxes, mosaic_labels) + # remove outside bboxes + inside_inds = remove_outside_bboxes(mosaic_bboxes, + 2 * self.img_scale[0], + 2 * self.img_scale[1]) + mosaic_bboxes = mosaic_bboxes[inside_inds] + mosaic_labels = mosaic_labels[inside_inds] + results['img'] = mosaic_img results['img_shape'] = mosaic_img.shape results['gt_bboxes'] = mosaic_bboxes @@ -2242,6 +2255,8 @@ class MixUp: max_aspect_ratio (float): Aspect ratio of width and height threshold to filter bboxes. If max(h/w, w/h) larger than this value, the box will be removed. Default: 20. + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. skip_filter (bool): Whether to skip filtering rules. If it is True, the filter rule will not be applied, and the `min_bbox_size` and `min_area_ratio` and `max_aspect_ratio` @@ -2257,6 +2272,7 @@ def __init__(self, min_bbox_size=5, min_area_ratio=0.2, max_aspect_ratio=20, + bbox_clip_border=True, skip_filter=True): assert isinstance(img_scale, tuple) self.dynamic_scale = img_scale @@ -2267,6 +2283,7 @@ def __init__(self, self.min_bbox_size = min_bbox_size self.min_area_ratio = min_area_ratio self.max_aspect_ratio = max_aspect_ratio + self.bbox_clip_border = bbox_clip_border self.skip_filter = skip_filter def __call__(self, results): @@ -2370,10 +2387,13 @@ def _mixup_transform(self, results): # 6. adjust bbox retrieve_gt_bboxes = retrieve_results['gt_bboxes'] - retrieve_gt_bboxes[:, 0::2] = np.clip( - retrieve_gt_bboxes[:, 0::2] * scale_ratio, 0, origin_w) - retrieve_gt_bboxes[:, 1::2] = np.clip( - retrieve_gt_bboxes[:, 1::2] * scale_ratio, 0, origin_h) + retrieve_gt_bboxes[:, 0::2] = retrieve_gt_bboxes[:, 0::2] * scale_ratio + retrieve_gt_bboxes[:, 1::2] = retrieve_gt_bboxes[:, 1::2] * scale_ratio + if self.bbox_clip_border: + retrieve_gt_bboxes[:, 0::2] = np.clip(retrieve_gt_bboxes[:, 0::2], + 0, origin_w) + retrieve_gt_bboxes[:, 1::2] = np.clip(retrieve_gt_bboxes[:, 1::2], + 0, origin_h) if is_filp: retrieve_gt_bboxes[:, 0::2] = ( @@ -2381,10 +2401,15 @@ def _mixup_transform(self, results): # 7. filter cp_retrieve_gt_bboxes = retrieve_gt_bboxes.copy() - cp_retrieve_gt_bboxes[:, 0::2] = np.clip( - cp_retrieve_gt_bboxes[:, 0::2] - x_offset, 0, target_w) - cp_retrieve_gt_bboxes[:, 1::2] = np.clip( - cp_retrieve_gt_bboxes[:, 1::2] - y_offset, 0, target_h) + cp_retrieve_gt_bboxes[:, 0::2] = \ + cp_retrieve_gt_bboxes[:, 0::2] - x_offset + cp_retrieve_gt_bboxes[:, 1::2] = \ + cp_retrieve_gt_bboxes[:, 1::2] - y_offset + if self.bbox_clip_border: + cp_retrieve_gt_bboxes[:, 0::2] = np.clip( + cp_retrieve_gt_bboxes[:, 0::2], 0, target_w) + cp_retrieve_gt_bboxes[:, 1::2] = np.clip( + cp_retrieve_gt_bboxes[:, 1::2], 0, target_h) # 8. mix up ori_img = ori_img.astype(np.float32) @@ -2404,6 +2429,12 @@ def _mixup_transform(self, results): mixup_gt_labels = np.concatenate( (results['gt_labels'], retrieve_gt_labels), axis=0) + # remove outside bbox + inside_inds = remove_outside_bboxes(mixup_gt_bboxes, target_h, + target_w) + mixup_gt_bboxes = mixup_gt_bboxes[inside_inds] + mixup_gt_labels = mixup_gt_labels[inside_inds] + results['img'] = mixup_img.astype(np.uint8) results['img_shape'] = mixup_img.shape results['gt_bboxes'] = mixup_gt_bboxes @@ -2470,6 +2501,8 @@ class RandomAffine: max_aspect_ratio (float): Aspect ratio of width and height threshold to filter bboxes. If max(h/w, w/h) larger than this value, the box will be removed. + bbox_clip_border (bool, optional): Whether clip the objects outside + the border of the image. Defaults to True. skip_filter (bool): Whether to skip filtering rules. If it is True, the filter rule will not be applied, and the `min_bbox_size` and `min_area_ratio` and `max_aspect_ratio` @@ -2486,6 +2519,7 @@ def __init__(self, min_bbox_size=2, min_area_ratio=0.2, max_aspect_ratio=20, + bbox_clip_border=True, skip_filter=True): assert 0 <= max_translate_ratio <= 1 assert scaling_ratio_range[0] <= scaling_ratio_range[1] @@ -2499,6 +2533,7 @@ def __init__(self, self.min_bbox_size = min_bbox_size self.min_area_ratio = min_area_ratio self.max_aspect_ratio = max_aspect_ratio + self.bbox_clip_border = bbox_clip_border self.skip_filter = skip_filter def __call__(self, results): @@ -2559,20 +2594,25 @@ def __call__(self, results): warp_bboxes = np.vstack( (xs.min(1), ys.min(1), xs.max(1), ys.max(1))).T - warp_bboxes[:, [0, 2]] = warp_bboxes[:, [0, 2]].clip(0, width) - warp_bboxes[:, [1, 3]] = warp_bboxes[:, [1, 3]].clip(0, height) + if self.bbox_clip_border: + warp_bboxes[:, [0, 2]] = \ + warp_bboxes[:, [0, 2]].clip(0, width) + warp_bboxes[:, [1, 3]] = \ + warp_bboxes[:, [1, 3]].clip(0, height) + # remove outside bbox + valid_index = remove_outside_bboxes(warp_bboxes, height, width) if not self.skip_filter: # filter bboxes - valid_index = self.filter_gt_bboxes( + filter_index = self.filter_gt_bboxes( bboxes * scaling_ratio, warp_bboxes) - results[key] = warp_bboxes[valid_index] - if key in ['gt_bboxes']: - if 'gt_labels' in results: - results['gt_labels'] = results['gt_labels'][ - valid_index] - else: - results[key] = warp_bboxes + valid_index = valid_index & filter_index + + results[key] = warp_bboxes[valid_index] + if key in ['gt_bboxes']: + if 'gt_labels' in results: + results['gt_labels'] = results['gt_labels'][ + valid_index] if 'gt_masks' in results: raise NotImplementedError( From 161009adb9087164d6b17f8b2aaa078aa92a0385 Mon Sep 17 00:00:00 2001 From: Tao Gong Date: Thu, 9 Dec 2021 13:34:20 +0800 Subject: [PATCH 2/6] update based on 1-st comments --- mmdet/core/bbox/transforms.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py index 039904f58f5..4ac16be7912 100644 --- a/mmdet/core/bbox/transforms.py +++ b/mmdet/core/bbox/transforms.py @@ -4,7 +4,7 @@ def remove_outside_bboxes(bboxes, img_h, img_w): - """Remove bboxes that are out of the image boundary. + """Remove bboxes that are completely out of the image boundary. Args: bboxes (Tensor): Shape (N, 4). @@ -14,10 +14,8 @@ def remove_outside_bboxes(bboxes, img_h, img_w): Returns: Tensor: Index of the remaining bboxes. """ - inside_inds = bboxes[:, 0] < img_w - inside_inds = inside_inds & (bboxes[:, 2] > 0) - inside_inds = inside_inds & (bboxes[:, 1] < img_h) - inside_inds = inside_inds & (bboxes[:, 3] > 0) + inside_inds = (bboxes[:, 0] < img_w) & (bboxes[:, 2] > 0) \ + & (bboxes[:, 1] < img_h) & (bboxes[:, 3] > 0) return inside_inds From 2c12a8cd203594e2959fcb0400cc1d372945e15e Mon Sep 17 00:00:00 2001 From: Tao Gong Date: Thu, 9 Dec 2021 15:21:23 +0800 Subject: [PATCH 3/6] add comments --- mmdet/datasets/pipelines/transforms.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index abab8a6d9f7..aaba2514c83 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -239,6 +239,9 @@ def _resize_bboxes(self, results): """Resize bounding boxes with ``results['scale_factor']``.""" for key in results.get('bbox_fields', []): bboxes = results[key] * results['scale_factor'] + # In some dataset like MOT17, the gt bboxes are allowed to cross + # the border of images. Therefore, we don't need to clip the gt + # bboxes in these cases. if self.bbox_clip_border: img_shape = results['img_shape'] bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) @@ -2105,6 +2108,9 @@ def _mosaic_transform(self, results): mosaic_bboxes = np.concatenate(mosaic_bboxes, 0) mosaic_labels = np.concatenate(mosaic_labels, 0) + # In some dataset like MOT17, the gt bboxes are allowed to cross + # the border of images. Therefore, we don't need to clip the gt + # bboxes in these cases. if self.bbox_clip_border: mosaic_bboxes[:, 0::2] = np.clip(mosaic_bboxes[:, 0::2], 0, 2 * self.img_scale[1]) @@ -2389,6 +2395,9 @@ def _mixup_transform(self, results): retrieve_gt_bboxes = retrieve_results['gt_bboxes'] retrieve_gt_bboxes[:, 0::2] = retrieve_gt_bboxes[:, 0::2] * scale_ratio retrieve_gt_bboxes[:, 1::2] = retrieve_gt_bboxes[:, 1::2] * scale_ratio + # In some dataset like MOT17, the gt bboxes are allowed to cross + # the border of images. Therefore, we don't need to clip the gt + # bboxes in these cases. if self.bbox_clip_border: retrieve_gt_bboxes[:, 0::2] = np.clip(retrieve_gt_bboxes[:, 0::2], 0, origin_w) @@ -2405,6 +2414,9 @@ def _mixup_transform(self, results): cp_retrieve_gt_bboxes[:, 0::2] - x_offset cp_retrieve_gt_bboxes[:, 1::2] = \ cp_retrieve_gt_bboxes[:, 1::2] - y_offset + # In some dataset like MOT17, the gt bboxes are allowed to cross + # the border of images. Therefore, we don't need to clip the gt + # bboxes in these cases. if self.bbox_clip_border: cp_retrieve_gt_bboxes[:, 0::2] = np.clip( cp_retrieve_gt_bboxes[:, 0::2], 0, target_w) @@ -2594,6 +2606,9 @@ def __call__(self, results): warp_bboxes = np.vstack( (xs.min(1), ys.min(1), xs.max(1), ys.max(1))).T + # In some dataset like MOT17, the gt bboxes are allowed to + # cross the border of images. Therefore, we don't need to clip + # the gt bboxes in these cases. if self.bbox_clip_border: warp_bboxes[:, [0, 2]] = \ warp_bboxes[:, [0, 2]].clip(0, width) From 85da782f8b4ef2fa7b7378943a4a030887384045 Mon Sep 17 00:00:00 2001 From: Tao Gong Date: Thu, 9 Dec 2021 19:24:46 +0800 Subject: [PATCH 4/6] fix typos --- mmdet/datasets/pipelines/transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index aaba2514c83..99ce1b3aa37 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -54,7 +54,7 @@ class Resize: ratio_range (tuple[float]): (min_ratio, max_ratio) keep_ratio (bool): Whether to keep the aspect ratio when resizing the image. - bbox_clip_border (bool, optional): Whether clip the objects outside + bbox_clip_border (bool, optional): Whether to clip the objects outside the border of the image. Defaults to True. backend (str): Image resize backend, choices are 'cv2' and 'pillow'. These two backends generates slightly different results. Defaults @@ -1985,7 +1985,7 @@ class Mosaic: output. Default to (0.5, 1.5). min_bbox_size (int | float): The minimum pixel for filtering invalid bboxes after the mosaic pipeline. Default to 0. - bbox_clip_border (bool, optional): Whether clip the objects outside + bbox_clip_border (bool, optional): Whether to clip the objects outside the border of the image. Defaults to True. skip_filter (bool): Whether to skip filtering rules. If it is True, the filter rule will not be applied, and the @@ -2261,7 +2261,7 @@ class MixUp: max_aspect_ratio (float): Aspect ratio of width and height threshold to filter bboxes. If max(h/w, w/h) larger than this value, the box will be removed. Default: 20. - bbox_clip_border (bool, optional): Whether clip the objects outside + bbox_clip_border (bool, optional): Whether to clip the objects outside the border of the image. Defaults to True. skip_filter (bool): Whether to skip filtering rules. If it is True, the filter rule will not be applied, and the @@ -2513,7 +2513,7 @@ class RandomAffine: max_aspect_ratio (float): Aspect ratio of width and height threshold to filter bboxes. If max(h/w, w/h) larger than this value, the box will be removed. - bbox_clip_border (bool, optional): Whether clip the objects outside + bbox_clip_border (bool, optional): Whether to clip the objects outside the border of the image. Defaults to True. skip_filter (bool): Whether to skip filtering rules. If it is True, the filter rule will not be applied, and the From 32d934dc4671e1f5cdcb818b65472933242e5e28 Mon Sep 17 00:00:00 2001 From: Tao Gong Date: Fri, 10 Dec 2021 10:25:20 +0800 Subject: [PATCH 5/6] rename remove_ouside_bboxes to find_inside_bboxes --- mmdet/core/bbox/__init__.py | 4 ++-- mmdet/core/bbox/transforms.py | 4 ++-- mmdet/datasets/pipelines/transforms.py | 12 +++++------- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index cf15e22af07..371eba198e9 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -12,7 +12,7 @@ from .transforms import (bbox2distance, bbox2result, bbox2roi, bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping, bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh, - distance2bbox, remove_outside_bboxes, roi2bbox) + distance2bbox, find_inside_bboxes, roi2bbox) __all__ = [ 'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner', @@ -24,5 +24,5 @@ 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'DistancePointBBoxCoder', 'CenterRegionAssigner', 'bbox_rescale', 'bbox_cxcywh_to_xyxy', - 'bbox_xyxy_to_cxcywh', 'RegionAssigner', 'remove_outside_bboxes' + 'bbox_xyxy_to_cxcywh', 'RegionAssigner', 'find_inside_bboxes' ] diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py index 4ac16be7912..6d72076a562 100644 --- a/mmdet/core/bbox/transforms.py +++ b/mmdet/core/bbox/transforms.py @@ -3,8 +3,8 @@ import torch -def remove_outside_bboxes(bboxes, img_h, img_w): - """Remove bboxes that are completely out of the image boundary. +def find_inside_bboxes(bboxes, img_h, img_w): + """Find bboxes as long as a part of bboxes is inside the image. Args: bboxes (Tensor): Shape (N, 4). diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index 99ce1b3aa37..6e349c10c13 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -9,7 +9,7 @@ import numpy as np from numpy import random -from mmdet.core import PolygonMasks, remove_outside_bboxes +from mmdet.core import PolygonMasks, find_inside_bboxes from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps from ..builder import PIPELINES @@ -2122,9 +2122,8 @@ def _mosaic_transform(self, results): self._filter_box_candidates(mosaic_bboxes, mosaic_labels) # remove outside bboxes - inside_inds = remove_outside_bboxes(mosaic_bboxes, - 2 * self.img_scale[0], - 2 * self.img_scale[1]) + inside_inds = find_inside_bboxes(mosaic_bboxes, 2 * self.img_scale[0], + 2 * self.img_scale[1]) mosaic_bboxes = mosaic_bboxes[inside_inds] mosaic_labels = mosaic_labels[inside_inds] @@ -2442,8 +2441,7 @@ def _mixup_transform(self, results): (results['gt_labels'], retrieve_gt_labels), axis=0) # remove outside bbox - inside_inds = remove_outside_bboxes(mixup_gt_bboxes, target_h, - target_w) + inside_inds = find_inside_bboxes(mixup_gt_bboxes, target_h, target_w) mixup_gt_bboxes = mixup_gt_bboxes[inside_inds] mixup_gt_labels = mixup_gt_labels[inside_inds] @@ -2616,7 +2614,7 @@ def __call__(self, results): warp_bboxes[:, [1, 3]].clip(0, height) # remove outside bbox - valid_index = remove_outside_bboxes(warp_bboxes, height, width) + valid_index = find_inside_bboxes(warp_bboxes, height, width) if not self.skip_filter: # filter bboxes filter_index = self.filter_gt_bboxes( From f04645aba3762536d66e833db60f0842f28798b8 Mon Sep 17 00:00:00 2001 From: Tao Gong Date: Fri, 10 Dec 2021 14:10:21 +0800 Subject: [PATCH 6/6] move comments to docstring --- mmdet/datasets/pipelines/transforms.py | 31 ++++++++++---------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py index 6e349c10c13..d11586d0e60 100644 --- a/mmdet/datasets/pipelines/transforms.py +++ b/mmdet/datasets/pipelines/transforms.py @@ -55,7 +55,9 @@ class Resize: keep_ratio (bool): Whether to keep the aspect ratio when resizing the image. bbox_clip_border (bool, optional): Whether to clip the objects outside - the border of the image. Defaults to True. + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. backend (str): Image resize backend, choices are 'cv2' and 'pillow'. These two backends generates slightly different results. Defaults to 'cv2'. @@ -239,9 +241,6 @@ def _resize_bboxes(self, results): """Resize bounding boxes with ``results['scale_factor']``.""" for key in results.get('bbox_fields', []): bboxes = results[key] * results['scale_factor'] - # In some dataset like MOT17, the gt bboxes are allowed to cross - # the border of images. Therefore, we don't need to clip the gt - # bboxes in these cases. if self.bbox_clip_border: img_shape = results['img_shape'] bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1]) @@ -1986,7 +1985,9 @@ class Mosaic: min_bbox_size (int | float): The minimum pixel for filtering invalid bboxes after the mosaic pipeline. Default to 0. bbox_clip_border (bool, optional): Whether to clip the objects outside - the border of the image. Defaults to True. + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. skip_filter (bool): Whether to skip filtering rules. If it is True, the filter rule will not be applied, and the `min_bbox_size` is invalid. Default to True. @@ -2108,9 +2109,6 @@ def _mosaic_transform(self, results): mosaic_bboxes = np.concatenate(mosaic_bboxes, 0) mosaic_labels = np.concatenate(mosaic_labels, 0) - # In some dataset like MOT17, the gt bboxes are allowed to cross - # the border of images. Therefore, we don't need to clip the gt - # bboxes in these cases. if self.bbox_clip_border: mosaic_bboxes[:, 0::2] = np.clip(mosaic_bboxes[:, 0::2], 0, 2 * self.img_scale[1]) @@ -2261,7 +2259,9 @@ class MixUp: threshold to filter bboxes. If max(h/w, w/h) larger than this value, the box will be removed. Default: 20. bbox_clip_border (bool, optional): Whether to clip the objects outside - the border of the image. Defaults to True. + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. skip_filter (bool): Whether to skip filtering rules. If it is True, the filter rule will not be applied, and the `min_bbox_size` and `min_area_ratio` and `max_aspect_ratio` @@ -2394,9 +2394,6 @@ def _mixup_transform(self, results): retrieve_gt_bboxes = retrieve_results['gt_bboxes'] retrieve_gt_bboxes[:, 0::2] = retrieve_gt_bboxes[:, 0::2] * scale_ratio retrieve_gt_bboxes[:, 1::2] = retrieve_gt_bboxes[:, 1::2] * scale_ratio - # In some dataset like MOT17, the gt bboxes are allowed to cross - # the border of images. Therefore, we don't need to clip the gt - # bboxes in these cases. if self.bbox_clip_border: retrieve_gt_bboxes[:, 0::2] = np.clip(retrieve_gt_bboxes[:, 0::2], 0, origin_w) @@ -2413,9 +2410,6 @@ def _mixup_transform(self, results): cp_retrieve_gt_bboxes[:, 0::2] - x_offset cp_retrieve_gt_bboxes[:, 1::2] = \ cp_retrieve_gt_bboxes[:, 1::2] - y_offset - # In some dataset like MOT17, the gt bboxes are allowed to cross - # the border of images. Therefore, we don't need to clip the gt - # bboxes in these cases. if self.bbox_clip_border: cp_retrieve_gt_bboxes[:, 0::2] = np.clip( cp_retrieve_gt_bboxes[:, 0::2], 0, target_w) @@ -2512,7 +2506,9 @@ class RandomAffine: threshold to filter bboxes. If max(h/w, w/h) larger than this value, the box will be removed. bbox_clip_border (bool, optional): Whether to clip the objects outside - the border of the image. Defaults to True. + the border of the image. In some dataset like MOT17, the gt bboxes + are allowed to cross the border of images. Therefore, we don't + need to clip the gt bboxes in these cases. Defaults to True. skip_filter (bool): Whether to skip filtering rules. If it is True, the filter rule will not be applied, and the `min_bbox_size` and `min_area_ratio` and `max_aspect_ratio` @@ -2604,9 +2600,6 @@ def __call__(self, results): warp_bboxes = np.vstack( (xs.min(1), ys.min(1), xs.max(1), ys.max(1))).T - # In some dataset like MOT17, the gt bboxes are allowed to - # cross the border of images. Therefore, we don't need to clip - # the gt bboxes in these cases. if self.bbox_clip_border: warp_bboxes[:, [0, 2]] = \ warp_bboxes[:, [0, 2]].clip(0, width)