Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

add scale_mask #799

Merged
merged 1 commit into from
Feb 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chainercv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from chainercv.utils.link import prepare_pretrained_model # NOQA
from chainercv.utils.mask.mask_iou import mask_iou # NOQA
from chainercv.utils.mask.mask_to_bbox import mask_to_bbox # NOQA
from chainercv.utils.mask.scale_mask import scale_mask # NOQA
from chainercv.utils.testing import assert_is_bbox # NOQA
from chainercv.utils.testing import assert_is_bbox_dataset # NOQA
from chainercv.utils.testing import assert_is_detection_link # NOQA
Expand Down
70 changes: 70 additions & 0 deletions chainercv/utils/mask/scale_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import division

import numpy as np
import PIL.Image

import chainer
from chainercv import transforms


def scale_mask(mask, bbox, size):
"""Scale instance segmentation mask while keeping the aspect ratio.
This function exploits the sparsity of :obj:`mask` to speed up
resize operation.
The input image will be resized so that
the shorter edge will be scaled to length :obj:`size` after
resizing.
Args:
mask (array): An array whose shape is :math:`(R, H, W)`.
:math:`R` is the number of masks.
The dtype should be :obj:`numpy.bool`.
bbox (array): The bounding boxes around the masked region
of :obj:`mask`. This is expected to be the value
obtained by :obj:`bbox = chainercv.utils.mask_to_bbox(mask)`.
size (int): The length of the smaller edge.
Returns:
array:
An array whose shape is :math:`(R, H, W)`.
:math:`R` is the number of masks.
The dtype should be :obj:`numpy.bool`.
"""
xp = chainer.backends.cuda.get_array_module(mask)
mask = chainer.cuda.to_cpu(mask)
bbox = chainer.cuda.to_cpu(bbox)

R, H, W = mask.shape
if H < W:
out_size = (size, int(size * W / H))
scale = size / H
else:
out_size = (int(size * H / W), size)
scale = size / W

bbox[:, :2] = np.floor(bbox[:, :2])
bbox[:, 2:] = np.ceil(bbox[:, 2:])
bbox = bbox.astype(np.int32)
scaled_bbox = bbox * scale
scaled_bbox[:, :2] = np.floor(scaled_bbox[:, :2])
scaled_bbox[:, 2:] = np.ceil(scaled_bbox[:, 2:])
scaled_bbox = scaled_bbox.astype(np.int32)

out_mask = xp.zeros((R,) + out_size, dtype=np.bool)
for i, (m, bb, scaled_bb) in enumerate(
zip(mask, bbox, scaled_bbox)):
cropped_m = m[bb[0]:bb[2], bb[1]:bb[3]]
h = scaled_bb[2] - scaled_bb[0]
w = scaled_bb[3] - scaled_bb[1]
cropped_m = transforms.resize(
cropped_m[None].astype(np.float32),
(h, w),
interpolation=PIL.Image.NEAREST)[0]
if xp != np:
cropped_m = xp.array(cropped_m)
out_mask[i, scaled_bb[0]:scaled_bb[2],
scaled_bb[1]:scaled_bb[3]] = cropped_m
return out_mask
4 changes: 4 additions & 0 deletions docs/source/reference/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ mask_to_bbox
~~~~~~~~~~~~
.. autofunction:: mask_to_bbox

scale_mask
~~~~~~~~~~
.. autofunction:: scale_mask


Testing Utilities
-----------------
Expand Down
78 changes: 78 additions & 0 deletions tests/utils_tests/mask_tests/test_scale_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import division

import unittest

import numpy as np
import PIL.Image

from chainer.backends import cuda
from chainer import testing
from chainer.testing import attr

from chainercv.transforms import resize
from chainercv.utils import generate_random_bbox
from chainercv.utils import mask_to_bbox
from chainercv.utils import scale_mask


@testing.parameterize(
{'mask': np.array(
[[[False, False],
[False, True]]]),
'expected': np.array(
[[[False, False, False, False],
[False, False, False, False],
[False, False, True, True],
[False, False, True, True]]])
}
)
class TestScaleMaskSimple(unittest.TestCase):

def check(self, mask, expected):
in_type = type(mask)
bbox = mask_to_bbox(mask)
size = 4
out_mask = scale_mask(mask, bbox, size)

self.assertIsInstance(out_mask, in_type)
self.assertEqual(out_mask.dtype, np.bool)

np.testing.assert_equal(
cuda.to_cpu(out_mask),
cuda.to_cpu(expected))

def test_scale_mask_simple_cpu(self):
self.check(self.mask, self.expected)

@attr.gpu
def test_scale_mask_simple_gpu(self):
self.check(cuda.to_gpu(self.mask), cuda.to_gpu(self.expected))


class TestScaleMaskCompareResize(unittest.TestCase):

def test(self):
H = 80
W = 90
n_inst = 10

mask = np.zeros((n_inst, H, W), dtype=np.bool)
bbox = generate_random_bbox(n_inst, (H, W), 10, 30).astype(np.int32)
for i, bb in enumerate(bbox):
y_min, x_min, y_max, x_max = bb
m = np.random.randint(0, 2, size=(y_max - y_min, x_max - x_min))
m[5, 5] = 1 # At least one element is one
mask[i, y_min:y_max, x_min:x_max] = m
bbox = mask_to_bbox(mask)
size = H * 2
out_H = size
out_W = W * 2
out_mask = scale_mask(mask, bbox, size)

expected = resize(
mask.astype(np.float32), (out_H, out_W),
interpolation=PIL.Image.NEAREST).astype(np.bool)
np.testing.assert_equal(out_mask, expected)


testing.run_module(__name__, __file__)