Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Better result visualization #419

Merged
merged 6 commits into from
Aug 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions mmcls/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings

import matplotlib.pyplot as plt
import mmcv
import numpy as np
import torch
Expand Down Expand Up @@ -89,18 +88,19 @@ def inference_model(model, img):
return result


def show_result_pyplot(model, img, result, fig_size=(15, 10)):
def show_result_pyplot(model, img, result, fig_size=(15, 10), wait_time=0):
mzr1996 marked this conversation as resolved.
Show resolved Hide resolved
"""Visualize the classification results on the image.

Args:
model (nn.Module): The loaded classifier.
img (str or np.ndarray): Image filename or loaded image.
result (list): The classification result.
fig_size (tuple): Figure size of the pyplot figure.
Defaults to (15, 10).
wait_time (int): How many seconds to display the image.
Defaults to 0.
"""
if hasattr(model, 'module'):
model = model.module
img = model.show_result(img, result, show=False)
plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.show()
model.show_result(
img, result, show=True, fig_size=fig_size, wait_time=wait_time)
3 changes: 3 additions & 0 deletions mmcls/core/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .image import color_val_matplotlib, imshow_infos

__all__ = ['imshow_infos', 'color_val_matplotlib']
130 changes: 130 additions & 0 deletions mmcls/core/visualization/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import matplotlib.pyplot as plt
import mmcv
import numpy as np

# A small value
EPS = 1e-2
mzr1996 marked this conversation as resolved.
Show resolved Hide resolved


def color_val_matplotlib(color):
"""Convert various input in BGR order to normalized RGB matplotlib color
tuples,

Args:
color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs

Returns:
tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
"""
color = mmcv.color_val(color)
color = [color / 255 for color in color[::-1]]
return tuple(color)


def imshow_infos(img,
infos,
text_color='white',
font_size=26,
row_width=20,
win_name='',
show=True,
fig_size=(15, 10),
wait_time=0,
out_file=None):
"""Show image with extra infomation.

Args:
img (str | ndarray): The image to be displayed.
infos (dict): Extra infos to display in the image.
text_color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Extra infos
display color. Defaults to 'white'.
font_size (int): Extra infos display font size. Defaults to 26.
row_width (int): width between each row of results on the image.
win_name (str): The image title. Defaults to ''
show (bool): Whether to show the image. Defaults to True.
fig_size (tuple): Image show figure size. Defaults to (15, 10).
wait_time (int): How many seconds to display the image. Defaults to 0.
out_file (Optional[str]): The filename to write the image.
Defaults to None.

Returns:
np.ndarray: The image with extra infomations.
"""
img = mmcv.imread(img).astype(np.uint8)

x, y = 3, row_width // 2
text_color = color_val_matplotlib(text_color)

img = mmcv.bgr2rgb(img)
width, height = img.shape[1], img.shape[0]
img = np.ascontiguousarray(img)

# A proper dpi for image save with default font size.
fig = plt.figure(win_name, frameon=False, figsize=fig_size, dpi=36)
plt.title(win_name)
canvas = fig.canvas
dpi = fig.get_dpi()
# add a small EPS to avoid precision lost due to matplotlib's truncation
# (https://github.com/matplotlib/matplotlib/issues/15363)
fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)

# remove white edges by set subplot margin
plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
ax = plt.gca()
ax.axis('off')

for k, v in infos.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
ax.text(
x,
y,
f'{label_text}',
bbox={
'facecolor': 'black',
'alpha': 0.7,
'pad': 0.2,
'edgecolor': 'none',
'boxstyle': 'round'
},
color=text_color,
fontsize=font_size,
family='monospace',
verticalalignment='top',
horizontalalignment='left')
y += row_width

plt.imshow(img)
stream, _ = canvas.print_to_buffer()
buffer = np.frombuffer(stream, dtype='uint8')
img_rgba = buffer.reshape(height, width, 4)
rgb, _ = np.split(img_rgba, [3], axis=2)
img = rgb.astype('uint8')
img = mmcv.rgb2bgr(img)

if show:
# Matplotlib will adjust text size depends on window size and image
# aspect ratio. It's hard to get, so here we set an adaptive dpi
# according to screen height. 20 here is an empirical parameter.
fig_manager = plt.get_current_fig_manager()
if hasattr(fig_manager, 'window'):
# Figure manager doesn't have window if no screen.
screen_dpi = fig_manager.window.winfo_screenheight() // 20
fig.set_dpi(screen_dpi)

# We do not use cv2 for display because in some cases, opencv will
# conflict with Qt, it will output a warning: Current thread
# is not the object's thread. You can refer to
# https://github.com/opencv/opencv-python/issues/46 for details
if wait_time == 0:
plt.show()
else:
plt.show(block=False)
plt.pause(wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmcv.imwrite(image, out_file, auto_mkdir=True) can suto make a dir if users forget to makedir.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value of auto_mkdir should be True?


plt.close()

return img
51 changes: 21 additions & 30 deletions mmcls/models/classifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict

import cv2
import mmcv
import torch
import torch.distributed as dist
from mmcv import color_val
from mmcv.runner import BaseModule

from mmcls.core.visualization import imshow_infos

# TODO import `auto_fp16` from mmcv and delete them from mmcls
try:
from mmcv.runner import auto_fp16
Expand Down Expand Up @@ -168,10 +168,11 @@ def val_step(self, data, optimizer):
def show_result(self,
img,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

api expresses that img can be a (str or ndarray), but here img can only be the image path.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmcv.imread supports np.ndarray.

result,
text_color='green',
text_color='white',
font_scale=0.5,
row_width=20,
show=False,
fig_size=(15, 10),
win_name='',
wait_time=0,
out_file=None):
Expand All @@ -185,39 +186,29 @@ def show_result(self,
row_width (int): width between each row of results on the image.
show (bool): Whether to show the image.
Default: False.
fig_size (tuple): Image show figure size. Defaults to (15, 10).
win_name (str): The window name.
wait_time (int): Value of waitKey param.
Default: 0.
wait_time (int): How many seconds to display the image.
Defaults to 0.
out_file (str or None): The filename to write the image.
Default: None.

Returns:
img (ndarray): Only if not `show` or `out_file`
img (ndarray): Image with overlayed results.
"""
img = mmcv.imread(img)
img = img.copy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this line is redundant

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line prevents this function from modifying the original image.


# write results on left-top of the image
x, y = 0, row_width
text_color = color_val(text_color)
for k, v in result.items():
if isinstance(v, float):
v = f'{v:.2f}'
label_text = f'{k}: {v}'
cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
font_scale, text_color)
y += row_width

# if out_file specified, do not show image in window
if out_file is not None:
show = False

if show:
mmcv.imshow(img, win_name, wait_time)
if out_file is not None:
mmcv.imwrite(img, out_file)

if not (show or out_file):
warnings.warn('show==False and out_file is not specified, only '
'result image will be returned')
return img
img = imshow_infos(
img,
result,
text_color=text_color,
font_size=int(font_scale * 50),
row_width=row_width,
win_name=win_name,
show=show,
fig_size=fig_size,
wait_time=wait_time,
out_file=out_file)

return img
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmcls
known_third_party = PIL,cv2,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts
known_third_party = PIL,matplotlib,mmcv,mmdet,numpy,onnxruntime,packaging,pytest,seaborn,torch,torchvision,ts
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY
11 changes: 0 additions & 11 deletions tests/test_models/test_classifiers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os.path as osp
import tempfile
from copy import deepcopy
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -84,16 +83,6 @@ def test_image_classifier():
model.show_result(img, result, out_file=out_file)
assert osp.exists(out_file)

def save_show(_, *args):
out_path = osp.join(tmpdir, '_'.join([str(arg) for arg in args]))
with open(out_path, 'w') as f:
f.write('test')

with patch('mmcv.imshow', save_show):
model.show_result(
img, result, show=True, win_name='img', wait_time=5)
assert osp.exists(osp.join(tmpdir, 'img_5'))


def test_image_classifier_with_mixup():
# Test mixup in ImageClassifier
Expand Down
90 changes: 90 additions & 0 deletions tests/test_utils/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os
import os.path as osp
import shutil
import tempfile
from unittest.mock import Mock, patch

import mmcv
import numpy as np
import pytest

from mmcls.core import visualization as vis


def test_color():
assert vis.color_val_matplotlib(mmcv.Color.blue) == (0., 0., 1.)
assert vis.color_val_matplotlib('green') == (0., 1., 0.)
assert vis.color_val_matplotlib((1, 2, 3)) == (3 / 255, 2 / 255, 1 / 255)
assert vis.color_val_matplotlib(100) == (100 / 255, 100 / 255, 100 / 255)
assert vis.color_val_matplotlib(np.zeros(3, dtype=int)) == (0., 0., 0.)
# forbid white color
with pytest.raises(TypeError):
vis.color_val_matplotlib([255, 255, 255])
# forbid float
with pytest.raises(TypeError):
vis.color_val_matplotlib(1.0)
# overflowed
with pytest.raises(AssertionError):
vis.color_val_matplotlib((0, 0, 500))


def test_imshow_infos():
tmp_dir = osp.join(tempfile.gettempdir(), 'infos_image')
tmp_filename = osp.join(tmp_dir, 'image.jpg')

image = np.ones((10, 10, 3), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
out_image = vis.imshow_infos(
image, result, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)
assert image.shape == out_image.shape
assert not np.allclose(image, out_image)
os.remove(tmp_filename)

# test grayscale images
image = np.ones((10, 10), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}
out_image = vis.imshow_infos(
image, result, out_file=tmp_filename, show=False)
assert osp.isfile(tmp_filename)
assert image.shape == out_image.shape[:2]
os.remove(tmp_filename)

# test show=True
image = np.ones((10, 10, 3), np.uint8)
result = {'pred_label': 1, 'pred_class': 'bird', 'pred_score': 0.98}

def save_args(*args, **kwargs):
args_list = ['args']
args_list += [
str(arg) for arg in args if isinstance(arg, (str, bool, int))
]
args_list += [
f'{k}-{v}' for k, v in kwargs.items()
if isinstance(v, (str, bool, int))
]
out_path = osp.join(tmp_dir, '_'.join(args_list))
with open(out_path, 'w') as f:
f.write('test')

with patch('matplotlib.pyplot.show', save_args), \
patch('matplotlib.pyplot.pause', save_args):
vis.imshow_infos(image, result, show=True, wait_time=5)
assert osp.exists(osp.join(tmp_dir, 'args_block-False'))
assert osp.exists(osp.join(tmp_dir, 'args_5'))

vis.imshow_infos(image, result, show=True, wait_time=0)
assert osp.exists(osp.join(tmp_dir, 'args'))

# test adaptive dpi
def mock_fig_manager():
fig_manager = Mock()
fig_manager.window.winfo_screenheight = Mock(return_value=1440)
return fig_manager

with patch('matplotlib.pyplot.get_current_fig_manager',
mock_fig_manager), patch('matplotlib.pyplot.show'):
vis.imshow_infos(image, result, show=True)

shutil.rmtree(tmp_dir)