From 022b055a661c52489fa01297fc63b905b55c7f8a Mon Sep 17 00:00:00 2001 From: yamengxi <49829199+yamengxi@users.noreply.github.com> Date: Fri, 8 Jan 2021 01:58:34 +0800 Subject: [PATCH] [Bug Fix] Fix TTA resize scale (#334) * fix tta bug * modify as suggested * fix test_tta bug --- mmseg/datasets/pipelines/test_time_aug.py | 2 +- mmseg/datasets/pipelines/transforms.py | 5 +++-- tests/test_data/test_tta.py | 6 +++--- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/mmseg/datasets/pipelines/test_time_aug.py b/mmseg/datasets/pipelines/test_time_aug.py index bab663653f..473a12bc86 100644 --- a/mmseg/datasets/pipelines/test_time_aug.py +++ b/mmseg/datasets/pipelines/test_time_aug.py @@ -104,7 +104,7 @@ def __call__(self, results): aug_data = [] if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float): h, w = results['img'].shape[:2] - img_scale = [(int(h * ratio), int(w * ratio)) + img_scale = [(int(w * ratio), int(h * ratio)) for ratio in self.img_ratios] else: img_scale = self.img_scale diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 801c666440..e168280adc 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -156,8 +156,9 @@ def _random_scale(self, results): if self.ratio_range is not None: if self.img_scale is None: - scale, scale_idx = self.random_sample_ratio( - results['img'].shape[:2], self.ratio_range) + h, w = results['img'].shape[:2] + scale, scale_idx = self.random_sample_ratio((w, h), + self.ratio_range) else: scale, scale_idx = self.random_sample_ratio( self.img_scale[0], self.ratio_range) diff --git a/tests/test_data/test_tta.py b/tests/test_data/test_tta.py index 61fb5aa340..cc8c71e57c 100644 --- a/tests/test_data/test_tta.py +++ b/tests/test_data/test_tta.py @@ -108,7 +108,7 @@ def test_multi_scale_flip_aug(): ) tta_module = build_from_cfg(tta_transform, PIPELINES) tta_results = tta_module(results.copy()) - assert tta_results['scale'] == [(144, 256), (288, 512), (576, 1024)] + assert tta_results['scale'] == [(256, 144), (512, 288), (1024, 576)] assert tta_results['flip'] == [False, False, False] tta_transform = dict( @@ -120,8 +120,8 @@ def test_multi_scale_flip_aug(): ) tta_module = build_from_cfg(tta_transform, PIPELINES) tta_results = tta_module(results.copy()) - assert tta_results['scale'] == [(144, 256), (144, 256), (288, 512), - (288, 512), (576, 1024), (576, 1024)] + assert tta_results['scale'] == [(256, 144), (256, 144), (512, 288), + (512, 288), (1024, 576), (1024, 576)] assert tta_results['flip'] == [False, True, False, True, False, True] tta_transform = dict(