Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace imshow_det_bboxes visualization backend #4389

Merged
merged 22 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions mmdet/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings

import matplotlib.pyplot as plt
import mmcv
import numpy as np
import torch
Expand Down Expand Up @@ -162,9 +161,8 @@ def show_result_pyplot(model,
img,
result,
score_thr=0.3,
fig_size=(15, 10),
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
title='result',
block=True):
wait_time=0):
"""Visualize the detection results on the image.

Args:
Expand All @@ -173,15 +171,17 @@ def show_result_pyplot(model,
result (tuple[list] or list): The detection result, can be either
(bbox, segm) or just bbox.
score_thr (float): The threshold to visualize the bboxes and masks.
fig_size (tuple): Figure size of the pyplot figure.
title (str): Title of the pyplot figure.
block (bool): Whether to block GUI.
wait_time (float): Value of waitKey param.
Default: 0.
"""
if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, score_thr=score_thr, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.title(title)
plt.tight_layout()
plt.show(block=block)
model.show_result(
img,
result,
score_thr=score_thr,
show=True,
wait_time=wait_time,
bbox_color=(72, 101, 241),
text_color=(72, 101, 241))
3 changes: 3 additions & 0 deletions mmdet/core/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .image import color_val_matplotlib, imshow_det_bboxes

__all__ = ['imshow_det_bboxes', 'color_val_matplotlib']
117 changes: 117 additions & 0 deletions mmdet/core/visualization/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import os.path as osp
import warnings

import matplotlib.pyplot as plt
import mmcv
import numpy as np
from matplotlib.collections import PatchCollection
from matplotlib.patches import Polygon


def color_val_matplotlib(color):
"""Convert opencv color value to matplotlib (BGR->RGB->Norm)"""
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
color = mmcv.color_val(color)
color = [color / 255 for color in color[::-1]]
return tuple(color)


def imshow_det_bboxes(img,
bboxes,
labels,
class_names=None,
score_thr=0,
bbox_color='green',
text_color='green',
thickness=2,
font_scale=0.5,
font_size=10,
win_name='',
show=True,
wait_time=0,
out_file=None):
"""Draw bboxes and class labels (with scores) on an image.

Args:
img (str or ndarray): The image to be displayed.
bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
(n, 5).
labels (ndarray): Labels of bboxes.
class_names (list[str]): Names of each classes.
score_thr (float): Minimum score of bboxes to be shown.
bbox_color (str or tuple or :obj:`Color`): Color of bbox lines.
text_color (str or tuple or :obj:`Color`): Color of texts.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
font_size (int): Font size of texts.
show (bool): Whether to show the image.
win_name (str): The window name.
wait_time (float): Value of waitKey param.
out_file (str or None): The filename to write the image.

Returns:
ndarray: The image with bboxes drawn on it.
"""
warnings.warn('"font_scale" will be deprecated in v2.9.0,'
'Please use "font_size"')
assert bboxes.ndim == 2
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
assert labels.ndim == 1
assert bboxes.shape[0] == labels.shape[0]
assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5
img = mmcv.imread(img).copy()

if score_thr > 0:
assert bboxes.shape[1] == 5
scores = bboxes[:, -1]
inds = scores > score_thr
bboxes = bboxes[inds, :]
labels = labels[inds]

bbox_color = color_val_matplotlib(bbox_color)
text_color = color_val_matplotlib(text_color)

img = mmcv.bgr2rgb(img)
img = np.ascontiguousarray(img)

plt.figure(win_name)
plt.title(win_name)
plt.imshow(img)
plt.axis('off')
ax = plt.gca()
ax.set_autoscale_on(False)
polygons = []
color = []

for bbox, label in zip(bboxes, labels):
bbox_int = bbox.astype(np.int32)
poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]],
[bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]]
np_poly = np.array(poly).reshape((4, 2))
polygons.append(Polygon(np_poly))
color.append(bbox_color)
label_text = class_names[
label] if class_names is not None else f'cls {label}'
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
if len(bbox) > 4:
label_text += f'|{bbox[-1]:.02f}'
ax.text(
bbox_int[0],
bbox_int[1],
'%s' % label_text,
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
color=text_color,
fontsize=font_size)

p = PatchCollection(
polygons, facecolor='none', edgecolors=color, linewidths=thickness)
ax.add_collection(p)

if out_file is not None:
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
dir_name = osp.abspath(osp.dirname(out_file))
mmcv.mkdir_or_exist(dir_name)
plt.savefig(out_file)
if show:
if wait_time == 0:
plt.show()
else:
plt.show(block=False)
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
plt.pause(wait_time)
plt.close()
return mmcv.rgb2bgr(img)
12 changes: 8 additions & 4 deletions mmdet/models/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mmcv.runner import auto_fp16
from mmcv.utils import print_log

from mmdet.core.visualization import imshow_det_bboxes
from mmdet.utils import get_root_logger


Expand Down Expand Up @@ -270,10 +271,11 @@ def show_result(self,
img,
result,
score_thr=0.3,
bbox_color='green',
text_color='green',
bbox_color=(72, 101, 241),
text_color=(72, 101, 241),
thickness=1,
font_scale=0.5,
font_size=10,
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
win_name='',
show=False,
wait_time=0,
Expand All @@ -290,8 +292,9 @@ def show_result(self,
text_color (str or tuple or :obj:`Color`): Color of texts.
thickness (int): Thickness of lines.
font_scale (float): Font scales of texts.
font_size (int): Font size of texts.
win_name (str): The window name.
wait_time (int): Value of waitKey param.
wait_time (float): Value of waitKey param.
Default: 0.
show (bool): Whether to show the image.
Default: False.
Expand Down Expand Up @@ -336,7 +339,7 @@ def show_result(self,
if out_file is not None:
show = False
# draw bounding boxes
mmcv.imshow_det_bboxes(
imshow_det_bboxes(
img,
bboxes,
labels,
Expand All @@ -346,6 +349,7 @@ def show_result(self,
text_color=text_color,
thickness=thickness,
font_scale=font_scale,
font_size=font_size,
win_name=win_name,
show=show,
wait_time=wait_time,
Expand Down
34 changes: 34 additions & 0 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os.path as osp
import tempfile

import mmcv
import numpy as np
import pytest

from mmdet.core import visualization as vis


def test_color():
assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.)
assert vis.color_val_matplotlib('green') == (0., 1., 0.)
assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255)
assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255)
assert vis.color_val_matplotlib(np.zeros(3, dtype=np.int)) == (0., 0., 0.)
with pytest.raises(TypeError):
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
vis.color_val_matplotlib([255, 255, 255])
with pytest.raises(TypeError):
vis.color_val_matplotlib(1.0)
with pytest.raises(AssertionError):
vis.color_val_matplotlib((0, 0, 500))


def test_imshow_det_bboxes():
tmp_filename = osp.join(tempfile.gettempdir(), 'det_bboxes_image',
'image.jpg')
image = np.ones((10, 10, 3), np.uint8)
bbox = np.array([[2, 1, 3, 3], [3, 4, 6, 6]])
label = np.array([0, 1])
vis.imshow_det_bboxes(
image, bbox, label, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)
14 changes: 9 additions & 5 deletions tools/browse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import mmcv
from mmcv import Config

from mmdet.core.visualization import imshow_det_bboxes
from mmdet.datasets.builder import build_dataset


Expand All @@ -25,9 +26,9 @@ def parse_args():
parser.add_argument('--not-show', default=False, action='store_true')
parser.add_argument(
'--show-interval',
type=int,
default=999,
help='the interval of show (ms)')
type=float,
default=1,
help='the interval of show (s)')
args = parser.parse_args()
return args

Expand All @@ -53,14 +54,17 @@ def main():
filename = os.path.join(args.output_dir,
Path(item['filename']).name
) if args.output_dir is not None else None
mmcv.imshow_det_bboxes(

imshow_det_bboxes(
item['img'],
item['gt_bboxes'],
item['gt_labels'],
class_names=dataset.CLASSES,
show=not args.not_show,
wait_time=args.show_interval,
out_file=filename,
wait_time=args.show_interval)
bbox_color=(255, 102, 61),
text_color=(255, 102, 61))
progress_bar.update()


Expand Down