Skip to content
This repository has been archived by the owner on Feb 6, 2023. It is now read-only.

Visualize images without Trainer #206

Merged
merged 13 commits into from
Dec 4, 2018
1 change: 0 additions & 1 deletion chainerui/extensions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
from chainerui.extensions.commands_extension import CommandsExtension # NOQA
from chainerui.extensions.image_reporter_extension import ImageReport # NOQA
109 changes: 0 additions & 109 deletions chainerui/extensions/image_reporter_extension.py

This file was deleted.

Empty file added chainerui/report/__init__.py
Empty file.
112 changes: 112 additions & 0 deletions chainerui/report/image_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import datetime
import hashlib
import os
import time
import warnings

import chainer
from chainer import cuda
import numpy


def _setup_image_module():
try:
from PIL import Image # NOQA
return Image

except (ImportError, TypeError):
return None


_Image = _setup_image_module()
_available = _Image is not None


def check_available():
if not _available:
warnings.warn('Pillow is not installed on your environment, '
'so no image will be output at this time.'
'Please install Pillow to save images.\n\n'
' $ pip install Pillow\n')
return _available


def report(images, out, name, ch_axis=1, row=0, mode=None, batched=True):
if isinstance(images, chainer.Variable):
images = images.data
images = cuda.to_cpu(images)
if batched:
stuck_image = _get_stuck_batched_image(images, ch_axis, row)
else:
stuck_image = _get_stuck_image(images, ch_axis)

now = datetime.datetime.now()
# ts = now.timestamp(), but Python2.7 does not support the method.
ts = time.mktime(now.timetuple()) + now.microsecond/1e6
filename = '{}_{}.png'.format(name, _get_hash('{}'.format(ts)))
filepath = os.path.join(out, filename)
_save_image(_normalize_8bit(stuck_image), filepath, mode=mode)

return filename, now


def _get_stuck_batched_image(images, ch_axis, row):
ndim = images.ndim
if not (ndim == 3 or ndim == 4):
raise ValueError(
'Number of array dimension {:d} must be 3 or 4'.format(ndim))

if ndim == 4:
images = _move_ch_to_last(images, ch_axis)
B, H, W, C = images.shape
if row == 0:
row = B
col = B // row
images = images.reshape(row, col, H, W, C)
images = images.transpose(0, 2, 1, 3, 4)
return images.reshape(row * H, col * W, C)

# ndim == 3
B, H, W = images.shape
if row == 0:
row = B
col = B // row
images = images.reshape(row, col, H, W)
images = images.transpose(0, 2, 1, 3)
return images.reshape(row * H, col * W)


def _get_stuck_image(images, ch_axis):
ndim = images.ndim
if not (ndim == 2 or ndim == 3):
raise ValueError(
'Number of array dimension {:d} must be 2 or 3'.format(ndim))
if ndim == 2:
return images

# ndim == 3
return _move_ch_to_last(images, ch_axis)


def _move_ch_to_last(x, axis):
if axis == -1:
return x
rolled_ax = numpy.append(numpy.delete(numpy.arange(x.ndim), axis), axis)
return x.transpose(rolled_ax)


def _get_hash(key):
return hashlib.md5(key.encode('utf-8')).hexdigest()[:12]


def _normalize_8bit(array):
if array.dtype == numpy.uint8:
return array
return numpy.asarray(numpy.clip(array*255, 0.0, 255.0), dtype=numpy.uint8)


def _save_image(img, name, ext='PNG', mode=None):
if mode is None:
_Image.fromarray(img).save(name, format=ext)
elif mode.lower() == 'hsv':
_Image.fromarray(img, mode='HSV').convert('RGB').save(name, format=ext)
Loading