diff --git a/mmcls/datasets/pipelines/transforms.py b/mmcls/datasets/pipelines/transforms.py index ff3ae755515..a56ce3c362b 100644 --- a/mmcls/datasets/pipelines/transforms.py +++ b/mmcls/datasets/pipelines/transforms.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import inspect import math import random @@ -1117,19 +1118,23 @@ def mapper(d, keymap): return updated_dict def __call__(self, results): + + # backup gt_label in case Albu modify it. + _gt_label = copy.deepcopy(results.get('gt_label', None)) + # dict to albumentations format results = self.mapper(results, self.keymap_to_albu) + # process aug results = self.aug(**results) - if 'gt_labels' in results: - if isinstance(results['gt_labels'], list): - results['gt_labels'] = np.array(results['gt_labels']) - results['gt_labels'] = results['gt_labels'].astype(np.int64) - # back to the original format results = self.mapper(results, self.keymap_back) + if _gt_label is not None: + # recover backup gt_label + results.update({'gt_label': _gt_label}) + # update final shape if self.update_pad_shape: results['pad_shape'] = results['img'].shape diff --git a/tests/test_data/test_pipelines/test_transform.py b/tests/test_data/test_pipelines/test_transform.py index b973cc1f219..b23e84b5849 100644 --- a/tests/test_data/test_pipelines/test_transform.py +++ b/tests/test_data/test_pipelines/test_transform.py @@ -1268,14 +1268,25 @@ def reset_results(results, original_img): def test_albu_transform(): results = dict( img_prefix=osp.join(osp.dirname(__file__), '../../data'), - img_info=dict(filename='color.jpg')) + img_info=dict(filename='color.jpg'), + gt_label=np.array(1)) # Define simple pipeline load = dict(type='LoadImageFromFile') load = build_from_cfg(load, PIPELINES) albu_transform = dict( - type='Albu', transforms=[dict(type='ChannelShuffle', p=1)]) + type='Albu', + transforms=[ + dict(type='ChannelShuffle', p=1), + dict( + type='ShiftScaleRotate', + shift_limit=0.0625, + scale_limit=0.0, + rotate_limit=0, + interpolation=1, + p=1) + ]) albu_transform = build_from_cfg(albu_transform, PIPELINES) normalize = dict(type='Normalize', mean=[0] * 3, std=[0] * 3, to_rgb=True) @@ -1287,3 +1298,4 @@ def test_albu_transform(): results = normalize(results) assert results['img'].dtype == np.float32 + assert results['gt_label'].shape == np.array(1).shape