forked from open-mmlab/mmpretrain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhance] Better result visualization (open-mmlab#419)
* Imporve result visualization to support wait time and change the backend to matplotlib. * Add unit test for visualization * Add adaptive dpi function * Rename `imshow_cls_result` to `imshow_infos`. * Support str in `imshow_infos` * Improve docstring.
- Loading branch information
Showing
7 changed files
with
251 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .image import color_val_matplotlib, imshow_infos | ||
|
||
__all__ = ['imshow_infos', 'color_val_matplotlib'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import matplotlib.pyplot as plt | ||
import mmcv | ||
import numpy as np | ||
|
||
# A small value | ||
EPS = 1e-2 | ||
|
||
|
||
def color_val_matplotlib(color): | ||
"""Convert various input in BGR order to normalized RGB matplotlib color | ||
tuples, | ||
Args: | ||
color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs | ||
Returns: | ||
tuple[float]: A tuple of 3 normalized floats indicating RGB channels. | ||
""" | ||
color = mmcv.color_val(color) | ||
color = [color / 255 for color in color[::-1]] | ||
return tuple(color) | ||
|
||
|
||
def imshow_infos(img, | ||
infos, | ||
text_color='white', | ||
font_size=26, | ||
row_width=20, | ||
win_name='', | ||
show=True, | ||
fig_size=(15, 10), | ||
wait_time=0, | ||
out_file=None): | ||
"""Show image with extra infomation. | ||
Args: | ||
img (str | ndarray): The image to be displayed. | ||
infos (dict): Extra infos to display in the image. | ||
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos | ||
display color. Defaults to 'white'. | ||
font_size (int): Extra infos display font size. Defaults to 26. | ||
row_width (int): width between each row of results on the image. | ||
win_name (str): The image title. Defaults to '' | ||
show (bool): Whether to show the image. Defaults to True. | ||
fig_size (tuple): Image show figure size. Defaults to (15, 10). | ||
wait_time (int): How many seconds to display the image. Defaults to 0. | ||
out_file (Optional[str]): The filename to write the image. | ||
Defaults to None. | ||
Returns: | ||
np.ndarray: The image with extra infomations. | ||
""" | ||
img = mmcv.imread(img).astype(np.uint8) | ||
|
||
x, y = 3, row_width // 2 | ||
text_color = color_val_matplotlib(text_color) | ||
|
||
img = mmcv.bgr2rgb(img) | ||
width, height = img.shape[1], img.shape[0] | ||
img = np.ascontiguousarray(img) | ||
|
||
# A proper dpi for image save with default font size. | ||
fig = plt.figure(win_name, frameon=False, figsize=fig_size, dpi=36) | ||
plt.title(win_name) | ||
canvas = fig.canvas | ||
dpi = fig.get_dpi() | ||
# add a small EPS to avoid precision lost due to matplotlib's truncation | ||
# (https://github.com/matplotlib/matplotlib/issues/15363) | ||
fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi) | ||
|
||
# remove white edges by set subplot margin | ||
plt.subplots_adjust(left=0, right=1, bottom=0, top=1) | ||
ax = plt.gca() | ||
ax.axis('off') | ||
|
||
for k, v in infos.items(): | ||
if isinstance(v, float): | ||
v = f'{v:.2f}' | ||
label_text = f'{k}: {v}' | ||
ax.text( | ||
x, | ||
y, | ||
f'{label_text}', | ||
bbox={ | ||
'facecolor': 'black', | ||
'alpha': 0.7, | ||
'pad': 0.2, | ||
'edgecolor': 'none', | ||
'boxstyle': 'round' | ||
}, | ||
color=text_color, | ||
fontsize=font_size, | ||
family='monospace', | ||
verticalalignment='top', | ||
horizontalalignment='left') | ||
y += row_width | ||
|
||
plt.imshow(img) | ||
stream, _ = canvas.print_to_buffer() | ||
buffer = np.frombuffer(stream, dtype='uint8') | ||
img_rgba = buffer.reshape(height, width, 4) | ||
rgb, _ = np.split(img_rgba, [3], axis=2) | ||
img = rgb.astype('uint8') | ||
img = mmcv.rgb2bgr(img) | ||
|
||
if show: | ||
# Matplotlib will adjust text size depends on window size and image | ||
# aspect ratio. It's hard to get, so here we set an adaptive dpi | ||
# according to screen height. 20 here is an empirical parameter. | ||
fig_manager = plt.get_current_fig_manager() | ||
if hasattr(fig_manager, 'window'): | ||
# Figure manager doesn't have window if no screen. | ||
screen_dpi = fig_manager.window.winfo_screenheight() // 20 | ||
fig.set_dpi(screen_dpi) | ||
|
||
# We do not use cv2 for display because in some cases, opencv will | ||
# conflict with Qt, it will output a warning: Current thread | ||
# is not the object's thread. You can refer to | ||
# https://github.com/opencv/opencv-python/issues/46 for details | ||
if wait_time == 0: | ||
plt.show() | ||
else: | ||
plt.show(block=False) | ||
plt.pause(wait_time) | ||
if out_file is not None: | ||
mmcv.imwrite(img, out_file) | ||
|
||
plt.close() | ||
|
||
return img |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) Open-MMLab. All rights reserved. | ||
import os | ||
import os.path as osp | ||
import shutil | ||
import tempfile | ||
from unittest.mock import Mock, patch | ||
|
||
import mmcv | ||
import numpy as np | ||
import pytest | ||
|
||
from mmcls.core import visualization as vis | ||
|
||
|
||
def test_color(): | ||
assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.) | ||
assert vis.color_val_matplotlib('green') == (0., 1., 0.) | ||
assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255) | ||
assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255) | ||
assert vis.color_val_matplotlib(np.zeros(3, dtype=int)) == (0., 0., 0.) | ||
# forbid white color | ||
with pytest.raises(TypeError): | ||
vis.color_val_matplotlib([255, 255, 255]) | ||
# forbid float | ||
with pytest.raises(TypeError): | ||
vis.color_val_matplotlib(1.0) | ||
# overflowed | ||
with pytest.raises(AssertionError): | ||
vis.color_val_matplotlib((0, 0, 500)) | ||
|
||
|
||
def test_imshow_infos(): | ||
tmp_dir = osp.join(tempfile.gettempdir(), 'infos_image') | ||
tmp_filename = osp.join(tmp_dir, 'image.jpg') | ||
|
||
image = np.ones((10, 10, 3), np.uint8) | ||
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98} | ||
out_image = vis.imshow_infos( | ||
image, result, out_file=tmp_filename, show=False) | ||
assert osp.isfile(tmp_filename) | ||
assert image.shape == out_image.shape | ||
assert not np.allclose(image, out_image) | ||
os.remove(tmp_filename) | ||
|
||
# test grayscale images | ||
image = np.ones((10, 10), np.uint8) | ||
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98} | ||
out_image = vis.imshow_infos( | ||
image, result, out_file=tmp_filename, show=False) | ||
assert osp.isfile(tmp_filename) | ||
assert image.shape == out_image.shape[:2] | ||
os.remove(tmp_filename) | ||
|
||
# test show=True | ||
image = np.ones((10, 10, 3), np.uint8) | ||
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98} | ||
|
||
def save_args(*args, **kwargs): | ||
args_list = ['args'] | ||
args_list += [ | ||
str(arg) for arg in args if isinstance(arg, (str, bool, int)) | ||
] | ||
args_list += [ | ||
f'{k}-{v}' for k, v in kwargs.items() | ||
if isinstance(v, (str, bool, int)) | ||
] | ||
out_path = osp.join(tmp_dir, '_'.join(args_list)) | ||
with open(out_path, 'w') as f: | ||
f.write('test') | ||
|
||
with patch('matplotlib.pyplot.show', save_args), \ | ||
patch('matplotlib.pyplot.pause', save_args): | ||
vis.imshow_infos(image, result, show=True, wait_time=5) | ||
assert osp.exists(osp.join(tmp_dir, 'args_block-False')) | ||
assert osp.exists(osp.join(tmp_dir, 'args_5')) | ||
|
||
vis.imshow_infos(image, result, show=True, wait_time=0) | ||
assert osp.exists(osp.join(tmp_dir, 'args')) | ||
|
||
# test adaptive dpi | ||
def mock_fig_manager(): | ||
fig_manager = Mock() | ||
fig_manager.window.winfo_screenheight = Mock(return_value=1440) | ||
return fig_manager | ||
|
||
with patch('matplotlib.pyplot.get_current_fig_manager', | ||
mock_fig_manager), patch('matplotlib.pyplot.show'): | ||
vis.imshow_infos(image, result, show=True) | ||
|
||
shutil.rmtree(tmp_dir) |