Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add RandomRotate transform #215

Merged
merged 13 commits into from
Nov 7, 2020
2 changes: 1 addition & 1 deletion configs/_base_/datasets/ade20k.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/datasets/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/datasets/cityscapes_769x769.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/datasets/pascal_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
Expand Down
2 changes: 1 addition & 1 deletion configs/_base_/datasets/pascal_voc12.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
Expand Down
99 changes: 92 additions & 7 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import mmcv
import numpy as np
from mmcv.utils import deprecated_api_warning
from numpy import random

from ..builder import PIPELINES
Expand Down Expand Up @@ -232,16 +233,17 @@ class RandomFlip(object):
method.

Args:
flip_ratio (float, optional): The flipping probability. Default: None.
prob (float, optional): The flipping probability. Default: None.
direction(str, optional): The flipping direction. Options are
'horizontal' and 'vertical'. Default: 'horizontal'.
"""

def __init__(self, flip_ratio=None, direction='horizontal'):
self.flip_ratio = flip_ratio
@deprecated_api_warning({'flip_ratio': 'prob'}, cls_name='RandomFlip')
def __init__(self, prob=None, direction='horizontal'):
self.prob = prob
self.direction = direction
if flip_ratio is not None:
assert flip_ratio >= 0 and flip_ratio <= 1
if prob is not None:
assert prob >= 0 and prob <= 1
assert direction in ['horizontal', 'vertical']

def __call__(self, results):
Expand All @@ -257,7 +259,7 @@ def __call__(self, results):
"""

if 'flip' not in results:
flip = True if np.random.rand() < self.flip_ratio else False
flip = True if np.random.rand() < self.prob else False
results['flip'] = flip
if 'flip_direction' not in results:
results['flip_direction'] = self.direction
Expand All @@ -274,7 +276,7 @@ def __call__(self, results):
return results

def __repr__(self):
return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})'
return self.__class__.__name__ + f'(prob={self.prob})'


@PIPELINES.register_module()
Expand Down Expand Up @@ -463,6 +465,89 @@ def __repr__(self):
return self.__class__.__name__ + f'(crop_size={self.crop_size})'


@PIPELINES.register_module()
class RandomRotate(object):
"""Rotate the image & seg.

Args:
prob (float): The rotation probability.
degree (float, tuple[float]): Range of degrees to select from. If
degree is a number instead of tuple like (min, max),
the range of degree will be (``-degree``, ``+degree``)
pad_val (float, optional): Padding value of image. Default: 0.
seg_pad_val (float, optional): Padding value of segmentation map.
Default: 255.
center (tuple[float], optional): Center point (w, h) of the rotation in
the source image. If not specified, the center of the image will be
used. Default: None.
auto_bound (bool): Whether to adjust the image size to cover the whole
rotated image. Default: False
"""

def __init__(self,
prob,
degree,
pad_val=0,
seg_pad_val=255,
center=None,
auto_bound=False):
self.prob = prob
assert prob >= 0 and prob <= 1
if isinstance(degree, (float, int)):
assert degree > 0, f'degree {degree} should be positive'
self.degree = (-degree, degree)
else:
self.degree = degree
assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
f'tuple of (min, max)'
self.pal_val = pad_val
self.seg_pad_val = seg_pad_val
self.center = center
self.auto_bound = auto_bound

def __call__(self, results):
"""Call function to rotate image, semantic segmentation maps.

Args:
results (dict): Result dict from loading pipeline.

Returns:
dict: Rotated results.
"""

rotate = True if np.random.rand() < self.prob else False
degree = np.random.uniform(min(*self.degree), max(*self.degree))
if rotate:
# rotate image
results['img'] = mmcv.imrotate(
results['img'],
angle=degree,
border_value=self.pal_val,
center=self.center,
auto_bound=self.auto_bound)

# rotate segs
for key in results.get('seg_fields', []):
results[key] = mmcv.imrotate(
results[key],
angle=degree,
border_value=self.seg_pad_val,
center=self.center,
auto_bound=self.auto_bound,
interpolation='nearest')
return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, ' \
f'degree={self.degree}, ' \
f'pad_val={self.pal_val}, ' \
f'seg_pad_val={self.seg_pad_val}, ' \
f'center={self.center}, ' \
f'auto_bound={self.auto_bound})'
return repr_str


@PIPELINES.register_module()
class SegRescale(object):
"""Rescale semantic segmentation maps.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_custom_dataset():
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(128, 256), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
Expand Down
51 changes: 46 additions & 5 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,17 @@ def test_resize():


def test_flip():
# test assertion for invalid flip_ratio
# test assertion for invalid prob
with pytest.raises(AssertionError):
transform = dict(type='RandomFlip', flip_ratio=1.5)
transform = dict(type='RandomFlip', prob=1.5)
build_from_cfg(transform, PIPELINES)

# test assertion for invalid direction
with pytest.raises(AssertionError):
transform = dict(
type='RandomFlip', flip_ratio=1, direction='horizonta')
transform = dict(type='RandomFlip', prob=1, direction='horizonta')
build_from_cfg(transform, PIPELINES)

transform = dict(type='RandomFlip', flip_ratio=1)
transform = dict(type='RandomFlip', prob=1)
flip_module = build_from_cfg(transform, PIPELINES)

results = dict()
Expand Down Expand Up @@ -197,6 +196,48 @@ def test_pad():
assert img_shape[1] % 32 == 0


def test_rotate():
# test assertion degree should be tuple[float] or float
with pytest.raises(AssertionError):
transform = dict(type='RandomRotate', rotate_ratio=0.5, degree=-10)
build_from_cfg(transform, PIPELINES)
# test assertion degree should be tuple[float] or float
with pytest.raises(AssertionError):
transform = dict(
type='RandomRotate', rotate_ratio=0.5, degree=(10., 20., 30.))
build_from_cfg(transform, PIPELINES)

transform = dict(type='RandomRotate', degree=10., rotate_ratio=1.)
transform = build_from_cfg(transform, PIPELINES)

assert str(transform) == f'RandomRotate(' \
f'rotate_ratio={1.}, ' \
f'degree=({-10.}, {10.}), ' \
f'pad_val={0}, ' \
f'seg_pad_val={255}, ' \
f'center={None}, ' \
f'auto_bound={False})'

results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, _ = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0

results = transform(results)
assert results['img'].shape[:2] == (h, w)
assert results['gt_semantic_seg'].shape[:2] == (h, w)


def test_normalize():
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
Expand Down