diff --git a/demo/README.md b/demo/README.md
index f824bbf0cf..60ecbc3398 100644
--- a/demo/README.md
+++ b/demo/README.md
@@ -65,3 +65,11 @@ This page provides tutorials about running demos. Please click the caption for m
[3D hand_pose demo](docs/3d_hand_demo.md)
+
+
+
+
+
![](https://user-images.githubusercontent.com/15977946/124059525-ce20c580-da5d-11eb-8e4a-2d96cd31fe9f.gif)
+
+[Webcam demo](docs/webcam_demo.md)
+
diff --git a/demo/docs/webcam_demo.md b/demo/docs/webcam_demo.md
new file mode 100644
index 0000000000..17db79752f
--- /dev/null
+++ b/demo/docs/webcam_demo.md
@@ -0,0 +1,49 @@
+## Webcam Demo
+
+We provide a webcam demo tool which integrartes detection and 2D pose estimation for humans and animals. You can simply run the following command:
+
+```python
+python demo/webcam_demo.py
+```
+
+It will launch a window to display the webcam video steam with detection and pose estimation results:
+
+
+
![](https://user-images.githubusercontent.com/15977946/124059525-ce20c580-da5d-11eb-8e4a-2d96cd31fe9f.gif)
+
+
+### Usage Tips
+
+- **Which model is used in the demo tool?**
+
+ Please check the following default arguments in the script. You can also choose other models from the [MMDetection Model Zoo](https://github.com/open-mmlab/mmdetection/blob/master/docs/model_zoo.md) and [MMPose Model Zoo](https://mmpose.readthedocs.io/en/latest/modelzoo.html#) or use your own models.
+
+ | Model | Arguments |
+ | :--: | :-- |
+ | Detection | `--det-config`, `--det-checkpoint` |
+ | Human Pose | `--human-pose-config`, `--human-pose-checkpoint` |
+ | Animal Pose | `--animal-pose-config`, `--animal-pose-checkpoint` |
+
+- **Can this tool run without GPU?**
+
+ Yes, you can set `--device=cpu` and the model inference will be performed on CPU. Of course, this may cause a low inference FPS compared to using GPU devices.
+
+- **Why there is time delay between the pose visualization and the video?**
+
+ The video I/O and model inference are running asynchronously and the latter usually takes more time for a single frame. To allevidate the time delay, you can:
+
+ 1. set `--display-delay=MILLISECONDS` to defer the video stream, according to the inference delay shown at the top left corner. Or,
+
+ 2. set `--synchronous-mode` to force video stream being aligned with inference results. This may reduce the video display FPS.
+
+- **Can this tool process video files?**
+
+ Yes. You can set `--cam_id=VIDEO_FILE_PATH` to run the demo tool in offline mode on a video file. Note that `--synchronous-mode` should be set in this case.
+
+- **How to enable/disable the special effects?**
+
+ The special effects can be enabled/disabled at launch time by setting arguments like `--bugeye`, `--sunglasses`, *etc*. You can also toggle the effects by keyboard shorcuts like `b`, `s` when the tool starts.
+
+- **What if my computer doesn't have a camera?**
+
+ You can use a smart phone as a webcam with apps like [Camo](https://reincubate.com/camo/) or [DroidCam](https://www.dev47apps.com/).
diff --git a/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_1class.py b/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_1class.py
index f45ad10cb0..4e60b6b739 100644
--- a/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_1class.py
+++ b/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_1class.py
@@ -206,7 +206,7 @@
max_per_img=100)))
dataset_type = 'CocoDataset'
-data_root = 'data/coco/'
+data_root = 'data/coco'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
@@ -239,17 +239,17 @@
workers_per_gpu=2,
train=dict(
type=dataset_type,
- ann_file=data_root + 'annotations/instances_train2017.json',
- img_prefix=data_root + 'train2017/',
+ ann_file=f'{data_root}/annotations/instances_train2017.json',
+ img_prefix=f'{data_root}/train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
- ann_file=data_root + 'annotations/instances_val2017.json',
- img_prefix=data_root + 'val2017/',
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
- ann_file=data_root + 'annotations/instances_val2017.json',
- img_prefix=data_root + 'val2017/',
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')
diff --git a/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_coco.py b/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_coco.py
index 58b96e62e1..f91bd0d105 100644
--- a/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_coco.py
+++ b/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_coco.py
@@ -207,7 +207,7 @@
max_per_img=100)))
dataset_type = 'CocoDataset'
-data_root = 'data/coco/'
+data_root = 'data/coco'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
@@ -240,17 +240,17 @@
workers_per_gpu=2,
train=dict(
type=dataset_type,
- ann_file=data_root + 'annotations/instances_train2017.json',
- img_prefix=data_root + 'train2017/',
+ ann_file=f'{data_root}/annotations/instances_train2017.json',
+ img_prefix=f'{data_root}/train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
- ann_file=data_root + 'annotations/instances_val2017.json',
- img_prefix=data_root + 'val2017/',
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
- ann_file=data_root + 'annotations/instances_val2017.json',
- img_prefix=data_root + 'val2017/',
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')
diff --git a/demo/mmdetection_cfg/faster_rcnn_r50_fpn_1class.py b/demo/mmdetection_cfg/faster_rcnn_r50_fpn_1class.py
index 8eddf232b7..ee54f5b66b 100644
--- a/demo/mmdetection_cfg/faster_rcnn_r50_fpn_1class.py
+++ b/demo/mmdetection_cfg/faster_rcnn_r50_fpn_1class.py
@@ -133,7 +133,7 @@
))
dataset_type = 'CocoDataset'
-data_root = 'data/coco/'
+data_root = 'data/coco'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
@@ -166,17 +166,17 @@
workers_per_gpu=2,
train=dict(
type=dataset_type,
- ann_file=data_root + 'annotations/instances_train2017.json',
- img_prefix=data_root + 'train2017/',
+ ann_file=f'{data_root}/annotations/instances_train2017.json',
+ img_prefix=f'{data_root}/train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
- ann_file=data_root + 'annotations/instances_val2017.json',
- img_prefix=data_root + 'val2017/',
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
- ann_file=data_root + 'annotations/instances_val2017.json',
- img_prefix=data_root + 'val2017/',
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')
diff --git a/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py b/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py
index d5c17df89a..a9ad9528b2 100644
--- a/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py
+++ b/demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py
@@ -133,7 +133,7 @@
))
dataset_type = 'CocoDataset'
-data_root = 'data/coco/'
+data_root = 'data/coco'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
@@ -166,17 +166,17 @@
workers_per_gpu=2,
train=dict(
type=dataset_type,
- ann_file=data_root + 'annotations/instances_train2017.json',
- img_prefix=data_root + 'train2017/',
+ ann_file=f'{data_root}/annotations/instances_train2017.json',
+ img_prefix=f'{data_root}/train2017/',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
- ann_file=data_root + 'annotations/instances_val2017.json',
- img_prefix=data_root + 'val2017/',
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
- ann_file=data_root + 'annotations/instances_val2017.json',
- img_prefix=data_root + 'val2017/',
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')
diff --git a/demo/mmdetection_cfg/yolov3_d53_320_273e_coco.py b/demo/mmdetection_cfg/yolov3_d53_320_273e_coco.py
new file mode 100644
index 0000000000..d7e9cca1eb
--- /dev/null
+++ b/demo/mmdetection_cfg/yolov3_d53_320_273e_coco.py
@@ -0,0 +1,140 @@
+# model settings
+model = dict(
+ type='YOLOV3',
+ pretrained='open-mmlab://darknet53',
+ backbone=dict(type='Darknet', depth=53, out_indices=(3, 4, 5)),
+ neck=dict(
+ type='YOLOV3Neck',
+ num_scales=3,
+ in_channels=[1024, 512, 256],
+ out_channels=[512, 256, 128]),
+ bbox_head=dict(
+ type='YOLOV3Head',
+ num_classes=80,
+ in_channels=[512, 256, 128],
+ out_channels=[1024, 512, 256],
+ anchor_generator=dict(
+ type='YOLOAnchorGenerator',
+ base_sizes=[[(116, 90), (156, 198), (373, 326)],
+ [(30, 61), (62, 45), (59, 119)],
+ [(10, 13), (16, 30), (33, 23)]],
+ strides=[32, 16, 8]),
+ bbox_coder=dict(type='YOLOBBoxCoder'),
+ featmap_strides=[32, 16, 8],
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0,
+ reduction='sum'),
+ loss_conf=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0,
+ reduction='sum'),
+ loss_xy=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=2.0,
+ reduction='sum'),
+ loss_wh=dict(type='MSELoss', loss_weight=2.0, reduction='sum')),
+ # training and testing settings
+ train_cfg=dict(
+ assigner=dict(
+ type='GridAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.5,
+ min_pos_iou=0)),
+ test_cfg=dict(
+ nms_pre=1000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ conf_thr=0.005,
+ nms=dict(type='nms', iou_threshold=0.45),
+ max_per_img=100))
+# dataset settings
+dataset_type = 'CocoDataset'
+data_root = 'data/coco'
+img_norm_cfg = dict(mean=[0, 0, 0], std=[255., 255., 255.], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile', to_float32=True),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(type='PhotoMetricDistortion'),
+ dict(
+ type='Expand',
+ mean=img_norm_cfg['mean'],
+ to_rgb=img_norm_cfg['to_rgb'],
+ ratio_range=(1, 2)),
+ dict(
+ type='MinIoURandomCrop',
+ min_ious=(0.4, 0.5, 0.6, 0.7, 0.8, 0.9),
+ min_crop_size=0.3),
+ dict(type='Resize', img_scale=(320, 320), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(320, 320),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img'])
+ ])
+]
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/annotations/instances_train2017.json',
+ img_prefix=f'{data_root}/train2017/',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=f'{data_root}/annotations/instances_val2017.json',
+ img_prefix=f'{data_root}/val2017/',
+ pipeline=test_pipeline))
+# optimizer
+optimizer = dict(type='SGD', lr=0.001, momentum=0.9, weight_decay=0.0005)
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=2000, # same as burn-in in darknet
+ warmup_ratio=0.1,
+ step=[218, 246])
+# runtime settings
+runner = dict(type='EpochBasedRunner', max_epochs=273)
+evaluation = dict(interval=1, metric=['bbox'])
+
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ # dict(type='TensorboardLoggerHook')
+ ])
+# yapf:enable
+custom_hooks = [dict(type='NumClassCheckHook')]
+
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/demo/mmtracking_cfg/tracktor_faster-rcnn_r50_fpn_4e_mot17-private.py b/demo/mmtracking_cfg/tracktor_faster-rcnn_r50_fpn_4e_mot17-private.py
index af31ccaf18..2973d16e87 100644
--- a/demo/mmtracking_cfg/tracktor_faster-rcnn_r50_fpn_4e_mot17-private.py
+++ b/demo/mmtracking_cfg/tracktor_faster-rcnn_r50_fpn_4e_mot17-private.py
@@ -202,7 +202,7 @@
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
- dict(type='ImageToTensor', keys=['img']),
+ dict(type='DefaultFormatBundle', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
]
@@ -272,7 +272,7 @@
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
- dict(type='ImageToTensor', keys=['img']),
+ dict(type='DefaultFormatBundle', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
]),
@@ -296,7 +296,7 @@
std=[58.395, 57.12, 57.375],
to_rgb=True),
dict(type='Pad', size_divisor=32),
- dict(type='ImageToTensor', keys=['img']),
+ dict(type='DefaultFormatBundle', keys=['img']),
dict(type='VideoCollect', keys=['img'])
])
]))
diff --git a/demo/resources/sunglasses.jpg b/demo/resources/sunglasses.jpg
new file mode 100644
index 0000000000..5d3cee8702
Binary files /dev/null and b/demo/resources/sunglasses.jpg differ
diff --git a/demo/webcam_demo.py b/demo/webcam_demo.py
new file mode 100644
index 0000000000..aaf2f14820
--- /dev/null
+++ b/demo/webcam_demo.py
@@ -0,0 +1,567 @@
+import argparse
+import time
+from collections import deque
+from queue import Queue
+from threading import Event, Lock, Thread
+
+import cv2
+import numpy as np
+
+from mmpose.apis import (get_track_id, inference_top_down_pose_model,
+ init_pose_model, vis_pose_result)
+from mmpose.core import apply_bugeye_effect, apply_sunglasses_effect
+from mmpose.utils import StopWatch
+
+try:
+ from mmdet.apis import inference_detector, init_detector
+ has_mmdet = True
+except (ImportError, ModuleNotFoundError):
+ has_mmdet = False
+
+try:
+ import psutil
+ psutil_proc = psutil.Process()
+except (ImportError, ModuleNotFoundError):
+ psutil_proc = None
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cam_id', type=str, default='0')
+ parser.add_argument(
+ '--det_config',
+ type=str,
+ default='demo/mmdetection_cfg/yolov3_d53_320_273e_coco.py',
+ help='Config file for detection')
+ parser.add_argument(
+ '--det_checkpoint',
+ type=str,
+ default='https://download.openmmlab.com/mmdetection/v2.0/yolo/'
+ 'yolov3_d53_320_273e_coco/yolov3_d53_320_273e_coco-421362b6.pth',
+ help='Checkpoint file for detection')
+ parser.add_argument(
+ '--enable_human_pose',
+ type=int,
+ default=1,
+ help='Enable human pose estimation')
+ parser.add_argument(
+ '--enable_animal_pose',
+ type=int,
+ default=1,
+ help='Enable animal pose estimation')
+ parser.add_argument(
+ '--human_pose_config',
+ type=str,
+ default='configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/'
+ 'hrnet_w48_coco_256x192.py',
+ help='Config file for human pose')
+ parser.add_argument(
+ '--human_pose_checkpoint',
+ type=str,
+ default='https://download.openmmlab.com/mmpose/top_down/hrnet/'
+ 'hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth',
+ help='Checkpoint file for human pose')
+ parser.add_argument(
+ '--human_det_ids',
+ type=int,
+ default=[1],
+ nargs='+',
+ help='Object category label of human in detection results.'
+ 'Default is [1(person)], following COCO definition.')
+ parser.add_argument(
+ '--animal_pose_config',
+ type=str,
+ default='configs/animal/2d_kpt_sview_rgb_img/topdown_heatmap/'
+ 'animalpose/hrnet_w32_animalpose_256x256.py',
+ help='Config file for animal pose')
+ parser.add_argument(
+ '--animal_pose_checkpoint',
+ type=str,
+ default='https://download.openmmlab.com/mmpose/animal/hrnet/'
+ 'hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth',
+ help='Checkpoint file for animal pose')
+ parser.add_argument(
+ '--animal_det_ids',
+ type=int,
+ default=[16, 17, 18, 19, 20],
+ nargs='+',
+ help='Object category label of animals in detection results'
+ 'Default is [16(cat), 17(dog), 18(horse), 19(sheep), 20(cow)], '
+ 'following COCO definition.')
+ parser.add_argument(
+ '--device', default='cuda:0', help='Device used for inference')
+ parser.add_argument(
+ '--det-score-thr',
+ type=float,
+ default=0.5,
+ help='bbox score threshold')
+ parser.add_argument(
+ '--kpt-thr', type=float, default=0.3, help='bbox score threshold')
+ parser.add_argument(
+ '--show-pose',
+ type=lambda s: s != 'False',
+ default=True,
+ choices=['True', 'False'],
+ help='Show pose estimation results. Set False to disable the pose'
+ 'visualization. Default: True')
+ parser.add_argument(
+ '--sunglasses', action='store_true', help='Apply `sunglasses` effect.')
+ parser.add_argument(
+ '--bugeye', action='store_true', help='Apply `bug-eye` effect.')
+
+ parser.add_argument(
+ '--out-video-file',
+ type=str,
+ default=None,
+ help='Record the video into a file. This may reduce the frame rate')
+
+ parser.add_argument(
+ '--out-video-fps',
+ type=int,
+ default=20,
+ help='Set the FPS of the output video file.')
+
+ parser.add_argument(
+ '--buffer-size',
+ type=int,
+ default=-1,
+ help='Frame buffer size. If set -1, the buffer size will be '
+ 'automatically inferred from the display delay time. Deafult: -1')
+
+ parser.add_argument(
+ '--inference_fps',
+ type=int,
+ default=10,
+ help='Maximum inference FPS. This is to limit the resource consuming '
+ 'especially when the detection and pose model are lightweight and '
+ 'very fast. Default: 10.')
+
+ parser.add_argument(
+ '--display-delay',
+ type=int,
+ default=0,
+ help='Delay the output video in milliseconds. This can be used to '
+ 'align the output video and inference results. The delay can be '
+ 'disabled by setting a non-positive delay time. Default: 0')
+
+ parser.add_argument(
+ '--synchronous-mode',
+ action='store_true',
+ help='Enable synchronous mode that video I/O and inference will be '
+ 'temporally aligned. Note that this will reduce the display FPS.')
+
+ return parser.parse_args()
+
+
+def process_mmdet_results(mmdet_results, class_names=None, cat_ids=1):
+ """Process mmdet results to mmpose input format.
+
+ Args:
+ mmdet_results: raw output of mmdet model
+ class_names: class names of mmdet model
+ cat_ids (int or List[int]): category id list that will be preserved
+ Returns:
+ List[Dict]: detection results for mmpose input
+ """
+ if isinstance(mmdet_results, tuple):
+ mmdet_results = mmdet_results[0]
+
+ if not isinstance(cat_ids, (list, tuple)):
+ cat_ids = [cat_ids]
+
+ # only keep bboxes of interested classes
+ bbox_results = [mmdet_results[i - 1] for i in cat_ids]
+ bboxes = np.vstack(bbox_results)
+
+ # get textual labels of classes
+ labels = np.concatenate([
+ np.full(bbox.shape[0], i - 1, dtype=np.int32)
+ for i, bbox in zip(cat_ids, bbox_results)
+ ])
+ if class_names is None:
+ labels = [f'class: {i}' for i in labels]
+ else:
+ labels = [class_names[i] for i in labels]
+
+ det_results = []
+ for bbox, label in zip(bboxes, labels):
+ det_result = dict(bbox=bbox, label=label)
+ det_results.append(det_result)
+ return det_results
+
+
+def read_camera():
+ # init video reader
+ print('Thread "input" started')
+ cam_id = args.cam_id
+ if cam_id.isdigit():
+ cam_id = int(cam_id)
+ vid_cap = cv2.VideoCapture(cam_id)
+ if not vid_cap.isOpened():
+ print(f'Cannot open camera (ID={cam_id})')
+ exit()
+
+ while not event_exit.is_set():
+ # capture a camera frame
+ ret_val, frame = vid_cap.read()
+ if ret_val:
+ ts_input = time.time()
+
+ event_inference_done.clear()
+ with input_queue_mutex:
+ input_queue.append((ts_input, frame))
+
+ if args.synchronous_mode:
+ event_inference_done.wait()
+
+ frame_buffer.put((ts_input, frame))
+ else:
+ # input ending signal
+ frame_buffer.put((None, None))
+ break
+
+ vid_cap.release()
+
+
+def inference_detection():
+ print('Thread "det" started')
+ stop_watch = StopWatch(window=10)
+ min_interval = 1.0 / args.inference_fps
+ _ts_last = None # timestamp when last inference was done
+
+ while True:
+ while len(input_queue) < 1:
+ time.sleep(0.001)
+ with input_queue_mutex:
+ ts_input, frame = input_queue.popleft()
+ # inference detection
+ with stop_watch.timeit('Det'):
+ mmdet_results = inference_detector(det_model, frame)
+
+ t_info = stop_watch.report_strings()
+ with det_result_queue_mutex:
+ det_result_queue.append((ts_input, frame, t_info, mmdet_results))
+
+ # limit the inference FPS
+ _ts = time.time()
+ if _ts_last is not None and _ts - _ts_last < min_interval:
+ time.sleep(min_interval - _ts + _ts_last)
+ _ts_last = time.time()
+
+
+def inference_pose():
+ print('Thread "pose" started')
+ stop_watch = StopWatch(window=10)
+
+ while True:
+ while len(det_result_queue) < 1:
+ time.sleep(0.001)
+ with det_result_queue_mutex:
+ ts_input, frame, t_info, mmdet_results = det_result_queue.popleft()
+
+ pose_results_list = []
+ for model_info, pose_history in zip(pose_model_list,
+ pose_history_list):
+ model_name = model_info['name']
+ pose_model = model_info['model']
+ cat_ids = model_info['cat_ids']
+ pose_results_last = pose_history['pose_results_last']
+ next_id = pose_history['next_id']
+
+ with stop_watch.timeit(model_name):
+ # process mmdet results
+ det_results = process_mmdet_results(
+ mmdet_results,
+ class_names=det_model.CLASSES,
+ cat_ids=cat_ids)
+
+ # inference pose model
+ dataset_name = pose_model.cfg.data['test']['type']
+ pose_results, _ = inference_top_down_pose_model(
+ pose_model,
+ frame,
+ det_results,
+ bbox_thr=args.det_score_thr,
+ format='xyxy',
+ dataset=dataset_name)
+
+ pose_results, next_id = get_track_id(
+ pose_results,
+ pose_results_last,
+ next_id,
+ use_oks=False,
+ tracking_thr=0.3,
+ use_one_euro=True,
+ fps=None)
+
+ pose_results_list.append(pose_results)
+
+ # update pose history
+ pose_history['pose_results_last'] = pose_results
+ pose_history['next_id'] = next_id
+
+ t_info += stop_watch.report_strings()
+ with pose_result_queue_mutex:
+ pose_result_queue.append((ts_input, t_info, pose_results_list))
+
+ event_inference_done.set()
+
+
+def display():
+ print('Thread "display" started')
+ stop_watch = StopWatch(window=10)
+
+ # initialize result status
+ ts_inference = None # timestamp of the latest inference result
+ fps_inference = 0. # infenrece FPS
+ t_delay_inference = 0. # inference result time delay
+ pose_results_list = None # latest inference result
+ t_info = [] # upstream time information (list[str])
+
+ # initialize visualization and output
+ sunglasses_img = None # resource image for sunglasses effect
+ text_color = (228, 183, 61) # text color to show time/system information
+ vid_out = None # video writer
+
+ # show instructions
+ print('Keyboard shortcuts: ')
+ print('"v": Toggle the visualization of bounding boxes and poses.')
+ print('"s": Toggle the sunglasses effect.')
+ print('"b": Toggle the bug-eye effect.')
+ print('"Q", "q" or Esc: Exit.')
+
+ while True:
+ with stop_watch.timeit('_FPS_'):
+ # acquire a frame from buffer
+ ts_input, frame = frame_buffer.get()
+ # input ending signal
+ if ts_input is None:
+ break
+
+ img = frame
+
+ # get pose estimation results
+ if len(pose_result_queue) > 0:
+ with pose_result_queue_mutex:
+ _result = pose_result_queue.popleft()
+ _ts_input, t_info, pose_results_list = _result
+
+ _ts = time.time()
+ if ts_inference is not None:
+ fps_inference = 1.0 / (_ts - ts_inference)
+ ts_inference = _ts
+ t_delay_inference = (_ts - _ts_input) * 1000
+
+ # visualize detection and pose results
+ if pose_results_list is not None:
+ for model_info, pose_results in zip(pose_model_list,
+ pose_results_list):
+ pose_model = model_info['model']
+ bbox_color = model_info['bbox_color']
+
+ dataset_name = pose_model.cfg.data['test']['type']
+
+ # show pose results
+ if args.show_pose:
+ img = vis_pose_result(
+ pose_model,
+ img,
+ pose_results,
+ radius=4,
+ thickness=2,
+ dataset=dataset_name,
+ kpt_score_thr=args.kpt_thr,
+ bbox_color=bbox_color)
+
+ # sunglasses effect
+ if args.sunglasses:
+ if dataset_name == 'TopDownCocoDataset':
+ left_eye_idx = 1
+ right_eye_idx = 2
+ elif dataset_name == 'AnimalPoseDataset':
+ left_eye_idx = 0
+ right_eye_idx = 1
+ else:
+ raise ValueError(
+ 'Sunglasses effect does not support'
+ f'{dataset_name}')
+ if sunglasses_img is None:
+ # The image attributes to:
+ # https://www.vecteezy.com/free-vector/glass
+ # Glass Vectors by Vecteezy
+ sunglasses_img = cv2.imread(
+ 'demo/resources/sunglasses.jpg')
+ img = apply_sunglasses_effect(img, pose_results,
+ sunglasses_img,
+ left_eye_idx,
+ right_eye_idx)
+ # bug-eye effect
+ if args.bugeye:
+ if dataset_name == 'TopDownCocoDataset':
+ left_eye_idx = 1
+ right_eye_idx = 2
+ elif dataset_name == 'AnimalPoseDataset':
+ left_eye_idx = 0
+ right_eye_idx = 1
+ else:
+ raise ValueError('Bug-eye effect does not support'
+ f'{dataset_name}')
+ img = apply_bugeye_effect(img, pose_results,
+ left_eye_idx, right_eye_idx)
+
+ # delay control
+ if args.display_delay > 0:
+ t_sleep = args.display_delay * 0.001 - (time.time() - ts_input)
+ if t_sleep > 0:
+ time.sleep(t_sleep)
+ t_delay = (time.time() - ts_input) * 1000
+
+ # show time information
+ t_info_display = stop_watch.report_strings() # display fps
+ t_info_display.append(f'Inference FPS: {fps_inference:>5.1f}')
+ t_info_display.append(f'Delay: {t_delay:>3.0f}')
+ t_info_display.append(
+ f'Inference Delay: {t_delay_inference:>3.0f}')
+ t_info_str = ' | '.join(t_info_display + t_info)
+ cv2.putText(img, t_info_str, (20, 20), cv2.FONT_HERSHEY_DUPLEX,
+ 0.3, text_color, 1)
+ # collect system information
+ sys_info = [
+ f'RES: {img.shape[1]}x{img.shape[0]}',
+ f'Buffer: {frame_buffer.qsize()}/{frame_buffer.maxsize}'
+ ]
+ if psutil_proc is not None:
+ sys_info += [
+ f'CPU: {psutil_proc.cpu_percent():.1f}%',
+ f'MEM: {psutil_proc.memory_percent():.1f}%'
+ ]
+ sys_info_str = ' | '.join(sys_info)
+ cv2.putText(img, sys_info_str, (20, 40), cv2.FONT_HERSHEY_DUPLEX,
+ 0.3, text_color, 1)
+
+ # save the output video frame
+ if args.out_video_file is not None:
+ if vid_out is None:
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+ fps = args.out_video_fps
+ frame_size = (img.shape[1], img.shape[0])
+ vid_out = cv2.VideoWriter(args.out_video_file, fourcc, fps,
+ frame_size)
+
+ vid_out.write(img)
+
+ # display
+ cv2.imshow('mmpose webcam demo', img)
+ keyboard_input = cv2.waitKey(1)
+ if keyboard_input in (27, ord('q'), ord('Q')):
+ break
+ elif keyboard_input == ord('s'):
+ args.sunglasses = not args.sunglasses
+ elif keyboard_input == ord('b'):
+ args.bugeye = not args.bugeye
+ elif keyboard_input == ord('v'):
+ args.show_pose = not args.show_pose
+
+ cv2.destroyAllWindows()
+ if vid_out is not None:
+ vid_out.release()
+ event_exit.set()
+
+
+def main():
+ global args
+ global frame_buffer
+ global input_queue, input_queue_mutex
+ global det_result_queue, det_result_queue_mutex
+ global pose_result_queue, pose_result_queue_mutex
+ global det_model, pose_model_list, pose_history_list
+ global event_exit, event_inference_done
+
+ args = parse_args()
+
+ assert has_mmdet, 'Please install mmdet to run the demo.'
+ assert args.det_config is not None
+ assert args.det_checkpoint is not None
+
+ # build detection model
+ det_model = init_detector(
+ args.det_config, args.det_checkpoint, device=args.device.lower())
+
+ # build pose models
+ pose_model_list = []
+ if args.enable_human_pose:
+ pose_model = init_pose_model(
+ args.human_pose_config,
+ args.human_pose_checkpoint,
+ device=args.device.lower())
+ model_info = {
+ 'name': 'HumanPose',
+ 'model': pose_model,
+ 'cat_ids': args.human_det_ids,
+ 'bbox_color': (148, 139, 255),
+ }
+ pose_model_list.append(model_info)
+ if args.enable_animal_pose:
+ pose_model = init_pose_model(
+ args.animal_pose_config,
+ args.animal_pose_checkpoint,
+ device=args.device.lower())
+ model_info = {
+ 'name': 'AnimalPose',
+ 'model': pose_model,
+ 'cat_ids': args.animal_det_ids,
+ 'bbox_color': 'cyan',
+ }
+ pose_model_list.append(model_info)
+
+ # store pose history for pose tracking
+ pose_history_list = []
+ for _ in range(len(pose_model_list)):
+ pose_history_list.append({'pose_results_last': [], 'next_id': 0})
+
+ # frame buffer
+ if args.buffer_size > 0:
+ buffer_size = args.buffer_size
+ else:
+ # infer buffer size from the display delay time
+ # assume that the maximum video fps is 30
+ buffer_size = round(30 * (1 + max(args.display_delay, 0) / 1000.))
+ frame_buffer = Queue(maxsize=buffer_size)
+
+ # queue of input frames
+ # element: (timestamp, frame)
+ input_queue = deque(maxlen=1)
+ input_queue_mutex = Lock()
+
+ # queue of detection results
+ # element: tuple(timestamp, frame, time_info, det_results)
+ det_result_queue = deque(maxlen=1)
+ det_result_queue_mutex = Lock()
+
+ # queue of detection/pose results
+ # element: (timestamp, time_info, pose_results_list)
+ pose_result_queue = deque(maxlen=1)
+ pose_result_queue_mutex = Lock()
+
+ try:
+ event_exit = Event()
+ event_inference_done = Event()
+ t_input = Thread(target=read_camera, args=())
+ t_det = Thread(target=inference_detection, args=(), daemon=True)
+ t_pose = Thread(target=inference_pose, args=(), daemon=True)
+
+ t_input.start()
+ t_det.start()
+ t_pose.start()
+
+ # run display in the main thread
+ display()
+ # join the input thread (non-daemon)
+ t_input.join()
+
+ except KeyboardInterrupt:
+ pass
+
+
+if __name__ == '__main__':
+ main()
diff --git a/mmpose/apis/inference.py b/mmpose/apis/inference.py
index 550c4e3cd6..b6668070ad 100644
--- a/mmpose/apis/inference.py
+++ b/mmpose/apis/inference.py
@@ -520,6 +520,7 @@ def vis_pose_result(model,
radius=4,
thickness=1,
kpt_score_thr=0.3,
+ bbox_color='green',
dataset='TopDownCocoDataset',
show=False,
out_file=None):
@@ -740,8 +741,8 @@ def vis_pose_result(model,
[14, 18], [7, 11], [11, 15], [15, 19], [7, 12], [12, 16],
[16, 20]]
- pose_limb_color = palette[[0] * 20]
- pose_kpt_color = palette[[0] * 20]
+ pose_kpt_color = palette[[15] * 5 + [0] * 7 + [9] * 8]
+ pose_limb_color = palette[[15] * 5 + [0] * 3 + [0, 9, 9] * 4]
else:
raise NotImplementedError()
@@ -755,6 +756,7 @@ def vis_pose_result(model,
pose_kpt_color=pose_kpt_color,
pose_limb_color=pose_limb_color,
kpt_score_thr=kpt_score_thr,
+ bbox_color=bbox_color,
show=show,
out_file=out_file)
diff --git a/mmpose/core/visualization/__init__.py b/mmpose/core/visualization/__init__.py
index d5e0f78170..eda8d48f9d 100644
--- a/mmpose/core/visualization/__init__.py
+++ b/mmpose/core/visualization/__init__.py
@@ -1,3 +1,7 @@
-from .image import imshow_keypoints, imshow_keypoints_3d
+from .effects import apply_bugeye_effect, apply_sunglasses_effect
+from .image import imshow_bboxes, imshow_keypoints, imshow_keypoints_3d
-__all__ = ['imshow_keypoints', 'imshow_keypoints_3d']
+__all__ = [
+ 'imshow_keypoints', 'imshow_keypoints_3d', 'imshow_bboxes',
+ 'apply_bugeye_effect', 'apply_sunglasses_effect'
+]
diff --git a/mmpose/core/visualization/effects.py b/mmpose/core/visualization/effects.py
new file mode 100644
index 0000000000..8711f1b52d
--- /dev/null
+++ b/mmpose/core/visualization/effects.py
@@ -0,0 +1,110 @@
+import cv2
+import numpy as np
+
+
+def apply_bugeye_effect(img,
+ pose_results,
+ left_eye_index,
+ right_eye_index,
+ kpt_thr=0.5):
+ """Apply bug-eye effect.
+
+ Args:
+ img (np.ndarray): Image data.
+ pose_results (list[dict]): The pose estimation results containing:
+ - "bbox" ([K, 4(or 5)]): detection bbox in
+ [x1, y1, x2, y2, (score)]
+ - "keypoints" ([K,3]): keypoint detection result in [x, y, score]
+ left_eye_index (int): Keypoint index of left eye
+ right_eye_index (int): Keypoint index of right eye
+ kpt_thr (float): The score threshold of required keypoints.
+ """
+
+ xx, yy = np.meshgrid(np.arange(img.shape[1]), np.arange(img.shape[0]))
+ xx = xx.astype(np.float32)
+ yy = yy.astype(np.float32)
+
+ for pose in pose_results:
+ bbox = pose['bbox']
+ kpts = pose['keypoints']
+
+ if kpts[left_eye_index, 2] < kpt_thr or kpts[right_eye_index,
+ 2] < kpt_thr:
+ continue
+
+ kpt_leye = kpts[left_eye_index, :2]
+ kpt_reye = kpts[right_eye_index, :2]
+ for xc, yc in [kpt_leye, kpt_reye]:
+
+ # distortion parameters
+ k1 = 0.001
+ epe = 1e-5
+
+ scale = (bbox[2] - bbox[0])**2 + (bbox[3] - bbox[1])**2
+ r2 = ((xx - xc)**2 + (yy - yc)**2)
+ r2 = (r2 + epe) / scale # normalized by bbox scale
+
+ xx = (xx - xc) / (1 + k1 / r2) + xc
+ yy = (yy - yc) / (1 + k1 / r2) + yc
+
+ img = cv2.remap(
+ img,
+ xx,
+ yy,
+ interpolation=cv2.INTER_AREA,
+ borderMode=cv2.BORDER_REPLICATE)
+ return img
+
+
+def apply_sunglasses_effect(img,
+ pose_results,
+ sunglasses_img,
+ left_eye_index,
+ right_eye_index,
+ kpt_thr=0.5):
+ """Apply sunglasses effect.
+
+ Args:
+ img (np.ndarray): Image data.
+ pose_results (list[dict]): The pose estimation results containing:
+ - "keypoints" ([K,3]): keypoint detection result in [x, y, score]
+ sunglasses_img (np.ndarray): Sunglasses image with white background.
+ left_eye_index (int): Keypoint index of left eye
+ right_eye_index (int): Keypoint index of right eye
+ kpt_thr (float): The score threshold of required keypoints.
+ """
+
+ hm, wm = sunglasses_img.shape[:2]
+ # anchor points in the sunglasses mask
+ pts_src = np.array([[0.3 * wm, 0.3 * hm], [0.3 * wm, 0.7 * hm],
+ [0.7 * wm, 0.3 * hm], [0.7 * wm, 0.7 * hm]],
+ dtype=np.float32)
+
+ for pose in pose_results:
+ kpts = pose['keypoints']
+
+ if kpts[left_eye_index, 2] < kpt_thr or kpts[right_eye_index,
+ 2] < kpt_thr:
+ continue
+
+ kpt_leye = kpts[left_eye_index, :2]
+ kpt_reye = kpts[right_eye_index, :2]
+ # orthogonal vector to the left-to-right eyes
+ vo = 0.5 * (kpt_reye - kpt_leye)[::-1] * [-1, 1]
+
+ # anchor points in the image by eye positions
+ pts_tar = np.vstack(
+ [kpt_reye + vo, kpt_reye - vo, kpt_leye + vo, kpt_leye - vo])
+
+ h_mat, _ = cv2.findHomography(pts_src, pts_tar)
+ patch = cv2.warpPerspective(
+ sunglasses_img,
+ h_mat,
+ dsize=(img.shape[1], img.shape[0]),
+ borderValue=(255, 255, 255))
+ # mask the white background area in the patch with a threshold 200
+ mask = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
+ mask = (mask < 200).astype(np.uint8)
+ img = cv2.copyTo(patch, mask, img)
+
+ return img
diff --git a/mmpose/core/visualization/image.py b/mmpose/core/visualization/image.py
index ba9ddc4e95..09bbacc79b 100644
--- a/mmpose/core/visualization/image.py
+++ b/mmpose/core/visualization/image.py
@@ -6,6 +6,82 @@
from matplotlib import pyplot as plt
+def imshow_bboxes(img,
+ bboxes,
+ labels=None,
+ colors='green',
+ text_color='white',
+ thickness=1,
+ font_scale=0.5,
+ show=True,
+ win_name='',
+ wait_time=0,
+ out_file=None):
+ """Draw bboxes with labels (optional) on an image. This is a wrapper of
+ mmcv.imshow_bboxes.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ bboxes (ndarray): ndarray of shape (k, 4), each row is a bbox in
+ format [x1, y1, x2, y2].
+ labels (str or list[str], optional): labels of each bbox.
+ colors (list[str or tuple or :obj:`Color`]): A list of colors.
+ text_color (str or tuple or :obj:`Color`): Color of texts.
+ thickness (int): Thickness of lines.
+ font_scale (float): Font scales of texts.
+ show (bool): Whether to show the image.
+ win_name (str): The window name.
+ wait_time (int): Value of waitKey param.
+ out_file (str, optional): The filename to write the image.
+
+ Returns:
+ ndarray: The image with bboxes drawn on it.
+ """
+
+ # adapt to mmcv.imshow_bboxes input format
+ bboxes = np.split(bboxes, bboxes.shape[0], axis=0)
+ if not isinstance(colors, list):
+ colors = [colors for _ in range(len(bboxes))]
+ colors = [mmcv.color_val(c) for c in colors]
+ assert len(bboxes) == len(colors)
+
+ img = mmcv.imshow_bboxes(
+ img,
+ bboxes,
+ colors,
+ top_k=-1,
+ thickness=thickness,
+ show=False,
+ out_file=None)
+
+ if labels is not None:
+ if not isinstance(labels, list):
+ labels = [labels for _ in range(len(bboxes))]
+ assert len(labels) == len(bboxes)
+
+ for bbox, label, color in zip(bboxes, labels, colors):
+ bbox_int = bbox[0, :4].astype(np.int32)
+ # roughly estimate the proper font size
+ text_size, text_baseline = cv2.getTextSize(label,
+ cv2.FONT_HERSHEY_DUPLEX,
+ font_scale, thickness)
+ text_x1 = bbox_int[0]
+ text_y1 = max(0, bbox_int[1] - text_size[1] - text_baseline)
+ text_x2 = bbox_int[0] + text_size[0]
+ text_y2 = text_y1 + text_size[1] + text_baseline
+ cv2.rectangle(img, (text_x1, text_y1), (text_x2, text_y2), color,
+ cv2.FILLED)
+ cv2.putText(img, label, (text_x1, text_y2 - text_baseline),
+ cv2.FONT_HERSHEY_DUPLEX, font_scale,
+ mmcv.color_val(text_color), thickness)
+
+ if show:
+ mmcv.imshow(img, win_name, wait_time)
+ if out_file is not None:
+ mmcv.imwrite(img, out_file)
+ return img
+
+
def imshow_keypoints(img,
pose_result,
skeleton=None,
diff --git a/mmpose/models/detectors/pose_lifter.py b/mmpose/models/detectors/pose_lifter.py
index 81ca236268..effc405607 100644
--- a/mmpose/models/detectors/pose_lifter.py
+++ b/mmpose/models/detectors/pose_lifter.py
@@ -3,7 +3,7 @@
import mmcv
import numpy as np
-from mmpose.core import imshow_keypoints, imshow_keypoints_3d
+from mmpose.core import imshow_bboxes, imshow_keypoints, imshow_keypoints_3d
from .. import builder
from ..builder import POSENETS
from .base import BasePose
@@ -345,11 +345,10 @@ def show_result(self,
if len(bbox_result) > 0:
bboxes = np.vstack(bbox_result)
- mmcv.imshow_bboxes(
+ imshow_bboxes(
img,
bboxes,
colors='green',
- top_k=-1,
thickness=thickness,
show=False)
if len(pose_input_2d) > 0:
diff --git a/mmpose/models/detectors/top_down.py b/mmpose/models/detectors/top_down.py
index 7134df2978..8aa26bb9e7 100644
--- a/mmpose/models/detectors/top_down.py
+++ b/mmpose/models/detectors/top_down.py
@@ -5,7 +5,7 @@
from mmcv.image import imwrite
from mmcv.visualization.image import imshow
-from mmpose.core import imshow_keypoints
+from mmpose.core import imshow_bboxes, imshow_keypoints
from .. import builder
from ..builder import POSENETS
from .base import BasePose
@@ -222,10 +222,11 @@ def show_result(self,
bbox_color='green',
pose_kpt_color=None,
pose_limb_color=None,
- text_color=(255, 0, 0),
+ text_color='white',
radius=4,
thickness=1,
font_scale=0.5,
+ bbox_thickness=1,
win_name='',
show=False,
show_keypoint_weight=False,
@@ -264,7 +265,6 @@ def show_result(self,
img = mmcv.imread(img)
img = img.copy()
- img_h, img_w, _ = img.shape
bbox_result = []
pose_result = []
@@ -274,13 +274,18 @@ def show_result(self,
if len(bbox_result) > 0:
bboxes = np.vstack(bbox_result)
+ labels = None
+ if 'label' in result[0]:
+ labels = [res['label'] for res in result]
# draw bounding boxes
- mmcv.imshow_bboxes(
+ imshow_bboxes(
img,
bboxes,
+ labels=labels,
colors=bbox_color,
- top_k=-1,
- thickness=thickness,
+ text_color=text_color,
+ thickness=bbox_thickness,
+ font_scale=font_scale,
show=False)
imshow_keypoints(img, pose_result, skeleton, kpt_score_thr,
diff --git a/mmpose/utils/__init__.py b/mmpose/utils/__init__.py
index ac489e2dbb..860a2d4a57 100644
--- a/mmpose/utils/__init__.py
+++ b/mmpose/utils/__init__.py
@@ -1,4 +1,5 @@
from .collect_env import collect_env
from .logger import get_root_logger
+from .timer import StopWatch
-__all__ = ['get_root_logger', 'collect_env']
+__all__ = ['get_root_logger', 'collect_env', 'StopWatch']
diff --git a/mmpose/utils/timer.py b/mmpose/utils/timer.py
new file mode 100644
index 0000000000..2e499dcbd2
--- /dev/null
+++ b/mmpose/utils/timer.py
@@ -0,0 +1,92 @@
+from collections import defaultdict
+
+import numpy as np
+from mmcv import Timer
+
+
+class StopWatch:
+ r"""A helper class to measure FPS and detailed time consuming of each phase
+ in a video processing loop or similar scenarios.
+
+ Args:
+ window (int): The sliding window size to calculate the running average
+ of the time consuming.
+
+ Example::
+ >>> stop_watch = StopWatch(window=10)
+ >>> while True:
+ ... with stop_watch.timeit('total'):
+ ... sleep(1)
+ ... # 'timeit' support nested use
+ ... with stop_watch.timeit('phase1'):
+ ... sleep(1)
+ ... with stop_watch.timeit('phase2'):
+ ... sleep(2)
+ ... sleep(2)
+ ... report = stop_watch.report()
+ report = {'total': 6., 'phase1': 1., 'phase2': 2.}
+
+ """
+
+ def __init__(self, window=1):
+ self._record = defaultdict(list)
+ self._timer_stack = []
+ self.window = window
+
+ def timeit(self, timer_name='_FPS_'):
+ """Timing a code snippet with an assigned name.
+
+ Args:
+ timer_name (str): The unique name of the interested code snippet to
+ handle multiple timers and generate reports. Note that '_FPS_'
+ is a special key that the measurement will be in `fps` instead
+ of `millisecond`. Also see `report` and `report_strings`.
+ Default: '_FPS_'.
+ Note:
+ This function should always be used in a `with` statement, as shown
+ in the example.
+ """
+ self._timer_stack.append((timer_name, Timer()))
+ return self
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_value, trackback):
+ timer_name, timer = self._timer_stack.pop()
+ self._record[timer_name].append(timer.since_start())
+ self._record[timer_name] = self._record[timer_name][-self.window:]
+
+ def report(self):
+ """Report timing information.
+
+ Returns:
+ dict: The key is the timer name and the value is the corresponding
+ average time consuming.
+ """
+ result = {
+ name: np.mean(vals) * 1000.
+ for name, vals in self._record.items()
+ }
+ return result
+
+ def report_strings(self):
+ """Report timing information in texture strings.
+
+ Returns:
+ list(str): Each element is the information string of a timed event,
+ in format of '{timer_name}: {time_in_ms}'. Specially, if
+ timer_name is '_FPS_', the result will be converted to
+ fps.
+ """
+ result = self.report()
+ strings = []
+ if '_FPS_' in result:
+ fps = 1000. / result.pop('_FPS_')
+ strings.append(f'FPS: {fps:>5.1f}')
+ strings += [f'{name}: {val:>3.0f}' for name, val in result.items()]
+ return strings
+
+ def reset(self):
+ self._record = defaultdict(list)
+ self._timer_stack = []
diff --git a/tests/test_regularization.py b/tests/test_regularization.py
index b23b86f1c7..3de10aceef 100644
--- a/tests/test_regularization.py
+++ b/tests/test_regularization.py
@@ -15,4 +15,4 @@ def test_weight_norm_clip():
_ = module(x)
weight_norm = module.weight.norm().item()
- np.testing.assert_allclose(weight_norm, 1.0, rtol=1e-6)
+ np.testing.assert_almost_equal(weight_norm, 1.0, decimal=6)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 7057dbedb0..67fa096cb8 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,10 +1,13 @@
+import time
+
import cv2
import mmcv
+import numpy as np
import torch
import torchvision
import mmpose
-from mmpose.utils import collect_env
+from mmpose.utils import StopWatch, collect_env
def test_collect_env():
@@ -15,3 +18,25 @@ def test_collect_env():
assert env_info['MMCV'] == mmcv.__version__
assert '+' in env_info['MMPose']
assert mmpose.__version__ in env_info['MMPose']
+
+
+def test_stopwatch():
+ window_size = 5
+ test_loop = 10
+ outer_time = 100
+ inner_time = 100
+
+ stop_watch = StopWatch(window=window_size)
+ for _ in range(test_loop):
+ with stop_watch.timeit():
+ time.sleep(outer_time / 1000.)
+ with stop_watch.timeit('inner'):
+ time.sleep(inner_time / 1000.)
+
+ report = stop_watch.report()
+ _ = stop_watch.report_strings()
+
+ np.testing.assert_allclose(
+ report['_FPS_'], outer_time + inner_time, rtol=0.01)
+
+ np.testing.assert_allclose(report['inner'], inner_time, rtol=0.01)
diff --git a/tests/test_visualization.py b/tests/test_visualization.py
index 8403713adf..927afaafb4 100644
--- a/tests/test_visualization.py
+++ b/tests/test_visualization.py
@@ -1,6 +1,10 @@
+import tempfile
+
+import mmcv
import numpy as np
-from mmpose.core import imshow_keypoints, imshow_keypoints_3d
+from mmpose.core import (apply_bugeye_effect, apply_sunglasses_effect,
+ imshow_bboxes, imshow_keypoints, imshow_keypoints_3d)
def test_imshow_keypoints():
@@ -29,3 +33,48 @@ def test_imshow_keypoints():
pose_kpt_color=pose_kpt_color,
pose_limb_color=pose_limb_color,
vis_height=400)
+
+
+def test_imshow_bbox():
+ img = np.zeros((100, 100, 3), dtype=np.uint8)
+ bboxes = np.array([[10, 10, 30, 30], [10, 50, 30, 80]], dtype=np.float32)
+ labels = ['label 1', 'label 2']
+ colors = ['red', 'green']
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ _ = imshow_bboxes(
+ img,
+ bboxes,
+ labels=labels,
+ colors=colors,
+ show=False,
+ out_file=f'{tmpdir}/out.png')
+
+
+def test_effects():
+ img = np.zeros((100, 100, 3), dtype=np.uint8)
+ kpts = np.array([[10., 10., 0.8], [20., 10., 0.8]], dtype=np.float32)
+ bbox = np.array([0, 0, 50, 50], dtype=np.float32)
+ pose_results = [dict(bbox=bbox, keypoints=kpts)]
+ # sunglasses
+ sunglasses_img = mmcv.imread('demo/resources/sunglasses.jpg')
+ _ = apply_sunglasses_effect(
+ img,
+ pose_results,
+ sunglasses_img,
+ left_eye_index=1,
+ right_eye_index=0,
+ kpt_thr=0.5)
+ _ = apply_sunglasses_effect(
+ img,
+ pose_results,
+ sunglasses_img,
+ left_eye_index=1,
+ right_eye_index=0,
+ kpt_thr=0.9)
+
+ # bug-eye
+ _ = apply_bugeye_effect(
+ img, pose_results, left_eye_index=1, right_eye_index=0, kpt_thr=0.5)
+ _ = apply_bugeye_effect(
+ img, pose_results, left_eye_index=1, right_eye_index=0, kpt_thr=0.9)