From 7185b5abf7c470b9946caf0d5ccaceed3cb5a7f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= Date: Tue, 20 Dec 2022 10:13:29 +0800 Subject: [PATCH] [Refactor] Make sure the pipeline argument shape is in WH order (#9324) * Keep input wh shape format in pipeline * update * fix lint * add doc * update * update * update * update * fix lint * update title * fix comment --- configs/yolox/yolox_s_8xb8-300e_coco.py | 3 +- configs/yolox/yolox_tiny_8xb8-300e_coco.py | 3 +- docs/en/advanced_guides/conventions.md | 33 +++++++ docs/zh_cn/advanced_guides/conventions.md | 34 +++++++ mmdet/datasets/transforms/formatting.py | 4 +- mmdet/datasets/transforms/transforms.py | 94 ++++++++++--------- .../test_transforms/test_transforms.py | 58 ++++++------ 7 files changed, 151 insertions(+), 78 deletions(-) diff --git a/configs/yolox/yolox_s_8xb8-300e_coco.py b/configs/yolox/yolox_s_8xb8-300e_coco.py index e78a82faa46..da37c758224 100644 --- a/configs/yolox/yolox_s_8xb8-300e_coco.py +++ b/configs/yolox/yolox_s_8xb8-300e_coco.py @@ -1,6 +1,6 @@ _base_ = ['../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'] -img_scale = (640, 640) # height, width +img_scale = (640, 640) # width, height # model settings model = dict( @@ -83,6 +83,7 @@ dict( type='RandomAffine', scaling_ratio_range=(0.1, 2), + # img_scale is (width, height) border=(-img_scale[0] // 2, -img_scale[1] // 2)), dict( type='MixUp', diff --git a/configs/yolox/yolox_tiny_8xb8-300e_coco.py b/configs/yolox/yolox_tiny_8xb8-300e_coco.py index b4f5bdeda1d..b15480bed0a 100644 --- a/configs/yolox/yolox_tiny_8xb8-300e_coco.py +++ b/configs/yolox/yolox_tiny_8xb8-300e_coco.py @@ -13,7 +13,7 @@ neck=dict(in_channels=[96, 192, 384], out_channels=96), bbox_head=dict(in_channels=96, feat_channels=96)) -img_scale = (640, 640) # height, width +img_scale = (640, 640) # width, height # file_client_args = dict( # backend='petrel', @@ -28,6 +28,7 @@ dict( type='RandomAffine', scaling_ratio_range=(0.5, 1.5), + # img_scale is (width, height) border=(-img_scale[0] // 2, -img_scale[1] // 2)), dict(type='YOLOXHSVRandomAug'), dict(type='RandomFlip', prob=0.5), diff --git a/docs/en/advanced_guides/conventions.md b/docs/en/advanced_guides/conventions.md index 67b6678bfd2..da159ac699f 100644 --- a/docs/en/advanced_guides/conventions.md +++ b/docs/en/advanced_guides/conventions.md @@ -2,6 +2,39 @@ Please check the following conventions if you would like to modify MMDetection as your own project. +## About the order of image shape + +In OpenMMLab 2.0, to be consistent with the input argument of OpenCV, the argument about image shape in the data transformation pipeline is always in the `(width, height)` order. On the contrary, for computation convenience, the order of the field going through the data pipeline and the model is `(height, width)`. Specifically, in the results processed by each data transform pipeline, the fields and their value meaning is as below: + +- img_shape: (height, width) +- ori_shape: (height, width) +- pad_shape: (height, width) +- batch_input_shape: (height, width) + +As an example, the initialization arguments of `Mosaic` are as below: + +```python +@TRANSFORMS.register_module() +class Mosaic(BaseTransform): + def __init__(self, + img_scale: Tuple[int, int] = (640, 640), + center_ratio_range: Tuple[float, float] = (0.5, 1.5), + bbox_clip_border: bool = True, + pad_val: float = 114.0, + prob: float = 1.0) -> None: + ... + + # img_scale order should be (width, height) + self.img_scale = img_scale + + def transform(self, results: dict) -> dict: + ... + + results['img'] = mosaic_img + # (height, width) + results['img_shape'] = mosaic_img.shape[:2] +``` + ## Loss In MMDetection, a `dict` containing losses and metrics will be returned by `model(**data)`. diff --git a/docs/zh_cn/advanced_guides/conventions.md b/docs/zh_cn/advanced_guides/conventions.md index ff52d6499e0..261f5ed5eb7 100644 --- a/docs/zh_cn/advanced_guides/conventions.md +++ b/docs/zh_cn/advanced_guides/conventions.md @@ -2,6 +2,40 @@ 如果你想把 MMDetection 修改为自己的项目,请遵循下面的约定。 +## 关于图片 shape 顺序的说明 + +在OpenMMLab 2.0中, 为了与 OpenCV 的输入参数相一致,图片处理 pipeline 中关于图像 shape 的输入参数总是以 `(width, height)` 的顺序排列。 +相反,为了计算方便,经过 pipeline 和 model 的字段的顺序是 `(height, width)`。具体来说在每个数据 pipeline 处理的结果中,字段和它们的值含义如下: + +- img_shape: (height, width) +- ori_shape: (height, width) +- pad_shape: (height, width) +- batch_input_shape: (height, width) + +以 `Mosaic` 为例,其初始化参数如下所示: + +```python +@TRANSFORMS.register_module() +class Mosaic(BaseTransform): + def __init__(self, + img_scale: Tuple[int, int] = (640, 640), + center_ratio_range: Tuple[float, float] = (0.5, 1.5), + bbox_clip_border: bool = True, + pad_val: float = 114.0, + prob: float = 1.0) -> None: + ... + + # img_scale 顺序应该是 (width, height) + self.img_scale = img_scale + + def transform(self, results: dict) -> dict: + ... + + results['img'] = mosaic_img + # (height, width) + results['img_shape'] = mosaic_img.shape[:2] +``` + ## 损失 在 MMDetection 中,`model(**data)` 的返回值是一个字典,包含着所有的损失和评价指标,他们将会由 `model(**data)` 返回。 diff --git a/mmdet/datasets/transforms/formatting.py b/mmdet/datasets/transforms/formatting.py index e34ef4219be..98248ee12be 100644 --- a/mmdet/datasets/transforms/formatting.py +++ b/mmdet/datasets/transforms/formatting.py @@ -21,10 +21,10 @@ class PackDetInputs(BaseTransform): - ``img_path``: path to the image file - - ``ori_shape``: original shape of the image as a tuple (h, w, c) + - ``ori_shape``: original shape of the image as a tuple (h, w) - ``img_shape``: shape of the image input to the network as a tuple \ - (h, w, c). Note that images may be zero padded on the \ + (h, w). Note that images may be zero padded on the \ bottom/right if the batch tensor is larger than this shape. - ``scale_factor``: a float indicating the preprocessing scale diff --git a/mmdet/datasets/transforms/transforms.py b/mmdet/datasets/transforms/transforms.py index c9e95bd7476..18b646bd26d 100644 --- a/mmdet/datasets/transforms/transforms.py +++ b/mmdet/datasets/transforms/transforms.py @@ -545,7 +545,7 @@ class Pad(MMCV_Pad): Args: size (tuple, optional): Fixed padding size. - Expected padding shape (w, h). Defaults to None. + Expected padding shape (width, height). Defaults to None. size_divisor (int, optional): The divisor of padded size. Defaults to None. pad_to_square (bool): Whether to pad the image into a square. @@ -630,7 +630,7 @@ class RandomCrop(BaseTransform): Args: crop_size (tuple): The relative ratio or absolute pixels of - height and width. + (width, height). crop_type (str, optional): One of "relative_range", "relative", "absolute", "absolute_range". "relative" randomly crops (h * crop_size[0], w * crop_size[1]) part from an input of size @@ -776,7 +776,7 @@ def _rand_offset(self, margin: Tuple[int, int]) -> Tuple[int, int]: offset_h = np.random.randint(0, margin_h + 1) offset_w = np.random.randint(0, margin_w + 1) - return (offset_h, offset_w) + return offset_h, offset_w @cache_randomness def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]: @@ -791,7 +791,7 @@ def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]: """ h, w = image_size if self.crop_type == 'absolute': - return (min(self.crop_size[0], h), min(self.crop_size[1], w)) + return min(self.crop_size[1], h), min(self.crop_size[0], w) elif self.crop_type == 'absolute_range': crop_h = np.random.randint( min(h, self.crop_size[0]), @@ -801,7 +801,7 @@ def _get_crop_size(self, image_size: Tuple[int, int]) -> Tuple[int, int]: min(w, self.crop_size[1]) + 1) return crop_h, crop_w elif self.crop_type == 'relative': - crop_h, crop_w = self.crop_size + crop_w, crop_h = self.crop_size return int(h * crop_h + 0.5), int(w * crop_w + 0.5) else: # 'relative_range' @@ -1668,8 +1668,8 @@ class RandomCenterCropPad(BaseTransform): Args: crop_size (tuple, optional): expected size after crop, final size will - computed according to ratio. Requires (h, w) in train mode, and - None in test mode. + computed according to ratio. Requires (width, height) + in train mode, and None in test mode. ratios (tuple, optional): random select a ratio from tuple and crop image to (crop_size[0] * ratio) * (crop_size[1] * ratio). Only available in train mode. Defaults to (0.9, 1.0, 1.1). @@ -1844,8 +1844,8 @@ def _train_aug(self, results): gt_bboxes = results['gt_bboxes'] while True: scale = random.choice(self.ratios) - new_h = int(self.crop_size[0] * scale) - new_w = int(self.crop_size[1] * scale) + new_h = int(self.crop_size[1] * scale) + new_w = int(self.crop_size[0] * scale) h_border = self._get_border(self.border, h) w_border = self._get_border(self.border, w) @@ -2107,7 +2107,7 @@ class Mosaic(BaseTransform): Args: img_scale (Sequence[int]): Image size after mosaic pipeline of single - image. The shape order should be (height, width). + image. The shape order should be (width, height). Defaults to (640, 640). center_ratio_range (Sequence[float]): Center ratio range of mosaic output. Defaults to (0.5, 1.5). @@ -2130,7 +2130,7 @@ def __init__(self, assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. ' \ f'got {prob}.' - log_img_scale(img_scale, skip_square=True) + log_img_scale(img_scale, skip_square=True, shape_order='wh') self.img_scale = img_scale self.center_ratio_range = center_ratio_range self.bbox_clip_border = bbox_clip_border @@ -2170,20 +2170,20 @@ def transform(self, results: dict) -> dict: mosaic_ignore_flags = [] if len(results['img'].shape) == 3: mosaic_img = np.full( - (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3), + (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3), self.pad_val, dtype=results['img'].dtype) else: mosaic_img = np.full( - (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)), self.pad_val, dtype=results['img'].dtype) # mosaic center x, y center_x = int( - random.uniform(*self.center_ratio_range) * self.img_scale[1]) - center_y = int( random.uniform(*self.center_ratio_range) * self.img_scale[0]) + center_y = int( + random.uniform(*self.center_ratio_range) * self.img_scale[1]) center_position = (center_x, center_y) loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') @@ -2196,8 +2196,8 @@ def transform(self, results: dict) -> dict: img_i = results_patch['img'] h_i, w_i = img_i.shape[:2] # keep_ratio resize - scale_ratio_i = min(self.img_scale[0] / h_i, - self.img_scale[1] / w_i) + scale_ratio_i = min(self.img_scale[1] / h_i, + self.img_scale[0] / w_i) img_i = mmcv.imresize( img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i))) @@ -2228,10 +2228,10 @@ def transform(self, results: dict) -> dict: mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0) if self.bbox_clip_border: - mosaic_bboxes.clip_([2 * self.img_scale[0], 2 * self.img_scale[1]]) + mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]]) # remove outside bboxes inside_inds = mosaic_bboxes.is_inside( - [2 * self.img_scale[0], 2 * self.img_scale[1]]).numpy() + [2 * self.img_scale[1], 2 * self.img_scale[0]]).numpy() mosaic_bboxes = mosaic_bboxes[inside_inds] mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds] mosaic_ignore_flags = mosaic_ignore_flags[inside_inds] @@ -2277,7 +2277,7 @@ def _mosaic_combine( x1, y1, x2, y2 = center_position_xy[0], \ max(center_position_xy[1] - img_shape_wh[1], 0), \ min(center_position_xy[0] + img_shape_wh[0], - self.img_scale[1] * 2), \ + self.img_scale[0] * 2), \ center_position_xy[1] crop_coord = 0, img_shape_wh[1] - (y2 - y1), min( img_shape_wh[0], x2 - x1), img_shape_wh[1] @@ -2287,7 +2287,7 @@ def _mosaic_combine( x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \ center_position_xy[1], \ center_position_xy[0], \ - min(self.img_scale[0] * 2, center_position_xy[1] + + min(self.img_scale[1] * 2, center_position_xy[1] + img_shape_wh[1]) crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min( y2 - y1, img_shape_wh[1]) @@ -2297,8 +2297,8 @@ def _mosaic_combine( x1, y1, x2, y2 = center_position_xy[0], \ center_position_xy[1], \ min(center_position_xy[0] + img_shape_wh[0], - self.img_scale[1] * 2), \ - min(self.img_scale[0] * 2, center_position_xy[1] + + self.img_scale[0] * 2), \ + min(self.img_scale[1] * 2, center_position_xy[1] + img_shape_wh[1]) crop_coord = 0, 0, min(img_shape_wh[0], x2 - x1), min(y2 - y1, img_shape_wh[1]) @@ -2362,7 +2362,7 @@ class MixUp(BaseTransform): Args: img_scale (Sequence[int]): Image output size after mixup pipeline. - The shape order should be (height, width). Defaults to (640, 640). + The shape order should be (width, height). Defaults to (640, 640). ratio_range (Sequence[float]): Scale ratio of mixup image. Defaults to (0.5, 1.5). flip_ratio (float): Horizontal flip ratio of mixup image. @@ -2385,7 +2385,7 @@ def __init__(self, max_iters: int = 15, bbox_clip_border: bool = True) -> None: assert isinstance(img_scale, tuple) - log_img_scale(img_scale, skip_square=True) + log_img_scale(img_scale, skip_square=True, shape_order='wh') self.dynamic_scale = img_scale self.ratio_range = ratio_range self.flip_ratio = flip_ratio @@ -2439,15 +2439,16 @@ def transform(self, results: dict) -> dict: if len(retrieve_img.shape) == 3: out_img = np.ones( - (self.dynamic_scale[0], self.dynamic_scale[1], 3), + (self.dynamic_scale[1], self.dynamic_scale[0], 3), dtype=retrieve_img.dtype) * self.pad_val else: out_img = np.ones( - self.dynamic_scale, dtype=retrieve_img.dtype) * self.pad_val + self.dynamic_scale[::-1], + dtype=retrieve_img.dtype) * self.pad_val # 1. keep_ratio resize - scale_ratio = min(self.dynamic_scale[0] / retrieve_img.shape[0], - self.dynamic_scale[1] / retrieve_img.shape[1]) + scale_ratio = min(self.dynamic_scale[1] / retrieve_img.shape[0], + self.dynamic_scale[0] / retrieve_img.shape[1]) retrieve_img = mmcv.imresize( retrieve_img, (int(retrieve_img.shape[1] * scale_ratio), int(retrieve_img.shape[0] * scale_ratio))) @@ -2567,7 +2568,7 @@ class RandomAffine(BaseTransform): scaling transform. Defaults to (0.5, 1.5). max_shear_degree (float): Maximum degrees of shear transform. Defaults to 2. - border (tuple[int]): Distance from height and width sides of input + border (tuple[int]): Distance from width and height sides of input image to adjust output shape. Only used in mosaic dataset. Defaults to (0, 0). border_val (tuple[int]): Border padding values of 3 channels. @@ -2630,8 +2631,8 @@ def _get_random_homography_matrix(self, height, width): @autocast_box_type() def transform(self, results: dict) -> dict: img = results['img'] - height = img.shape[0] + self.border[0] * 2 - width = img.shape[1] + self.border[1] * 2 + height = img.shape[0] + self.border[1] * 2 + width = img.shape[1] + self.border[0] * 2 warp_matrix = self._get_random_homography_matrix(height, width) @@ -3167,7 +3168,7 @@ class CachedMosaic(Mosaic): Args: img_scale (Sequence[int]): Image size after mosaic pipeline of single - image. The shape order should be (height, width). + image. The shape order should be (width, height). Defaults to (640, 640). center_ratio_range (Sequence[float]): Center ratio range of mosaic output. Defaults to (0.5, 1.5). @@ -3249,20 +3250,20 @@ def transform(self, results: dict) -> dict: if len(results['img'].shape) == 3: mosaic_img = np.full( - (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3), + (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2), 3), self.pad_val, dtype=results['img'].dtype) else: mosaic_img = np.full( - (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)), + (int(self.img_scale[1] * 2), int(self.img_scale[0] * 2)), self.pad_val, dtype=results['img'].dtype) # mosaic center x, y center_x = int( - random.uniform(*self.center_ratio_range) * self.img_scale[1]) - center_y = int( random.uniform(*self.center_ratio_range) * self.img_scale[0]) + center_y = int( + random.uniform(*self.center_ratio_range) * self.img_scale[1]) center_position = (center_x, center_y) loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right') @@ -3275,8 +3276,8 @@ def transform(self, results: dict) -> dict: img_i = results_patch['img'] h_i, w_i = img_i.shape[:2] # keep_ratio resize - scale_ratio_i = min(self.img_scale[0] / h_i, - self.img_scale[1] / w_i) + scale_ratio_i = min(self.img_scale[1] / h_i, + self.img_scale[0] / w_i) img_i = mmcv.imresize( img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i))) @@ -3321,10 +3322,10 @@ def transform(self, results: dict) -> dict: mosaic_ignore_flags = np.concatenate(mosaic_ignore_flags, 0) if self.bbox_clip_border: - mosaic_bboxes.clip_([2 * self.img_scale[0], 2 * self.img_scale[1]]) + mosaic_bboxes.clip_([2 * self.img_scale[1], 2 * self.img_scale[0]]) # remove outside bboxes inside_inds = mosaic_bboxes.is_inside( - [2 * self.img_scale[0], 2 * self.img_scale[1]]).numpy() + [2 * self.img_scale[1], 2 * self.img_scale[0]]).numpy() mosaic_bboxes = mosaic_bboxes[inside_inds] mosaic_bboxes_labels = mosaic_bboxes_labels[inside_inds] mosaic_ignore_flags = mosaic_ignore_flags[inside_inds] @@ -3399,7 +3400,7 @@ class CachedMixUp(BaseTransform): Args: img_scale (Sequence[int]): Image output size after mixup pipeline. - The shape order should be (height, width). Defaults to (640, 640). + The shape order should be (width, height). Defaults to (640, 640). ratio_range (Sequence[float]): Scale ratio of mixup image. Defaults to (0.5, 1.5). flip_ratio (float): Horizontal flip ratio of mixup image. @@ -3509,15 +3510,16 @@ def transform(self, results: dict) -> dict: if len(retrieve_img.shape) == 3: out_img = np.ones( - (self.dynamic_scale[0], self.dynamic_scale[1], 3), + (self.dynamic_scale[1], self.dynamic_scale[0], 3), dtype=retrieve_img.dtype) * self.pad_val else: out_img = np.ones( - self.dynamic_scale, dtype=retrieve_img.dtype) * self.pad_val + self.dynamic_scale[::-1], + dtype=retrieve_img.dtype) * self.pad_val # 1. keep_ratio resize - scale_ratio = min(self.dynamic_scale[0] / retrieve_img.shape[0], - self.dynamic_scale[1] / retrieve_img.shape[1]) + scale_ratio = min(self.dynamic_scale[1] / retrieve_img.shape[0], + self.dynamic_scale[0] / retrieve_img.shape[1]) retrieve_img = mmcv.imresize( retrieve_img, (int(retrieve_img.shape[1] * scale_ratio), int(retrieve_img.shape[0] * scale_ratio))) diff --git a/tests/test_datasets/test_transforms/test_transforms.py b/tests/test_datasets/test_transforms/test_transforms.py index 2f84673f5f6..b0cb2c29d04 100644 --- a/tests/test_datasets/test_transforms/test_transforms.py +++ b/tests/test_datasets/test_transforms/test_transforms.py @@ -593,28 +593,29 @@ def test_init(self): def test_transform(self): # test relative and absolute crop src_results = { - 'img': np.random.randint(0, 255, size=(32, 24), dtype=np.int32) + 'img': np.random.randint(0, 255, size=(24, 32), dtype=np.int32) } - target_shape = (16, 12) + target_shape = (12, 16) for crop_type, crop_size in zip(['relative', 'absolute'], [(0.5, 0.5), (16, 12)]): transform = RandomCrop(crop_size=crop_size, crop_type=crop_type) results = transform(copy.deepcopy(src_results)) + print(results['img'].shape[:2]) self.assertEqual(results['img'].shape[:2], target_shape) # test absolute_range crop transform = RandomCrop(crop_size=(10, 20), crop_type='absolute_range') results = transform(copy.deepcopy(src_results)) h, w = results['img'].shape - self.assertTrue(10 <= h <= 20) self.assertTrue(10 <= w <= 20) + self.assertTrue(10 <= h <= 20) # test relative_range crop transform = RandomCrop( crop_size=(0.5, 0.5), crop_type='relative_range') results = transform(copy.deepcopy(src_results)) h, w = results['img'].shape - self.assertTrue(16 <= h <= 32) - self.assertTrue(12 <= w <= 24) + self.assertTrue(16 <= w <= 32) + self.assertTrue(12 <= h <= 24) # test with gt_bboxes, gt_bboxes_labels, gt_ignore_flags, # gt_masks, gt_seg_map @@ -636,23 +637,23 @@ def test_transform(self): 'gt_seg_map': gt_seg_map } transform = RandomCrop( - crop_size=(5, 5), + crop_size=(7, 5), allow_negative_crop=False, recompute_bbox=False, bbox_clip_border=True) results = transform(copy.deepcopy(src_results)) h, w = results['img'].shape self.assertEqual(h, 5) - self.assertEqual(w, 5) + self.assertEqual(w, 7) self.assertEqual(results['gt_bboxes'].shape[0], 2) self.assertEqual(results['gt_bboxes_labels'].shape[0], 2) self.assertEqual(results['gt_ignore_flags'].shape[0], 2) - self.assertTupleEqual(results['gt_seg_map'].shape[:2], (5, 5)) + self.assertTupleEqual(results['gt_seg_map'].shape[:2], (5, 7)) # test geometric transformation with homography matrix bboxes = copy.deepcopy(src_results['gt_bboxes']) self.assertTrue((bbox_project(bboxes, results['homography_matrix'], - (5, 5)) == results['gt_bboxes']).all()) + (5, 7)) == results['gt_bboxes']).all()) # test recompute_bbox = True gt_masks_ = np.zeros((2, 10, 10), np.uint8) @@ -665,7 +666,7 @@ def test_transform(self): } target_gt_bboxes = np.zeros((1, 4), dtype=np.float32) transform = RandomCrop( - crop_size=(10, 10), + crop_size=(10, 11), allow_negative_crop=False, recompute_bbox=True, bbox_clip_border=True) @@ -675,7 +676,7 @@ def test_transform(self): # test bbox_clip_border = False src_results = {'img': img, 'gt_bboxes': gt_bboxes} transform = RandomCrop( - crop_size=(10, 10), + crop_size=(10, 11), allow_negative_crop=False, recompute_bbox=True, bbox_clip_border=False) @@ -688,7 +689,7 @@ def test_transform(self): img = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8) gt_bboxes = np.zeros((0, 4), dtype=np.float32) src_results = {'img': img, 'gt_bboxes': gt_bboxes} - transform = RandomCrop(crop_size=(5, 5), allow_negative_crop=False) + transform = RandomCrop(crop_size=(5, 3), allow_negative_crop=False) results = transform(copy.deepcopy(src_results)) self.assertIsNone(results) @@ -696,7 +697,7 @@ def test_transform(self): img = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8) gt_bboxes = np.zeros((0, 4), dtype=np.float32) src_results = {'img': img, 'gt_bboxes': gt_bboxes} - transform = RandomCrop(crop_size=(5, 5), allow_negative_crop=True) + transform = RandomCrop(crop_size=(5, 3), allow_negative_crop=True) results = transform(copy.deepcopy(src_results)) self.assertTrue(isinstance(results, dict)) @@ -721,24 +722,25 @@ def test_transform_use_box_type(self): 'gt_seg_map': gt_seg_map } transform = RandomCrop( - crop_size=(5, 5), + crop_size=(7, 5), allow_negative_crop=False, recompute_bbox=False, bbox_clip_border=True) results = transform(copy.deepcopy(src_results)) h, w = results['img'].shape self.assertEqual(h, 5) - self.assertEqual(w, 5) + self.assertEqual(w, 7) self.assertEqual(results['gt_bboxes'].shape[0], 2) self.assertEqual(results['gt_bboxes_labels'].shape[0], 2) self.assertEqual(results['gt_ignore_flags'].shape[0], 2) - self.assertTupleEqual(results['gt_seg_map'].shape[:2], (5, 5)) + self.assertTupleEqual(results['gt_seg_map'].shape[:2], (5, 7)) # test geometric transformation with homography matrix bboxes = copy.deepcopy(src_results['gt_bboxes'].numpy()) + print(bboxes, results['gt_bboxes']) self.assertTrue( (bbox_project(bboxes, results['homography_matrix'], - (5, 5)) == results['gt_bboxes'].numpy()).all()) + (5, 7)) == results['gt_bboxes'].numpy()).all()) # test recompute_bbox = True gt_masks_ = np.zeros((2, 10, 10), np.uint8) @@ -751,7 +753,7 @@ def test_transform_use_box_type(self): } target_gt_bboxes = np.zeros((1, 4), dtype=np.float32) transform = RandomCrop( - crop_size=(10, 10), + crop_size=(10, 11), allow_negative_crop=False, recompute_bbox=True, bbox_clip_border=True) @@ -776,7 +778,7 @@ def test_transform_use_box_type(self): img = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8) gt_bboxes = HorizontalBoxes(np.zeros((0, 4), dtype=np.float32)) src_results = {'img': img, 'gt_bboxes': gt_bboxes} - transform = RandomCrop(crop_size=(5, 5), allow_negative_crop=False) + transform = RandomCrop(crop_size=(5, 2), allow_negative_crop=False) results = transform(copy.deepcopy(src_results)) self.assertIsNone(results) @@ -784,13 +786,13 @@ def test_transform_use_box_type(self): img = np.random.randint(0, 255, size=(10, 10), dtype=np.uint8) gt_bboxes = HorizontalBoxes(np.zeros((0, 4), dtype=np.float32)) src_results = {'img': img, 'gt_bboxes': gt_bboxes} - transform = RandomCrop(crop_size=(5, 5), allow_negative_crop=True) + transform = RandomCrop(crop_size=(5, 2), allow_negative_crop=True) results = transform(copy.deepcopy(src_results)) self.assertTrue(isinstance(results, dict)) def test_repr(self): crop_type = 'absolute' - crop_size = (10, 10) + crop_size = (10, 5) allow_negative_crop = False recompute_bbox = True bbox_clip_border = False @@ -903,7 +905,7 @@ def test_transform(self): with self.assertRaises(AssertionError): transform = Mosaic(prob=1.5) - transform = Mosaic(img_scale=(10, 12)) + transform = Mosaic(img_scale=(12, 10)) # test assertion for invalid mix_results with self.assertRaises(AssertionError): results = transform(copy.deepcopy(self.results)) @@ -921,7 +923,7 @@ def test_transform_with_no_gt(self): self.results['gt_bboxes'] = np.empty((0, 4), dtype=np.float32) self.results['gt_bboxes_labels'] = np.empty((0, ), dtype=np.int64) self.results['gt_ignore_flags'] = np.empty((0, ), dtype=np.bool) - transform = Mosaic(img_scale=(10, 12)) + transform = Mosaic(img_scale=(12, 10)) self.results['mix_results'] = [copy.deepcopy(self.results)] * 3 results = transform(copy.deepcopy(self.results)) self.assertIsInstance(results, dict) @@ -934,7 +936,7 @@ def test_transform_with_no_gt(self): self.assertTrue(results['gt_ignore_flags'].dtype == bool) def test_transform_use_box_type(self): - transform = Mosaic(img_scale=(10, 12)) + transform = Mosaic(img_scale=(12, 10)) results = copy.deepcopy(self.results) results['gt_bboxes'] = HorizontalBoxes(results['gt_bboxes']) results['mix_results'] = [results] * 3 @@ -984,7 +986,7 @@ def test_transform(self): with self.assertRaises(AssertionError): transform = MixUp(img_scale=640) - transform = MixUp(img_scale=(10, 12)) + transform = MixUp(img_scale=(12, 10)) # test assertion for invalid mix_results with self.assertRaises(AssertionError): results = transform(copy.deepcopy(self.results)) @@ -1006,7 +1008,7 @@ def test_transform_use_box_type(self): results = copy.deepcopy(self.results) results['gt_bboxes'] = HorizontalBoxes(results['gt_bboxes']) - transform = MixUp(img_scale=(10, 12)) + transform = MixUp(img_scale=(12, 10)) results['mix_results'] = [results] results = transform(results) self.assertTrue(results['img'].shape[:2] == (224, 224)) @@ -1233,7 +1235,7 @@ def test_transform(self): results['gt_bboxes_labels'] = gt_bboxes_labels results['gt_ignore_flags'] = gt_ignore_flags crop_module = RandomCenterCropPad( - crop_size=(h - 20, w - 20), + crop_size=(w - 20, h - 20), ratios=(1.0, ), border=128, mean=[123.675, 116.28, 103.53], @@ -1278,7 +1280,7 @@ def test_transform_use_box_type(self): results['gt_bboxes_labels'] = gt_bboxes_labels results['gt_ignore_flags'] = gt_ignore_flags crop_module = RandomCenterCropPad( - crop_size=(h - 20, w - 20), + crop_size=(w - 20, h - 20), ratios=(1.0, ), border=128, mean=[123.675, 116.28, 103.53],