Skip to content

Commit

Permalink
[Enhance] Better result visualization (open-mmlab#419)
Browse files Browse the repository at this point in the history
* Imporve result visualization to support wait time and change the backend
to matplotlib.

* Add unit test for visualization

* Add adaptive dpi function

* Rename `imshow_cls_result` to `imshow_infos`.

* Support str in `imshow_infos`

* Improve docstring.
  • Loading branch information
mzr1996 authored and Ezra-Yu committed Sep 7, 2021
1 parent e800ffb commit d3d797f
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 48 deletions.
12 changes: 6 additions & 6 deletions mmcls/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import matplotlib.pyplot as plt
import mmcv
import numpy as np
import torch
Expand Down Expand Up @@ -90,18 +89,19 @@ def inference_model(model, img):
return result


def show_result_pyplot(model, img, result, fig_size=(15, 10)):
def show_result_pyplot(model, img, result, fig_size=(15, 10), wait_time=0):
"""Visualize the classification results on the image.
Args:
model (nn.Module): The loaded classifier.
img (str or np.ndarray): Image filename or loaded image.
result (list): The classification result.
fig_size (tuple): Figure size of the pyplot figure.
Defaults to (15, 10).
wait_time (int): How many seconds to display the image.
Defaults to 0.
"""
if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.show()
model.show_result(
img, result, show=True, fig_size=fig_size, wait_time=wait_time)
3 changes: 3 additions & 0 deletions mmcls/core/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .image import color_val_matplotlib, imshow_infos

__all__ = ['imshow_infos', 'color_val_matplotlib']
130 changes: 130 additions & 0 deletions mmcls/core/visualization/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import matplotlib.pyplot as plt
import mmcv
import numpy as np

# A small value
EPS = 1e-2


def color_val_matplotlib(color):
"""Convert various input in BGR order to normalized RGB matplotlib color
tuples,
Args:
color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs
Returns:
tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
"""
color = mmcv.color_val(color)
color = [color / 255 for color in color[::-1]]
return tuple(color)


def imshow_infos(img,
infos,
text_color='white',
font_size=26,
row_width=20,
win_name='',
show=True,
fig_size=(15, 10),
wait_time=0,
out_file=None):
"""Show image with extra infomation.
Args:
img (str | ndarray): The image to be displayed.
infos (dict): Extra infos to display in the image.
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
display color. Defaults to 'white'.
font_size (int): Extra infos display font size. Defaults to 26.
row_width (int): width between each row of results on the image.
win_name (str): The image title. Defaults to ''
show (bool): Whether to show the image. Defaults to True.
fig_size (tuple): Image show figure size. Defaults to (15, 10).
wait_time (int): How many seconds to display the image. Defaults to 0.
out_file (Optional[str]): The filename to write the image.
Defaults to None.
Returns:
np.ndarray: The image with extra infomations.
"""
img = mmcv.imread(img).astype(np.uint8)

x, y = 3, row_width // 2
text_color = color_val_matplotlib(text_color)

img = mmcv.bgr2rgb(img)
width, height = img.shape[1], img.shape[0]
img = np.ascontiguousarray(img)

# A proper dpi for image save with default font size.
fig = plt.figure(win_name, frameon=False, figsize=fig_size, dpi=36)
plt.title(win_name)
canvas = fig.canvas
dpi = fig.get_dpi()
# add a small EPS to avoid precision lost due to matplotlib's truncation
# (https://github.com/matplotlib/matplotlib/issues/15363)
fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)

# remove white edges by set subplot margin
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
ax = plt.gca()
ax.axis('off')

for k, v in infos.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
ax.text(
x,
y,
f'{label_text}',
bbox={
'facecolor': 'black',
'alpha': 0.7,
'pad': 0.2,
'edgecolor': 'none',
'boxstyle': 'round'
},
color=text_color,
fontsize=font_size,
family='monospace',
verticalalignment='top',
horizontalalignment='left')
y += row_width

plt.imshow(img)
stream, _ = canvas.print_to_buffer()
buffer = np.frombuffer(stream, dtype='uint8')
img_rgba = buffer.reshape(height, width, 4)
rgb, _ = np.split(img_rgba, [3], axis=2)
img = rgb.astype('uint8')
img = mmcv.rgb2bgr(img)

if show:
# Matplotlib will adjust text size depends on window size and image
# aspect ratio. It's hard to get, so here we set an adaptive dpi
# according to screen height. 20 here is an empirical parameter.
fig_manager = plt.get_current_fig_manager()
if hasattr(fig_manager, 'window'):
# Figure manager doesn't have window if no screen.
screen_dpi = fig_manager.window.winfo_screenheight() // 20
fig.set_dpi(screen_dpi)

# We do not use cv2 for display because in some cases, opencv will
# conflict with Qt, it will output a warning: Current thread
# is not the object's thread. You can refer to
# https://github.com/opencv/opencv-python/issues/46 for details
if wait_time == 0:
plt.show()
else:
plt.show(block=False)
plt.pause(wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)

plt.close()

return img
51 changes: 21 additions & 30 deletions mmcls/models/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict

import cv2
import mmcv
import torch
import torch.distributed as dist
from mmcv import color_val
from mmcv.runner import BaseModule

from mmcls.core.visualization import imshow_infos

# TODO import `auto_fp16` from mmcv and delete them from mmcls
try:
from mmcv.runner import auto_fp16
Expand Down Expand Up @@ -169,10 +169,11 @@ def val_step(self, data, optimizer):
def show_result(self,
img,
result,
text_color='green',
text_color='white',
font_scale=0.5,
row_width=20,
show=False,
fig_size=(15, 10),
win_name='',
wait_time=0,
out_file=None):
Expand All @@ -186,39 +187,29 @@ def show_result(self,
row_width (int): width between each row of results on the image.
show (bool): Whether to show the image.
Default: False.
fig_size (tuple): Image show figure size. Defaults to (15, 10).
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
wait_time (int): How many seconds to display the image.
Defaults to 0.
out_file (str or None): The filename to write the image.
Default: None.
Returns:
img (ndarray): Only if not `show` or `out_file`
img (ndarray): Image with overlayed results.
"""
img = mmcv.imread(img)
img = img.copy()

# write results on left-top of the image
x, y = 0, row_width
text_color = color_val(text_color)
for k, v in result.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
font_scale, text_color)
y += row_width

# if out_file specified, do not show image in window
if out_file is not None:
show = False

if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)

if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
img = imshow_infos(
img,
result,
text_color=text_color,
font_size=int(font_scale * 50),
row_width=row_width,
win_name=win_name,
show=show,
fig_size=fig_size,
wait_time=wait_time,
out_file=out_file)

return img
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmcls
known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts
known_third_party = PIL,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
11 changes: 0 additions & 11 deletions tests/test_models/test_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os.path as osp
import tempfile
from copy import deepcopy
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -85,16 +84,6 @@ def test_image_classifier():
model.show_result(img, result, out_file=out_file)
assert osp.exists(out_file)

def save_show(_, *args):
out_path = osp.join(tmpdir, '_'.join([str(arg) for arg in args]))
with open(out_path, 'w') as f:
f.write('test')

with patch('mmcv.imshow', save_show):
model.show_result(
img, result, show=True, win_name='img', wait_time=5)
assert osp.exists(osp.join(tmpdir, 'img_5'))


def test_image_classifier_with_mixup():
# Test mixup in ImageClassifier
Expand Down
90 changes: 90 additions & 0 deletions tests/test_utils/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os
import os.path as osp
import shutil
import tempfile
from unittest.mock import Mock, patch

import mmcv
import numpy as np
import pytest

from mmcls.core import visualization as vis


def test_color():
assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.)
assert vis.color_val_matplotlib('green') == (0., 1., 0.)
assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255)
assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255)
assert vis.color_val_matplotlib(np.zeros(3, dtype=int)) == (0., 0., 0.)
# forbid white color
with pytest.raises(TypeError):
vis.color_val_matplotlib([255, 255, 255])
# forbid float
with pytest.raises(TypeError):
vis.color_val_matplotlib(1.0)
# overflowed
with pytest.raises(AssertionError):
vis.color_val_matplotlib((0, 0, 500))


def test_imshow_infos():
tmp_dir = osp.join(tempfile.gettempdir(), 'infos_image')
tmp_filename = osp.join(tmp_dir, 'image.jpg')

image = np.ones((10, 10, 3), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
out_image = vis.imshow_infos(
image, result, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
os.remove(tmp_filename)

# test grayscale images
image = np.ones((10, 10), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
out_image = vis.imshow_infos(
image, result, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)
assert image.shape == out_image.shape[:2]
os.remove(tmp_filename)

# test show=True
image = np.ones((10, 10, 3), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}

def save_args(*args, **kwargs):
args_list = ['args']
args_list += [
str(arg) for arg in args if isinstance(arg, (str, bool, int))
]
args_list += [
f'{k}-{v}' for k, v in kwargs.items()
if isinstance(v, (str, bool, int))
]
out_path = osp.join(tmp_dir, '_'.join(args_list))
with open(out_path, 'w') as f:
f.write('test')

with patch('matplotlib.pyplot.show', save_args), \
patch('matplotlib.pyplot.pause', save_args):
vis.imshow_infos(image, result, show=True, wait_time=5)
assert osp.exists(osp.join(tmp_dir, 'args_block-False'))
assert osp.exists(osp.join(tmp_dir, 'args_5'))

vis.imshow_infos(image, result, show=True, wait_time=0)
assert osp.exists(osp.join(tmp_dir, 'args'))

# test adaptive dpi
def mock_fig_manager():
fig_manager = Mock()
fig_manager.window.winfo_screenheight = Mock(return_value=1440)
return fig_manager

with patch('matplotlib.pyplot.get_current_fig_manager',
mock_fig_manager), patch('matplotlib.pyplot.show'):
vis.imshow_infos(image, result, show=True)

shutil.rmtree(tmp_dir)

0 comments on commit d3d797f

Please sign in to comment.