Skip to content

Commit

Permalink
[Fix] Add swap_labe_pairs in RandomFlip (#2332)
Browse files Browse the repository at this point in the history
* [Fix] Add `swap_labe_pairs` in `RandomFlip`

* [Fix] Add `swap_labe_pairs` in `RandomFlip`

* add reference info

* add swap_label_pairs in results

* revise according to comments

* revise according to comments

* revise according to comments

* docstring

* docstring
  • Loading branch information
MeowZheng authored Oct 17, 2022
1 parent a4c8261 commit bf48ca0
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 27 deletions.
79 changes: 58 additions & 21 deletions mmcv/transforms/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,21 +1025,25 @@ class RandomFlip(BaseTransform):
- flip
- flip_direction
- swap_seg_labels (optional)
Args:
prob (float | list[float], optional): The flipping probability.
Defaults to None.
direction(str | list[str]): The flipping direction. Options
If input is a list, the length must equal ``prob``. Each
element in ``prob`` indicates the flip probability of
corresponding direction. Defaults to 'horizontal'.
prob (float | list[float], optional): The flipping probability.
Defaults to None.
direction(str | list[str]): The flipping direction. Options
If input is a list, the length must equal ``prob``. Each
element in ``prob`` indicates the flip probability of
corresponding direction. Defaults to 'horizontal'.
swap_seg_labels (list, optional): The label pair need to be swapped
for ground truth, like 'left arm' and 'right arm' need to be
swapped after horizontal flipping. For example, ``[(1, 5)]``,
where 1/5 is the label of the left/right arm. Defaults to None.
"""

def __init__(
self,
prob: Optional[Union[float, Iterable[float]]] = None,
direction: Union[str,
Sequence[Optional[str]]] = 'horizontal') -> None:
def __init__(self,
prob: Optional[Union[float, Iterable[float]]] = None,
direction: Union[str, Sequence[Optional[str]]] = 'horizontal',
swap_seg_labels: Optional[Sequence] = None) -> None:
if isinstance(prob, list):
assert mmengine.is_list_of(prob, float)
assert 0 <= sum(prob) <= 1
Expand All @@ -1049,6 +1053,7 @@ def __init__(
raise ValueError(f'probs must be float or list of float, but \
got `{type(prob)}`.')
self.prob = prob
self.swap_seg_labels = swap_seg_labels

valid_directions = ['horizontal', 'vertical', 'diagonal']
if isinstance(direction, str):
Expand All @@ -1064,8 +1069,8 @@ def __init__(
if isinstance(prob, list):
assert len(prob) == len(self.direction)

def flip_bbox(self, bboxes: np.ndarray, img_shape: Tuple[int, int],
direction: str) -> np.ndarray:
def _flip_bbox(self, bboxes: np.ndarray, img_shape: Tuple[int, int],
direction: str) -> np.ndarray:
"""Flip bboxes horizontally.
Args:
Expand Down Expand Up @@ -1096,8 +1101,12 @@ def flip_bbox(self, bboxes: np.ndarray, img_shape: Tuple[int, int],
or 'diagonal', but got '{direction}'")
return flipped

def flip_keypoints(self, keypoints: np.ndarray, img_shape: Tuple[int, int],
direction: str) -> np.ndarray:
def _flip_keypoints(
self,
keypoints: np.ndarray,
img_shape: Tuple[int, int],
direction: str,
) -> np.ndarray:
"""Flip keypoints horizontally, vertically or diagonally.
Args:
Expand Down Expand Up @@ -1127,6 +1136,33 @@ def flip_keypoints(self, keypoints: np.ndarray, img_shape: Tuple[int, int],
flipped = np.concatenate([keypoints, meta_info], axis=-1)
return flipped

def _flip_seg_map(self, seg_map: dict, direction: str) -> np.ndarray:
"""Flip segmentation map horizontally, vertically or diagonally.
Args:
seg_map (numpy.ndarray): segmentation map, shape (H, W).
direction (str): Flip direction. Options are 'horizontal',
'vertical'.
Returns:
numpy.ndarray: Flipped segmentation map.
"""
seg_map = mmcv.imflip(seg_map, direction=direction)
if self.swap_seg_labels is not None:
# to handle datasets with left/right annotations
# like 'Left-arm' and 'Right-arm' in LIP dataset
# Modified from https://github.com/openseg-group/openseg.pytorch/blob/master/lib/datasets/tools/cv2_aug_transforms.py # noqa:E501
# Licensed under MIT license
temp = seg_map.copy()
assert isinstance(self.swap_seg_labels, (tuple, list))
for pair in self.swap_seg_labels:
assert isinstance(pair, (tuple, list)) and len(pair) == 2, \
'swap_seg_labels must be a sequence with pair, but got ' \
f'{self.swap_seg_labels}.'
seg_map[temp == pair[0]] = pair[1]
seg_map[temp == pair[1]] = pair[0]
return seg_map

@cache_randomness
def _choose_direction(self) -> str:
"""Choose the flip direction according to `prob` and `direction`"""
Expand Down Expand Up @@ -1162,19 +1198,20 @@ def _flip(self, results: dict) -> None:

# flip bboxes
if results.get('gt_bboxes', None) is not None:
results['gt_bboxes'] = self.flip_bbox(results['gt_bboxes'],
img_shape,
results['flip_direction'])
results['gt_bboxes'] = self._flip_bbox(results['gt_bboxes'],
img_shape,
results['flip_direction'])

# flip keypoints
if results.get('gt_keypoints', None) is not None:
results['gt_keypoints'] = self.flip_keypoints(
results['gt_keypoints'] = self._flip_keypoints(
results['gt_keypoints'], img_shape, results['flip_direction'])

# flip segs
# flip seg map
if results.get('gt_seg_map', None) is not None:
results['gt_seg_map'] = mmcv.imflip(
results['gt_seg_map'] = self._flip_seg_map(
results['gt_seg_map'], direction=results['flip_direction'])
results['swap_seg_labels'] = self.swap_seg_labels

def _flip_on_direction(self, results: dict) -> None:
"""Function to flip images, bounding boxes, semantic segmentation map
Expand Down
38 changes: 32 additions & 6 deletions tests/test_transforms/test_transforms_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,49 +777,75 @@ def test_transform(self):
'img': np.random.random((224, 224, 3)),
'gt_bboxes': np.array([[0, 1, 100, 101]]),
'gt_keypoints': np.array([[[100, 100, 1.0]]]),
'gt_seg_map': np.random.random((224, 224, 3))
# seg map flip is irrelative with image, so there is no requirement
# that gt_set_map of test data matches image.
'gt_seg_map': np.array([[0, 1], [2, 3]])
}

# horizontal flip
TRANSFORMS = RandomFlip([1.0], ['horizontal'])
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[124, 1, 224,
101]])).all()
assert (results_update['gt_seg_map'] == np.array([[1, 0], [3,
2]])).all()

# diagnal flip
# diagonal flip
TRANSFORMS = RandomFlip([1.0], ['diagonal'])
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[124, 123, 224,
223]])).all()
assert (results_update['gt_seg_map'] == np.array([[3, 2], [1,
0]])).all()

# vertical flip
TRANSFORMS = RandomFlip([1.0], ['vertical'])
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[0, 123, 100,
223]])).all()
assert (results_update['gt_seg_map'] == np.array([[2, 3], [0,
1]])).all()

# horizontal flip when direction is None
TRANSFORMS = RandomFlip(1.0)
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[124, 1, 224,
101]])).all()
assert (results_update['gt_seg_map'] == np.array([[1, 0], [3,
2]])).all()

# horizontal flip and swap label pair
TRANSFORMS = RandomFlip([1.0], ['horizontal'],
swap_seg_labels=[[0, 1]])
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_seg_map'] == np.array([[0, 1], [3,
2]])).all()
assert results_update['swap_seg_labels'] == [[0, 1]]

TRANSFORMS = RandomFlip(0.0)
results_update = TRANSFORMS.transform(copy.deepcopy(results))
assert (results_update['gt_bboxes'] == np.array([[0, 1, 100,
101]])).all()
assert (results_update['gt_seg_map'] == np.array([[0, 1], [2,
3]])).all()

# flip direction is invalid in bbox flip
with pytest.raises(ValueError):
TRANSFORMS = RandomFlip(1.0)
results_update = TRANSFORMS.flip_bbox(results['gt_bboxes'],
(224, 224), 'invalid')
results_update = TRANSFORMS._flip_bbox(results['gt_bboxes'],
(224, 224), 'invalid')

# flip direction is invalid in keypoints flip
with pytest.raises(ValueError):
TRANSFORMS = RandomFlip(1.0)
results_update = TRANSFORMS.flip_keypoints(results['gt_keypoints'],
(224, 224), 'invalid')
results_update = TRANSFORMS._flip_keypoints(
results['gt_keypoints'], (224, 224), 'invalid')

# swap pair is invalid
with pytest.raises(AssertionError):
TRANSFORMS = RandomFlip(1.0, swap_seg_labels='invalid')
results_update = TRANSFORMS._flip_seg_map(results['gt_seg_map'],
'horizontal')

def test_repr(self):
TRANSFORMS = RandomFlip(0.1)
Expand Down

0 comments on commit bf48ca0

Please sign in to comment.