Skip to content

Commit

Permalink
[Refactor] Fix bugs of browsing transformed training images (open-mml…
Browse files Browse the repository at this point in the history
…ab#1591)

* fix bugs of browsing transformed img

* fix unittest
  • Loading branch information
Tau-J authored Aug 24, 2022
1 parent f76069d commit 19aa89d
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 6 deletions.
10 changes: 8 additions & 2 deletions mmpose/datasets/transforms/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class PackPoseInputs(BaseTransform):
'bbox_scale': 'bbox_scales',
'bbox_score': 'bbox_scores',
'keypoints': 'keypoints',
'keypoints_visible': 'keypoints_visible'
'keypoints_visible': 'keypoints_visible',
'transformed_keypoints': 'transformed_keypoints'
}

label_mapping_table = {
Expand All @@ -94,8 +95,10 @@ class PackPoseInputs(BaseTransform):
def __init__(self,
meta_keys=('id', 'img_id', 'img_path', 'ori_shape',
'img_shape', 'input_size', 'flip',
'flip_direction', 'flip_indices')):
'flip_direction', 'flip_indices'),
pack_transformed=False):
self.meta_keys = meta_keys
self.pack_transformed = pack_transformed

def transform(self, results: dict) -> dict:
"""Method to pack the input data.
Expand All @@ -122,6 +125,9 @@ def transform(self, results: dict) -> dict:
for key, packed_key in self.instance_mapping_table.items():
if key in results:
gt_instances.set_field(results[key], packed_key)
if not self.pack_transformed:
if 'transformed_keypoints' in gt_instances:
del gt_instances['transformed_keypoints']
data_sample.gt_instances = gt_instances

# pack instance labels
Expand Down
3 changes: 2 additions & 1 deletion mmpose/engine/hooks/visualization_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional, Sequence

import mmcv
import mmengine
from mmengine.hooks import Hook
from mmengine.runner import Runner
from mmengine.visualization import Visualizer
Expand Down Expand Up @@ -131,7 +132,7 @@ def after_test_iter(self, runner: Runner, batch_idx: int,
mmcv.mkdir_or_exist(self.test_out_dir)

if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
self.file_client = mmengine.FileClient(**self.file_client_args)

for input_data, output in zip(data_batch, outputs):
self._test_index += 1
Expand Down
10 changes: 9 additions & 1 deletion mmpose/structures/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bbox import (bbox_cs2xywh, bbox_cs2xyxy, bbox_xywh2cs, bbox_xywh2xyxy,
bbox_xyxy2cs, bbox_xyxy2xywh, flip_bbox,
get_udp_warp_matrix, get_warp_matrix)
from .keypoint import flip_keypoints
from .multilevel_pixel_data import MultilevelPixelData
from .pose_data_sample import PoseDataSample

__all__ = ['PoseDataSample', 'MultilevelPixelData']
__all__ = [
'PoseDataSample', 'MultilevelPixelData', 'bbox_cs2xywh', 'bbox_cs2xyxy',
'bbox_xywh2cs', 'bbox_xywh2xyxy', 'bbox_xyxy2cs', 'bbox_xyxy2xywh',
'flip_bbox', 'get_udp_warp_matrix', 'get_warp_matrix', 'flip_keypoints'
]
3 changes: 2 additions & 1 deletion mmpose/visualization/local_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def _draw_instances_kpts(self,
img_h, img_w, _ = image.shape

if 'keypoints' in instances:
keypoints = instances.keypoints
keypoints = instances.get('transformed_keypoints',
instances.keypoints)

if 'scores' in instances and self.show_keypoint_weight:
scores = instances.scores
Expand Down
10 changes: 10 additions & 0 deletions tests/test_datasets/test_transforms/test_formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,19 @@ def setUp(self):
np.random.randint(0, 100, (1, 17, 2)).astype(np.float32),
'keypoint_y_labels':
np.random.randint(0, 100, (1, 17, 2)).astype(np.float32),
'transformed_keypoints':
np.random.randint(0, 100, (1, 17, 2)).astype(np.float32),
}
self.meta_keys = ('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction')

def test_transform(self):
transform = PackPoseInputs(
meta_keys=self.meta_keys, pack_transformed=True)
results = transform(copy.deepcopy(self.results_topdown))
self.assertIn('transformed_keypoints',
results['data_sample'].gt_instances)

transform = PackPoseInputs(meta_keys=self.meta_keys)
results = transform(copy.deepcopy(self.results_topdown))
self.assertIn('inputs', results)
Expand All @@ -78,6 +86,8 @@ def test_transform(self):
self.assertEqual(len(results['data_sample'].gt_instances), 1)
self.assertIsInstance(results['data_sample'].gt_fields.heatmaps,
torch.Tensor)
self.assertNotIn('transformed_keypoints',
results['data_sample'].gt_instances)

# test when results['img'] is sequence of frames
results = copy.deepcopy(self.results_topdown)
Expand Down
6 changes: 5 additions & 1 deletion tools/misc/browse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def parse_args():
help='the interval of show (s)')
parser.add_argument(
'--mode',
default='original',
default='transformed',
type=str,
choices=['original', 'transformed'],
help='display mode; display original pictures or transformed pictures'
Expand Down Expand Up @@ -90,6 +90,10 @@ def main():
if 'bbox_file' in cfg[f'{args.phase}_dataloader'].dataset:
cfg[f'{args.phase}_dataloader'].dataset.bbox_file = None

# pack transformed keypoints for visualization
cfg[f'{args.phase}_dataloader'].dataset.pipeline[
-1].pack_transformed = True

dataset = build_from_cfg(cfg[f'{args.phase}_dataloader'].dataset, DATASETS)

visualizer = VISUALIZERS.build(cfg.visualizer)
Expand Down

0 comments on commit 19aa89d

Please sign in to comment.