diff --git a/demo/README.md b/demo/README.md index f824bbf0cf..60ecbc3398 100644 --- a/demo/README.md +++ b/demo/README.md @@ -65,3 +65,11 @@ This page provides tutorials about running demos. Please click the caption for m [3D hand_pose demo](docs/3d_hand_demo.md) + +
+ +
+
+ +[Webcam demo](docs/webcam_demo.md) +
diff --git a/demo/docs/webcam_demo.md b/demo/docs/webcam_demo.md new file mode 100644 index 0000000000..17db79752f --- /dev/null +++ b/demo/docs/webcam_demo.md @@ -0,0 +1,49 @@ +## Webcam Demo + +We provide a webcam demo tool which integrartes detection and 2D pose estimation for humans and animals. You can simply run the following command: + +```python +python demo/webcam_demo.py +``` + +It will launch a window to display the webcam video steam with detection and pose estimation results: + +
+
+
+ +### Usage Tips + +- **Which model is used in the demo tool?** + + Please check the following default arguments in the script. You can also choose other models from the [MMDetection Model Zoo](https://github.com/open-mmlab/mmdetection/blob/master/docs/model_zoo.md) and [MMPose Model Zoo](https://mmpose.readthedocs.io/en/latest/modelzoo.html#) or use your own models. + + | Model | Arguments | + | :--: | :-- | + | Detection | `--det-config`, `--det-checkpoint` | + | Human Pose | `--human-pose-config`, `--human-pose-checkpoint` | + | Animal Pose | `--animal-pose-config`, `--animal-pose-checkpoint` | + +- **Can this tool run without GPU?** + + Yes, you can set `--device=cpu` and the model inference will be performed on CPU. Of course, this may cause a low inference FPS compared to using GPU devices. + +- **Why there is time delay between the pose visualization and the video?** + + The video I/O and model inference are running asynchronously and the latter usually takes more time for a single frame. To allevidate the time delay, you can: + + 1. set `--display-delay=MILLISECONDS` to defer the video stream, according to the inference delay shown at the top left corner. Or, + + 2. set `--synchronous-mode` to force video stream being aligned with inference results. This may reduce the video display FPS. + +- **Can this tool process video files?** + + Yes. You can set `--cam_id=VIDEO_FILE_PATH` to run the demo tool in offline mode on a video file. Note that `--synchronous-mode` should be set in this case. + +- **How to enable/disable the special effects?** + + The special effects can be enabled/disabled at launch time by setting arguments like `--bugeye`, `--sunglasses`, *etc*. You can also toggle the effects by keyboard shorcuts like `b`, `s` when the tool starts. + +- **What if my computer doesn't have a camera?** + + You can use a smart phone as a webcam with apps like [Camo](https://reincubate.com/camo/) or [DroidCam](https://www.dev47apps.com/). diff --git a/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_1class.py b/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_1class.py index f45ad10cb0..4e60b6b739 100644 --- a/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_1class.py +++ b/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_1class.py @@ -206,7 +206,7 @@ max_per_img=100))) dataset_type = 'CocoDataset' -data_root = 'data/coco/' +data_root = 'data/coco' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) train_pipeline = [ @@ -239,17 +239,17 @@ workers_per_gpu=2, train=dict( type=dataset_type, - ann_file=data_root + 'annotations/instances_train2017.json', - img_prefix=data_root + 'train2017/', + ann_file=f'{data_root}/annotations/instances_train2017.json', + img_prefix=f'{data_root}/train2017/', pipeline=train_pipeline), val=dict( type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', + ann_file=f'{data_root}/annotations/instances_val2017.json', + img_prefix=f'{data_root}/val2017/', pipeline=test_pipeline), test=dict( type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', + ann_file=f'{data_root}/annotations/instances_val2017.json', + img_prefix=f'{data_root}/val2017/', pipeline=test_pipeline)) evaluation = dict(interval=1, metric='bbox') diff --git a/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_coco.py b/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_coco.py index 58b96e62e1..f91bd0d105 100644 --- a/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_coco.py +++ b/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_coco.py @@ -207,7 +207,7 @@ max_per_img=100))) dataset_type = 'CocoDataset' -data_root = 'data/coco/' +data_root = 'data/coco' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) train_pipeline = [ @@ -240,17 +240,17 @@ workers_per_gpu=2, train=dict( type=dataset_type, - ann_file=data_root + 'annotations/instances_train2017.json', - img_prefix=data_root + 'train2017/', + ann_file=f'{data_root}/annotations/instances_train2017.json', + img_prefix=f'{data_root}/train2017/', pipeline=train_pipeline), val=dict( type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', + ann_file=f'{data_root}/annotations/instances_val2017.json', + img_prefix=f'{data_root}/val2017/', pipeline=test_pipeline), test=dict( type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', + ann_file=f'{data_root}/annotations/instances_val2017.json', + img_prefix=f'{data_root}/val2017/', pipeline=test_pipeline)) evaluation = dict(interval=1, metric='bbox') diff --git a/demo/mmdetection_cfg/faster_rcnn_r50_fpn_1class.py b/demo/mmdetection_cfg/faster_rcnn_r50_fpn_1class.py index 8eddf232b7..ee54f5b66b 100644 --- a/demo/mmdetection_cfg/faster_rcnn_r50_fpn_1class.py +++ b/demo/mmdetection_cfg/faster_rcnn_r50_fpn_1class.py @@ -133,7 +133,7 @@ )) dataset_type = 'CocoDataset' -data_root = 'data/coco/' +data_root = 'data/coco' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) train_pipeline = [ @@ -166,17 +166,17 @@ workers_per_gpu=2, train=dict( type=dataset_type, - ann_file=data_root + 'annotations/instances_train2017.json', - img_prefix=data_root + 'train2017/', + ann_file=f'{data_root}/annotations/instances_train2017.json', + img_prefix=f'{data_root}/train2017/', pipeline=train_pipeline), val=dict( type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', + ann_file=f'{data_root}/annotations/instances_val2017.json', + img_prefix=f'{data_root}/val2017/', pipeline=test_pipeline), test=dict( type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', + ann_file=f'{data_root}/annotations/instances_val2017.json', + img_prefix=f'{data_root}/val2017/', pipeline=test_pipeline)) evaluation = dict(interval=1, metric='bbox') diff --git a/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py b/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py index d5c17df89a..a9ad9528b2 100644 --- a/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py +++ b/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py @@ -133,7 +133,7 @@ )) dataset_type = 'CocoDataset' -data_root = 'data/coco/' +data_root = 'data/coco' img_norm_cfg = dict( mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) train_pipeline = [ @@ -166,17 +166,17 @@ workers_per_gpu=2, train=dict( type=dataset_type, - ann_file=data_root + 'annotations/instances_train2017.json', - img_prefix=data_root + 'train2017/', + ann_file=f'{data_root}/annotations/instances_train2017.json', + img_prefix=f'{data_root}/train2017/', pipeline=train_pipeline), val=dict( type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', + ann_file=f'{data_root}/annotations/instances_val2017.json', + img_prefix=f'{data_root}/val2017/', pipeline=test_pipeline), test=dict( type=dataset_type, - ann_file=data_root + 'annotations/instances_val2017.json', - img_prefix=data_root + 'val2017/', + ann_file=f'{data_root}/annotations/instances_val2017.json', + img_prefix=f'{data_root}/val2017/', pipeline=test_pipeline)) evaluation = dict(interval=1, metric='bbox') diff --git a/demo/mmdetection_cfg/yolov3_d53_320_273e_coco.py b/demo/mmdetection_cfg/yolov3_d53_320_273e_coco.py new file mode 100644 index 0000000000..d7e9cca1eb --- /dev/null +++ b/demo/mmdetection_cfg/yolov3_d53_320_273e_coco.py @@ -0,0 +1,140 @@ +# model settings +model = dict( + type='YOLOV3', + pretrained='open-mmlab://darknet53', + backbone=dict(type='Darknet', depth=53, out_indices=(3, 4, 5)), + neck=dict( + type='YOLOV3Neck', + num_scales=3, + in_channels=[1024, 512, 256], + out_channels=[512, 256, 128]), + bbox_head=dict( + type='YOLOV3Head', + num_classes=80, + in_channels=[512, 256, 128], + out_channels=[1024, 512, 256], + anchor_generator=dict( + type='YOLOAnchorGenerator', + base_sizes=[[(116, 90), (156, 198), (373, 326)], + [(30, 61), (62, 45), (59, 119)], + [(10, 13), (16, 30), (33, 23)]], + strides=[32, 16, 8]), + bbox_coder=dict(type='YOLOBBoxCoder'), + featmap_strides=[32, 16, 8], + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0, + reduction='sum'), + loss_conf=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0, + reduction='sum'), + loss_xy=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=2.0, + reduction='sum'), + loss_wh=dict(type='MSELoss', loss_weight=2.0, reduction='sum')), + # training and testing settings + train_cfg=dict( + assigner=dict( + type='GridAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0)), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + conf_thr=0.005, + nms=dict(type='nms', iou_threshold=0.45), + max_per_img=100)) +# dataset settings +dataset_type = 'CocoDataset' +data_root = 'data/coco' +img_norm_cfg = dict(mean=[0, 0, 0], std=[255., 255., 255.], to_rgb=True) +train_pipeline = [ + dict(type='LoadImageFromFile', to_float32=True), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='PhotoMetricDistortion'), + dict( + type='Expand', + mean=img_norm_cfg['mean'], + to_rgb=img_norm_cfg['to_rgb'], + ratio_range=(1, 2)), + dict( + type='MinIoURandomCrop', + min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9), + min_crop_size=0.3), + dict(type='Resize', img_scale=(320, 320), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(320, 320), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img']) + ]) +] +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type=dataset_type, + ann_file=f'{data_root}/annotations/instances_train2017.json', + img_prefix=f'{data_root}/train2017/', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + ann_file=f'{data_root}/annotations/instances_val2017.json', + img_prefix=f'{data_root}/val2017/', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + ann_file=f'{data_root}/annotations/instances_val2017.json', + img_prefix=f'{data_root}/val2017/', + pipeline=test_pipeline)) +# optimizer +optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=2000, # same as burn-in in darknet + warmup_ratio=0.1, + step=[218, 246]) +# runtime settings +runner = dict(type='EpochBasedRunner', max_epochs=273) +evaluation = dict(interval=1, metric=['bbox']) + +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +custom_hooks = [dict(type='NumClassCheckHook')] + +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] diff --git a/demo/mmtracking_cfg/tracktor_faster-rcnn_r50_fpn_4e_mot17-private.py b/demo/mmtracking_cfg/tracktor_faster-rcnn_r50_fpn_4e_mot17-private.py index af31ccaf18..2973d16e87 100644 --- a/demo/mmtracking_cfg/tracktor_faster-rcnn_r50_fpn_4e_mot17-private.py +++ b/demo/mmtracking_cfg/tracktor_faster-rcnn_r50_fpn_4e_mot17-private.py @@ -202,7 +202,7 @@ std=[58.395, 57.12, 57.375], to_rgb=True), dict(type='Pad', size_divisor=32), - dict(type='ImageToTensor', keys=['img']), + dict(type='DefaultFormatBundle', keys=['img']), dict(type='VideoCollect', keys=['img']) ]) ] @@ -272,7 +272,7 @@ std=[58.395, 57.12, 57.375], to_rgb=True), dict(type='Pad', size_divisor=32), - dict(type='ImageToTensor', keys=['img']), + dict(type='DefaultFormatBundle', keys=['img']), dict(type='VideoCollect', keys=['img']) ]) ]), @@ -296,7 +296,7 @@ std=[58.395, 57.12, 57.375], to_rgb=True), dict(type='Pad', size_divisor=32), - dict(type='ImageToTensor', keys=['img']), + dict(type='DefaultFormatBundle', keys=['img']), dict(type='VideoCollect', keys=['img']) ]) ])) diff --git a/demo/resources/sunglasses.jpg b/demo/resources/sunglasses.jpg new file mode 100644 index 0000000000..5d3cee8702 Binary files /dev/null and b/demo/resources/sunglasses.jpg differ diff --git a/demo/webcam_demo.py b/demo/webcam_demo.py new file mode 100644 index 0000000000..aaf2f14820 --- /dev/null +++ b/demo/webcam_demo.py @@ -0,0 +1,567 @@ +import argparse +import time +from collections import deque +from queue import Queue +from threading import Event, Lock, Thread + +import cv2 +import numpy as np + +from mmpose.apis import (get_track_id, inference_top_down_pose_model, + init_pose_model, vis_pose_result) +from mmpose.core import apply_bugeye_effect, apply_sunglasses_effect +from mmpose.utils import StopWatch + +try: + from mmdet.apis import inference_detector, init_detector + has_mmdet = True +except (ImportError, ModuleNotFoundError): + has_mmdet = False + +try: + import psutil + psutil_proc = psutil.Process() +except (ImportError, ModuleNotFoundError): + psutil_proc = None + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--cam_id', type=str, default='0') + parser.add_argument( + '--det_config', + type=str, + default='demo/mmdetection_cfg/yolov3_d53_320_273e_coco.py', + help='Config file for detection') + parser.add_argument( + '--det_checkpoint', + type=str, + default='https://download.openmmlab.com/mmdetection/v2.0/yolo/' + 'yolov3_d53_320_273e_coco/yolov3_d53_320_273e_coco-421362b6.pth', + help='Checkpoint file for detection') + parser.add_argument( + '--enable_human_pose', + type=int, + default=1, + help='Enable human pose estimation') + parser.add_argument( + '--enable_animal_pose', + type=int, + default=1, + help='Enable animal pose estimation') + parser.add_argument( + '--human_pose_config', + type=str, + default='configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/' + 'hrnet_w48_coco_256x192.py', + help='Config file for human pose') + parser.add_argument( + '--human_pose_checkpoint', + type=str, + default='https://download.openmmlab.com/mmpose/top_down/hrnet/' + 'hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth', + help='Checkpoint file for human pose') + parser.add_argument( + '--human_det_ids', + type=int, + default=[1], + nargs='+', + help='Object category label of human in detection results.' + 'Default is [1(person)], following COCO definition.') + parser.add_argument( + '--animal_pose_config', + type=str, + default='configs/animal/2d_kpt_sview_rgb_img/topdown_heatmap/' + 'animalpose/hrnet_w32_animalpose_256x256.py', + help='Config file for animal pose') + parser.add_argument( + '--animal_pose_checkpoint', + type=str, + default='https://download.openmmlab.com/mmpose/animal/hrnet/' + 'hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth', + help='Checkpoint file for animal pose') + parser.add_argument( + '--animal_det_ids', + type=int, + default=[16, 17, 18, 19, 20], + nargs='+', + help='Object category label of animals in detection results' + 'Default is [16(cat), 17(dog), 18(horse), 19(sheep), 20(cow)], ' + 'following COCO definition.') + parser.add_argument( + '--device', default='cuda:0', help='Device used for inference') + parser.add_argument( + '--det-score-thr', + type=float, + default=0.5, + help='bbox score threshold') + parser.add_argument( + '--kpt-thr', type=float, default=0.3, help='bbox score threshold') + parser.add_argument( + '--show-pose', + type=lambda s: s != 'False', + default=True, + choices=['True', 'False'], + help='Show pose estimation results. Set False to disable the pose' + 'visualization. Default: True') + parser.add_argument( + '--sunglasses', action='store_true', help='Apply `sunglasses` effect.') + parser.add_argument( + '--bugeye', action='store_true', help='Apply `bug-eye` effect.') + + parser.add_argument( + '--out-video-file', + type=str, + default=None, + help='Record the video into a file. This may reduce the frame rate') + + parser.add_argument( + '--out-video-fps', + type=int, + default=20, + help='Set the FPS of the output video file.') + + parser.add_argument( + '--buffer-size', + type=int, + default=-1, + help='Frame buffer size. If set -1, the buffer size will be ' + 'automatically inferred from the display delay time. Deafult: -1') + + parser.add_argument( + '--inference_fps', + type=int, + default=10, + help='Maximum inference FPS. This is to limit the resource consuming ' + 'especially when the detection and pose model are lightweight and ' + 'very fast. Default: 10.') + + parser.add_argument( + '--display-delay', + type=int, + default=0, + help='Delay the output video in milliseconds. This can be used to ' + 'align the output video and inference results. The delay can be ' + 'disabled by setting a non-positive delay time. Default: 0') + + parser.add_argument( + '--synchronous-mode', + action='store_true', + help='Enable synchronous mode that video I/O and inference will be ' + 'temporally aligned. Note that this will reduce the display FPS.') + + return parser.parse_args() + + +def process_mmdet_results(mmdet_results, class_names=None, cat_ids=1): + """Process mmdet results to mmpose input format. + + Args: + mmdet_results: raw output of mmdet model + class_names: class names of mmdet model + cat_ids (int or List[int]): category id list that will be preserved + Returns: + List[Dict]: detection results for mmpose input + """ + if isinstance(mmdet_results, tuple): + mmdet_results = mmdet_results[0] + + if not isinstance(cat_ids, (list, tuple)): + cat_ids = [cat_ids] + + # only keep bboxes of interested classes + bbox_results = [mmdet_results[i - 1] for i in cat_ids] + bboxes = np.vstack(bbox_results) + + # get textual labels of classes + labels = np.concatenate([ + np.full(bbox.shape[0], i - 1, dtype=np.int32) + for i, bbox in zip(cat_ids, bbox_results) + ]) + if class_names is None: + labels = [f'class: {i}' for i in labels] + else: + labels = [class_names[i] for i in labels] + + det_results = [] + for bbox, label in zip(bboxes, labels): + det_result = dict(bbox=bbox, label=label) + det_results.append(det_result) + return det_results + + +def read_camera(): + # init video reader + print('Thread "input" started') + cam_id = args.cam_id + if cam_id.isdigit(): + cam_id = int(cam_id) + vid_cap = cv2.VideoCapture(cam_id) + if not vid_cap.isOpened(): + print(f'Cannot open camera (ID={cam_id})') + exit() + + while not event_exit.is_set(): + # capture a camera frame + ret_val, frame = vid_cap.read() + if ret_val: + ts_input = time.time() + + event_inference_done.clear() + with input_queue_mutex: + input_queue.append((ts_input, frame)) + + if args.synchronous_mode: + event_inference_done.wait() + + frame_buffer.put((ts_input, frame)) + else: + # input ending signal + frame_buffer.put((None, None)) + break + + vid_cap.release() + + +def inference_detection(): + print('Thread "det" started') + stop_watch = StopWatch(window=10) + min_interval = 1.0 / args.inference_fps + _ts_last = None # timestamp when last inference was done + + while True: + while len(input_queue) < 1: + time.sleep(0.001) + with input_queue_mutex: + ts_input, frame = input_queue.popleft() + # inference detection + with stop_watch.timeit('Det'): + mmdet_results = inference_detector(det_model, frame) + + t_info = stop_watch.report_strings() + with det_result_queue_mutex: + det_result_queue.append((ts_input, frame, t_info, mmdet_results)) + + # limit the inference FPS + _ts = time.time() + if _ts_last is not None and _ts - _ts_last < min_interval: + time.sleep(min_interval - _ts + _ts_last) + _ts_last = time.time() + + +def inference_pose(): + print('Thread "pose" started') + stop_watch = StopWatch(window=10) + + while True: + while len(det_result_queue) < 1: + time.sleep(0.001) + with det_result_queue_mutex: + ts_input, frame, t_info, mmdet_results = det_result_queue.popleft() + + pose_results_list = [] + for model_info, pose_history in zip(pose_model_list, + pose_history_list): + model_name = model_info['name'] + pose_model = model_info['model'] + cat_ids = model_info['cat_ids'] + pose_results_last = pose_history['pose_results_last'] + next_id = pose_history['next_id'] + + with stop_watch.timeit(model_name): + # process mmdet results + det_results = process_mmdet_results( + mmdet_results, + class_names=det_model.CLASSES, + cat_ids=cat_ids) + + # inference pose model + dataset_name = pose_model.cfg.data['test']['type'] + pose_results, _ = inference_top_down_pose_model( + pose_model, + frame, + det_results, + bbox_thr=args.det_score_thr, + format='xyxy', + dataset=dataset_name) + + pose_results, next_id = get_track_id( + pose_results, + pose_results_last, + next_id, + use_oks=False, + tracking_thr=0.3, + use_one_euro=True, + fps=None) + + pose_results_list.append(pose_results) + + # update pose history + pose_history['pose_results_last'] = pose_results + pose_history['next_id'] = next_id + + t_info += stop_watch.report_strings() + with pose_result_queue_mutex: + pose_result_queue.append((ts_input, t_info, pose_results_list)) + + event_inference_done.set() + + +def display(): + print('Thread "display" started') + stop_watch = StopWatch(window=10) + + # initialize result status + ts_inference = None # timestamp of the latest inference result + fps_inference = 0. # infenrece FPS + t_delay_inference = 0. # inference result time delay + pose_results_list = None # latest inference result + t_info = [] # upstream time information (list[str]) + + # initialize visualization and output + sunglasses_img = None # resource image for sunglasses effect + text_color = (228, 183, 61) # text color to show time/system information + vid_out = None # video writer + + # show instructions + print('Keyboard shortcuts: ') + print('"v": Toggle the visualization of bounding boxes and poses.') + print('"s": Toggle the sunglasses effect.') + print('"b": Toggle the bug-eye effect.') + print('"Q", "q" or Esc: Exit.') + + while True: + with stop_watch.timeit('_FPS_'): + # acquire a frame from buffer + ts_input, frame = frame_buffer.get() + # input ending signal + if ts_input is None: + break + + img = frame + + # get pose estimation results + if len(pose_result_queue) > 0: + with pose_result_queue_mutex: + _result = pose_result_queue.popleft() + _ts_input, t_info, pose_results_list = _result + + _ts = time.time() + if ts_inference is not None: + fps_inference = 1.0 / (_ts - ts_inference) + ts_inference = _ts + t_delay_inference = (_ts - _ts_input) * 1000 + + # visualize detection and pose results + if pose_results_list is not None: + for model_info, pose_results in zip(pose_model_list, + pose_results_list): + pose_model = model_info['model'] + bbox_color = model_info['bbox_color'] + + dataset_name = pose_model.cfg.data['test']['type'] + + # show pose results + if args.show_pose: + img = vis_pose_result( + pose_model, + img, + pose_results, + radius=4, + thickness=2, + dataset=dataset_name, + kpt_score_thr=args.kpt_thr, + bbox_color=bbox_color) + + # sunglasses effect + if args.sunglasses: + if dataset_name == 'TopDownCocoDataset': + left_eye_idx = 1 + right_eye_idx = 2 + elif dataset_name == 'AnimalPoseDataset': + left_eye_idx = 0 + right_eye_idx = 1 + else: + raise ValueError( + 'Sunglasses effect does not support' + f'{dataset_name}') + if sunglasses_img is None: + # The image attributes to: + # https://www.vecteezy.com/free-vector/glass + # Glass Vectors by Vecteezy + sunglasses_img = cv2.imread( + 'demo/resources/sunglasses.jpg') + img = apply_sunglasses_effect(img, pose_results, + sunglasses_img, + left_eye_idx, + right_eye_idx) + # bug-eye effect + if args.bugeye: + if dataset_name == 'TopDownCocoDataset': + left_eye_idx = 1 + right_eye_idx = 2 + elif dataset_name == 'AnimalPoseDataset': + left_eye_idx = 0 + right_eye_idx = 1 + else: + raise ValueError('Bug-eye effect does not support' + f'{dataset_name}') + img = apply_bugeye_effect(img, pose_results, + left_eye_idx, right_eye_idx) + + # delay control + if args.display_delay > 0: + t_sleep = args.display_delay * 0.001 - (time.time() - ts_input) + if t_sleep > 0: + time.sleep(t_sleep) + t_delay = (time.time() - ts_input) * 1000 + + # show time information + t_info_display = stop_watch.report_strings() # display fps + t_info_display.append(f'Inference FPS: {fps_inference:>5.1f}') + t_info_display.append(f'Delay: {t_delay:>3.0f}') + t_info_display.append( + f'Inference Delay: {t_delay_inference:>3.0f}') + t_info_str = ' | '.join(t_info_display + t_info) + cv2.putText(img, t_info_str, (20, 20), cv2.FONT_HERSHEY_DUPLEX, + 0.3, text_color, 1) + # collect system information + sys_info = [ + f'RES: {img.shape[1]}x{img.shape[0]}', + f'Buffer: {frame_buffer.qsize()}/{frame_buffer.maxsize}' + ] + if psutil_proc is not None: + sys_info += [ + f'CPU: {psutil_proc.cpu_percent():.1f}%', + f'MEM: {psutil_proc.memory_percent():.1f}%' + ] + sys_info_str = ' | '.join(sys_info) + cv2.putText(img, sys_info_str, (20, 40), cv2.FONT_HERSHEY_DUPLEX, + 0.3, text_color, 1) + + # save the output video frame + if args.out_video_file is not None: + if vid_out is None: + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + fps = args.out_video_fps + frame_size = (img.shape[1], img.shape[0]) + vid_out = cv2.VideoWriter(args.out_video_file, fourcc, fps, + frame_size) + + vid_out.write(img) + + # display + cv2.imshow('mmpose webcam demo', img) + keyboard_input = cv2.waitKey(1) + if keyboard_input in (27, ord('q'), ord('Q')): + break + elif keyboard_input == ord('s'): + args.sunglasses = not args.sunglasses + elif keyboard_input == ord('b'): + args.bugeye = not args.bugeye + elif keyboard_input == ord('v'): + args.show_pose = not args.show_pose + + cv2.destroyAllWindows() + if vid_out is not None: + vid_out.release() + event_exit.set() + + +def main(): + global args + global frame_buffer + global input_queue, input_queue_mutex + global det_result_queue, det_result_queue_mutex + global pose_result_queue, pose_result_queue_mutex + global det_model, pose_model_list, pose_history_list + global event_exit, event_inference_done + + args = parse_args() + + assert has_mmdet, 'Please install mmdet to run the demo.' + assert args.det_config is not None + assert args.det_checkpoint is not None + + # build detection model + det_model = init_detector( + args.det_config, args.det_checkpoint, device=args.device.lower()) + + # build pose models + pose_model_list = [] + if args.enable_human_pose: + pose_model = init_pose_model( + args.human_pose_config, + args.human_pose_checkpoint, + device=args.device.lower()) + model_info = { + 'name': 'HumanPose', + 'model': pose_model, + 'cat_ids': args.human_det_ids, + 'bbox_color': (148, 139, 255), + } + pose_model_list.append(model_info) + if args.enable_animal_pose: + pose_model = init_pose_model( + args.animal_pose_config, + args.animal_pose_checkpoint, + device=args.device.lower()) + model_info = { + 'name': 'AnimalPose', + 'model': pose_model, + 'cat_ids': args.animal_det_ids, + 'bbox_color': 'cyan', + } + pose_model_list.append(model_info) + + # store pose history for pose tracking + pose_history_list = [] + for _ in range(len(pose_model_list)): + pose_history_list.append({'pose_results_last': [], 'next_id': 0}) + + # frame buffer + if args.buffer_size > 0: + buffer_size = args.buffer_size + else: + # infer buffer size from the display delay time + # assume that the maximum video fps is 30 + buffer_size = round(30 * (1 + max(args.display_delay, 0) / 1000.)) + frame_buffer = Queue(maxsize=buffer_size) + + # queue of input frames + # element: (timestamp, frame) + input_queue = deque(maxlen=1) + input_queue_mutex = Lock() + + # queue of detection results + # element: tuple(timestamp, frame, time_info, det_results) + det_result_queue = deque(maxlen=1) + det_result_queue_mutex = Lock() + + # queue of detection/pose results + # element: (timestamp, time_info, pose_results_list) + pose_result_queue = deque(maxlen=1) + pose_result_queue_mutex = Lock() + + try: + event_exit = Event() + event_inference_done = Event() + t_input = Thread(target=read_camera, args=()) + t_det = Thread(target=inference_detection, args=(), daemon=True) + t_pose = Thread(target=inference_pose, args=(), daemon=True) + + t_input.start() + t_det.start() + t_pose.start() + + # run display in the main thread + display() + # join the input thread (non-daemon) + t_input.join() + + except KeyboardInterrupt: + pass + + +if __name__ == '__main__': + main() diff --git a/mmpose/apis/inference.py b/mmpose/apis/inference.py index 550c4e3cd6..b6668070ad 100644 --- a/mmpose/apis/inference.py +++ b/mmpose/apis/inference.py @@ -520,6 +520,7 @@ def vis_pose_result(model, radius=4, thickness=1, kpt_score_thr=0.3, + bbox_color='green', dataset='TopDownCocoDataset', show=False, out_file=None): @@ -740,8 +741,8 @@ def vis_pose_result(model, [14, 18], [7, 11], [11, 15], [15, 19], [7, 12], [12, 16], [16, 20]] - pose_limb_color = palette[[0] * 20] - pose_kpt_color = palette[[0] * 20] + pose_kpt_color = palette[[15] * 5 + [0] * 7 + [9] * 8] + pose_limb_color = palette[[15] * 5 + [0] * 3 + [0, 9, 9] * 4] else: raise NotImplementedError() @@ -755,6 +756,7 @@ def vis_pose_result(model, pose_kpt_color=pose_kpt_color, pose_limb_color=pose_limb_color, kpt_score_thr=kpt_score_thr, + bbox_color=bbox_color, show=show, out_file=out_file) diff --git a/mmpose/core/visualization/__init__.py b/mmpose/core/visualization/__init__.py index d5e0f78170..eda8d48f9d 100644 --- a/mmpose/core/visualization/__init__.py +++ b/mmpose/core/visualization/__init__.py @@ -1,3 +1,7 @@ -from .image import imshow_keypoints, imshow_keypoints_3d +from .effects import apply_bugeye_effect, apply_sunglasses_effect +from .image import imshow_bboxes, imshow_keypoints, imshow_keypoints_3d -__all__ = ['imshow_keypoints', 'imshow_keypoints_3d'] +__all__ = [ + 'imshow_keypoints', 'imshow_keypoints_3d', 'imshow_bboxes', + 'apply_bugeye_effect', 'apply_sunglasses_effect' +] diff --git a/mmpose/core/visualization/effects.py b/mmpose/core/visualization/effects.py new file mode 100644 index 0000000000..8711f1b52d --- /dev/null +++ b/mmpose/core/visualization/effects.py @@ -0,0 +1,110 @@ +import cv2 +import numpy as np + + +def apply_bugeye_effect(img, + pose_results, + left_eye_index, + right_eye_index, + kpt_thr=0.5): + """Apply bug-eye effect. + + Args: + img (np.ndarray): Image data. + pose_results (list[dict]): The pose estimation results containing: + - "bbox" ([K, 4(or 5)]): detection bbox in + [x1, y1, x2, y2, (score)] + - "keypoints" ([K,3]): keypoint detection result in [x, y, score] + left_eye_index (int): Keypoint index of left eye + right_eye_index (int): Keypoint index of right eye + kpt_thr (float): The score threshold of required keypoints. + """ + + xx, yy = np.meshgrid(np.arange(img.shape[1]), np.arange(img.shape[0])) + xx = xx.astype(np.float32) + yy = yy.astype(np.float32) + + for pose in pose_results: + bbox = pose['bbox'] + kpts = pose['keypoints'] + + if kpts[left_eye_index, 2] < kpt_thr or kpts[right_eye_index, + 2] < kpt_thr: + continue + + kpt_leye = kpts[left_eye_index, :2] + kpt_reye = kpts[right_eye_index, :2] + for xc, yc in [kpt_leye, kpt_reye]: + + # distortion parameters + k1 = 0.001 + epe = 1e-5 + + scale = (bbox[2] - bbox[0])**2 + (bbox[3] - bbox[1])**2 + r2 = ((xx - xc)**2 + (yy - yc)**2) + r2 = (r2 + epe) / scale # normalized by bbox scale + + xx = (xx - xc) / (1 + k1 / r2) + xc + yy = (yy - yc) / (1 + k1 / r2) + yc + + img = cv2.remap( + img, + xx, + yy, + interpolation=cv2.INTER_AREA, + borderMode=cv2.BORDER_REPLICATE) + return img + + +def apply_sunglasses_effect(img, + pose_results, + sunglasses_img, + left_eye_index, + right_eye_index, + kpt_thr=0.5): + """Apply sunglasses effect. + + Args: + img (np.ndarray): Image data. + pose_results (list[dict]): The pose estimation results containing: + - "keypoints" ([K,3]): keypoint detection result in [x, y, score] + sunglasses_img (np.ndarray): Sunglasses image with white background. + left_eye_index (int): Keypoint index of left eye + right_eye_index (int): Keypoint index of right eye + kpt_thr (float): The score threshold of required keypoints. + """ + + hm, wm = sunglasses_img.shape[:2] + # anchor points in the sunglasses mask + pts_src = np.array([[0.3 * wm, 0.3 * hm], [0.3 * wm, 0.7 * hm], + [0.7 * wm, 0.3 * hm], [0.7 * wm, 0.7 * hm]], + dtype=np.float32) + + for pose in pose_results: + kpts = pose['keypoints'] + + if kpts[left_eye_index, 2] < kpt_thr or kpts[right_eye_index, + 2] < kpt_thr: + continue + + kpt_leye = kpts[left_eye_index, :2] + kpt_reye = kpts[right_eye_index, :2] + # orthogonal vector to the left-to-right eyes + vo = 0.5 * (kpt_reye - kpt_leye)[::-1] * [-1, 1] + + # anchor points in the image by eye positions + pts_tar = np.vstack( + [kpt_reye + vo, kpt_reye - vo, kpt_leye + vo, kpt_leye - vo]) + + h_mat, _ = cv2.findHomography(pts_src, pts_tar) + patch = cv2.warpPerspective( + sunglasses_img, + h_mat, + dsize=(img.shape[1], img.shape[0]), + borderValue=(255, 255, 255)) + # mask the white background area in the patch with a threshold 200 + mask = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY) + mask = (mask < 200).astype(np.uint8) + img = cv2.copyTo(patch, mask, img) + + return img diff --git a/mmpose/core/visualization/image.py b/mmpose/core/visualization/image.py index ba9ddc4e95..09bbacc79b 100644 --- a/mmpose/core/visualization/image.py +++ b/mmpose/core/visualization/image.py @@ -6,6 +6,82 @@ from matplotlib import pyplot as plt +def imshow_bboxes(img, + bboxes, + labels=None, + colors='green', + text_color='white', + thickness=1, + font_scale=0.5, + show=True, + win_name='', + wait_time=0, + out_file=None): + """Draw bboxes with labels (optional) on an image. This is a wrapper of + mmcv.imshow_bboxes. + + Args: + img (str or ndarray): The image to be displayed. + bboxes (ndarray): ndarray of shape (k, 4), each row is a bbox in + format [x1, y1, x2, y2]. + labels (str or list[str], optional): labels of each bbox. + colors (list[str or tuple or :obj:`Color`]): A list of colors. + text_color (str or tuple or :obj:`Color`): Color of texts. + thickness (int): Thickness of lines. + font_scale (float): Font scales of texts. + show (bool): Whether to show the image. + win_name (str): The window name. + wait_time (int): Value of waitKey param. + out_file (str, optional): The filename to write the image. + + Returns: + ndarray: The image with bboxes drawn on it. + """ + + # adapt to mmcv.imshow_bboxes input format + bboxes = np.split(bboxes, bboxes.shape[0], axis=0) + if not isinstance(colors, list): + colors = [colors for _ in range(len(bboxes))] + colors = [mmcv.color_val(c) for c in colors] + assert len(bboxes) == len(colors) + + img = mmcv.imshow_bboxes( + img, + bboxes, + colors, + top_k=-1, + thickness=thickness, + show=False, + out_file=None) + + if labels is not None: + if not isinstance(labels, list): + labels = [labels for _ in range(len(bboxes))] + assert len(labels) == len(bboxes) + + for bbox, label, color in zip(bboxes, labels, colors): + bbox_int = bbox[0, :4].astype(np.int32) + # roughly estimate the proper font size + text_size, text_baseline = cv2.getTextSize(label, + cv2.FONT_HERSHEY_DUPLEX, + font_scale, thickness) + text_x1 = bbox_int[0] + text_y1 = max(0, bbox_int[1] - text_size[1] - text_baseline) + text_x2 = bbox_int[0] + text_size[0] + text_y2 = text_y1 + text_size[1] + text_baseline + cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color, + cv2.FILLED) + cv2.putText(img, label, (text_x1, text_y2 - text_baseline), + cv2.FONT_HERSHEY_DUPLEX, font_scale, + mmcv.color_val(text_color), thickness) + + if show: + mmcv.imshow(img, win_name, wait_time) + if out_file is not None: + mmcv.imwrite(img, out_file) + return img + + def imshow_keypoints(img, pose_result, skeleton=None, diff --git a/mmpose/models/detectors/pose_lifter.py b/mmpose/models/detectors/pose_lifter.py index 81ca236268..effc405607 100644 --- a/mmpose/models/detectors/pose_lifter.py +++ b/mmpose/models/detectors/pose_lifter.py @@ -3,7 +3,7 @@ import mmcv import numpy as np -from mmpose.core import imshow_keypoints, imshow_keypoints_3d +from mmpose.core import imshow_bboxes, imshow_keypoints, imshow_keypoints_3d from .. import builder from ..builder import POSENETS from .base import BasePose @@ -345,11 +345,10 @@ def show_result(self, if len(bbox_result) > 0: bboxes = np.vstack(bbox_result) - mmcv.imshow_bboxes( + imshow_bboxes( img, bboxes, colors='green', - top_k=-1, thickness=thickness, show=False) if len(pose_input_2d) > 0: diff --git a/mmpose/models/detectors/top_down.py b/mmpose/models/detectors/top_down.py index 7134df2978..8aa26bb9e7 100644 --- a/mmpose/models/detectors/top_down.py +++ b/mmpose/models/detectors/top_down.py @@ -5,7 +5,7 @@ from mmcv.image import imwrite from mmcv.visualization.image import imshow -from mmpose.core import imshow_keypoints +from mmpose.core import imshow_bboxes, imshow_keypoints from .. import builder from ..builder import POSENETS from .base import BasePose @@ -222,10 +222,11 @@ def show_result(self, bbox_color='green', pose_kpt_color=None, pose_limb_color=None, - text_color=(255, 0, 0), + text_color='white', radius=4, thickness=1, font_scale=0.5, + bbox_thickness=1, win_name='', show=False, show_keypoint_weight=False, @@ -264,7 +265,6 @@ def show_result(self, img = mmcv.imread(img) img = img.copy() - img_h, img_w, _ = img.shape bbox_result = [] pose_result = [] @@ -274,13 +274,18 @@ def show_result(self, if len(bbox_result) > 0: bboxes = np.vstack(bbox_result) + labels = None + if 'label' in result[0]: + labels = [res['label'] for res in result] # draw bounding boxes - mmcv.imshow_bboxes( + imshow_bboxes( img, bboxes, + labels=labels, colors=bbox_color, - top_k=-1, - thickness=thickness, + text_color=text_color, + thickness=bbox_thickness, + font_scale=font_scale, show=False) imshow_keypoints(img, pose_result, skeleton, kpt_score_thr, diff --git a/mmpose/utils/__init__.py b/mmpose/utils/__init__.py index ac489e2dbb..860a2d4a57 100644 --- a/mmpose/utils/__init__.py +++ b/mmpose/utils/__init__.py @@ -1,4 +1,5 @@ from .collect_env import collect_env from .logger import get_root_logger +from .timer import StopWatch -__all__ = ['get_root_logger', 'collect_env'] +__all__ = ['get_root_logger', 'collect_env', 'StopWatch'] diff --git a/mmpose/utils/timer.py b/mmpose/utils/timer.py new file mode 100644 index 0000000000..2e499dcbd2 --- /dev/null +++ b/mmpose/utils/timer.py @@ -0,0 +1,92 @@ +from collections import defaultdict + +import numpy as np +from mmcv import Timer + + +class StopWatch: + r"""A helper class to measure FPS and detailed time consuming of each phase + in a video processing loop or similar scenarios. + + Args: + window (int): The sliding window size to calculate the running average + of the time consuming. + + Example:: + >>> stop_watch = StopWatch(window=10) + >>> while True: + ... with stop_watch.timeit('total'): + ... sleep(1) + ... # 'timeit' support nested use + ... with stop_watch.timeit('phase1'): + ... sleep(1) + ... with stop_watch.timeit('phase2'): + ... sleep(2) + ... sleep(2) + ... report = stop_watch.report() + report = {'total': 6., 'phase1': 1., 'phase2': 2.} + + """ + + def __init__(self, window=1): + self._record = defaultdict(list) + self._timer_stack = [] + self.window = window + + def timeit(self, timer_name='_FPS_'): + """Timing a code snippet with an assigned name. + + Args: + timer_name (str): The unique name of the interested code snippet to + handle multiple timers and generate reports. Note that '_FPS_' + is a special key that the measurement will be in `fps` instead + of `millisecond`. Also see `report` and `report_strings`. + Default: '_FPS_'. + Note: + This function should always be used in a `with` statement, as shown + in the example. + """ + self._timer_stack.append((timer_name, Timer())) + return self + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, trackback): + timer_name, timer = self._timer_stack.pop() + self._record[timer_name].append(timer.since_start()) + self._record[timer_name] = self._record[timer_name][-self.window:] + + def report(self): + """Report timing information. + + Returns: + dict: The key is the timer name and the value is the corresponding + average time consuming. + """ + result = { + name: np.mean(vals) * 1000. + for name, vals in self._record.items() + } + return result + + def report_strings(self): + """Report timing information in texture strings. + + Returns: + list(str): Each element is the information string of a timed event, + in format of '{timer_name}: {time_in_ms}'. Specially, if + timer_name is '_FPS_', the result will be converted to + fps. + """ + result = self.report() + strings = [] + if '_FPS_' in result: + fps = 1000. / result.pop('_FPS_') + strings.append(f'FPS: {fps:>5.1f}') + strings += [f'{name}: {val:>3.0f}' for name, val in result.items()] + return strings + + def reset(self): + self._record = defaultdict(list) + self._timer_stack = [] diff --git a/tests/test_regularization.py b/tests/test_regularization.py index b23b86f1c7..3de10aceef 100644 --- a/tests/test_regularization.py +++ b/tests/test_regularization.py @@ -15,4 +15,4 @@ def test_weight_norm_clip(): _ = module(x) weight_norm = module.weight.norm().item() - np.testing.assert_allclose(weight_norm, 1.0, rtol=1e-6) + np.testing.assert_almost_equal(weight_norm, 1.0, decimal=6) diff --git a/tests/test_utils.py b/tests/test_utils.py index 7057dbedb0..67fa096cb8 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,13 @@ +import time + import cv2 import mmcv +import numpy as np import torch import torchvision import mmpose -from mmpose.utils import collect_env +from mmpose.utils import StopWatch, collect_env def test_collect_env(): @@ -15,3 +18,25 @@ def test_collect_env(): assert env_info['MMCV'] == mmcv.__version__ assert '+' in env_info['MMPose'] assert mmpose.__version__ in env_info['MMPose'] + + +def test_stopwatch(): + window_size = 5 + test_loop = 10 + outer_time = 100 + inner_time = 100 + + stop_watch = StopWatch(window=window_size) + for _ in range(test_loop): + with stop_watch.timeit(): + time.sleep(outer_time / 1000.) + with stop_watch.timeit('inner'): + time.sleep(inner_time / 1000.) + + report = stop_watch.report() + _ = stop_watch.report_strings() + + np.testing.assert_allclose( + report['_FPS_'], outer_time + inner_time, rtol=0.01) + + np.testing.assert_allclose(report['inner'], inner_time, rtol=0.01) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 8403713adf..927afaafb4 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,6 +1,10 @@ +import tempfile + +import mmcv import numpy as np -from mmpose.core import imshow_keypoints, imshow_keypoints_3d +from mmpose.core import (apply_bugeye_effect, apply_sunglasses_effect, + imshow_bboxes, imshow_keypoints, imshow_keypoints_3d) def test_imshow_keypoints(): @@ -29,3 +33,48 @@ def test_imshow_keypoints(): pose_kpt_color=pose_kpt_color, pose_limb_color=pose_limb_color, vis_height=400) + + +def test_imshow_bbox(): + img = np.zeros((100, 100, 3), dtype=np.uint8) + bboxes = np.array([[10, 10, 30, 30], [10, 50, 30, 80]], dtype=np.float32) + labels = ['label 1', 'label 2'] + colors = ['red', 'green'] + + with tempfile.TemporaryDirectory() as tmpdir: + _ = imshow_bboxes( + img, + bboxes, + labels=labels, + colors=colors, + show=False, + out_file=f'{tmpdir}/out.png') + + +def test_effects(): + img = np.zeros((100, 100, 3), dtype=np.uint8) + kpts = np.array([[10., 10., 0.8], [20., 10., 0.8]], dtype=np.float32) + bbox = np.array([0, 0, 50, 50], dtype=np.float32) + pose_results = [dict(bbox=bbox, keypoints=kpts)] + # sunglasses + sunglasses_img = mmcv.imread('demo/resources/sunglasses.jpg') + _ = apply_sunglasses_effect( + img, + pose_results, + sunglasses_img, + left_eye_index=1, + right_eye_index=0, + kpt_thr=0.5) + _ = apply_sunglasses_effect( + img, + pose_results, + sunglasses_img, + left_eye_index=1, + right_eye_index=0, + kpt_thr=0.9) + + # bug-eye + _ = apply_bugeye_effect( + img, pose_results, left_eye_index=1, right_eye_index=0, kpt_thr=0.5) + _ = apply_bugeye_effect( + img, pose_results, left_eye_index=1, right_eye_index=0, kpt_thr=0.9)