Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Assert input type for transform functions #681

Closed
wants to merge 8 commits into from
Closed
8 changes: 4 additions & 4 deletions chainercv/experimental/links/model/pspnet/pspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,16 @@ def __init__(self, n_layer, initialW, bn_kwargs=None):
x, ksize=3, stride=2, pad=1)
self.res2 = ResBlock(
n_block[0], 128, 64, 256, 1, 1,
initialW, bn_kwargs, stride_first=False)
initialW=initialW, bn_kwargs=bn_kwargs, stride_first=False)
self.res3 = ResBlock(
n_block[1], 256, 128, 512, 2, 1,
initialW, bn_kwargs, stride_first=False)
initialW=initialW, bn_kwargs=bn_kwargs, stride_first=False)
self.res4 = ResBlock(
n_block[2], 512, 256, 1024, 1, 2,
initialW, bn_kwargs, stride_first=False)
initialW=initialW, bn_kwargs=bn_kwargs, stride_first=False)
self.res5 = ResBlock(
n_block[3], 1024, 512, 2048, 1, 4,
initialW, bn_kwargs, stride_first=False)
initialW=initialW, bn_kwargs=bn_kwargs, stride_first=False)


class PSPNet(chainer.Chain):
Expand Down
4 changes: 3 additions & 1 deletion chainercv/transforms/bbox/crop_bbox.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox


def crop_bbox(
bbox, y_slice=None, x_slice=None,
Expand Down Expand Up @@ -46,7 +48,7 @@ def crop_bbox(
bounding boxes.

"""

assert_is_bbox(bbox)
t, b = _slice_to_bounds(y_slice)
l, r = _slice_to_bounds(x_slice)
crop_bb = np.array((t, l, b, r))
Expand Down
4 changes: 4 additions & 0 deletions chainercv/transforms/bbox/flip_bbox.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox


def flip_bbox(bbox, size, y_flip=False, x_flip=False):
"""Flip bounding boxes accordingly.

Expand All @@ -23,6 +26,7 @@ def flip_bbox(bbox, size, y_flip=False, x_flip=False):
Bounding boxes flipped according to the given flips.

"""
assert_is_bbox(bbox, size)
H, W = size
bbox = bbox.copy()
if y_flip:
Expand Down
4 changes: 4 additions & 0 deletions chainercv/transforms/bbox/resize_bbox.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox


def resize_bbox(bbox, in_size, out_size):
"""Resize bounding boxes according to image resize.

Expand All @@ -21,6 +24,7 @@ def resize_bbox(bbox, in_size, out_size):
Bounding boxes rescaled according to the given image shapes.

"""
assert_is_bbox(bbox)
bbox = bbox.copy()
y_scale = float(out_size[0]) / in_size[0]
x_scale = float(out_size[1]) / in_size[1]
Expand Down
5 changes: 4 additions & 1 deletion chainercv/transforms/bbox/translate_bbox.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from chainercv.utils.testing.assertions.assert_is_bbox import assert_is_bbox


def translate_bbox(bbox, y_offset=0, x_offset=0):
"""Translate bounding boxes.

Expand All @@ -24,7 +27,7 @@ def translate_bbox(bbox, y_offset=0, x_offset=0):
Bounding boxes translated according to the given offsets.

"""

assert_is_bbox(bbox)
out_bbox = bbox.copy()
out_bbox[:, :2] += (y_offset, x_offset)
out_bbox[:, 2:] += (y_offset, x_offset)
Expand Down
4 changes: 4 additions & 0 deletions chainercv/transforms/image/center_crop.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image


def center_crop(img, size, return_param=False, copy=False):
"""Center crop an image by `size`.

Expand Down Expand Up @@ -35,6 +38,7 @@ def center_crop(img, size, return_param=False, copy=False):
out_img = img[:, y_slice, x_slice]

"""
assert_is_image(img, color=None)
_, H, W = img.shape
oH, oW = size
if oH > H or oW > W:
Expand Down
5 changes: 5 additions & 0 deletions chainercv/transforms/image/flip.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image


def flip(img, y_flip=False, x_flip=False, copy=False):
"""Flip an image in vertical or horizontal direction as specified.

Expand All @@ -11,6 +14,8 @@ def flip(img, y_flip=False, x_flip=False, copy=False):
Returns:
Transformed :obj:`img` in CHW format.
"""

assert_is_image(img, color=None)
if y_flip:
img = img[:, ::-1, :]
if x_flip:
Expand Down
3 changes: 3 additions & 0 deletions chainercv/transforms/image/pca_lighting.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from chainercv.utils.testing.assertions.assert_is_image import assert_is_image


def pca_lighting(img, sigma, eigen_value=None, eigen_vector=None):
"""AlexNet style color augmentation
Expand Down Expand Up @@ -30,6 +32,7 @@ def pca_lighting(img, sigma, eigen_value=None, eigen_vector=None):
An image in CHW format.
"""

assert_is_image(img, color=True)
if sigma <= 0:
return img

Expand Down
3 changes: 3 additions & 0 deletions chainercv/transforms/image/resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import PIL
import warnings

from chainercv.utils.testing.assertions.assert_is_image import assert_is_image


try:
import cv2
Expand Down Expand Up @@ -68,5 +70,6 @@ def resize(img, size, interpolation=PIL.Image.BILINEAR):
~numpy.ndarray: A resize array in CHW format.

"""
assert_is_image(img, color=None)
img = _resize(img, size, interpolation)
return img
2 changes: 2 additions & 0 deletions chainercv/transforms/image/resize_contain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import PIL

from chainercv.transforms import resize
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image


def resize_contain(img, size, fill=0, interpolation=PIL.Image.BILINEAR,
Expand Down Expand Up @@ -52,6 +53,7 @@ def resize_contain(img, size, fill=0, interpolation=PIL.Image.BILINEAR,
:obj:`height, width`.

"""
assert_is_image(img, color=None)
C, H, W = img.shape
out_H, out_W = size
scale_h = out_H / H
Expand Down
2 changes: 2 additions & 0 deletions chainercv/transforms/image/scale.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import PIL

from chainercv.transforms import resize
from chainercv.utils.testing.assertions.assert_is_image import assert_is_image


def scale(img, size, fit_short=True, interpolation=PIL.Image.BILINEAR):
Expand Down Expand Up @@ -30,6 +31,7 @@ def scale(img, size, fit_short=True, interpolation=PIL.Image.BILINEAR):
~numpy.ndarray: A scaled image in CHW format.

"""
assert_is_image(img, color=None)
_, H, W = img.shape

# If resizing is not necessary, return the input as is.
Expand Down
3 changes: 3 additions & 0 deletions chainercv/transforms/image/ten_crop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np

from chainercv.utils.testing.assertions.assert_is_image import assert_is_image


def ten_crop(img, size):
"""Crop 10 regions from an array.
Expand Down Expand Up @@ -31,6 +33,7 @@ def ten_crop(img, size):
The cropped arrays. The shape of tensor is :math:`(10, C, H, W)`.

"""
assert_is_image(img, color=None)
H, W = size
iH, iW = img.shape[1:3]

Expand Down
9 changes: 5 additions & 4 deletions chainercv/utils/testing/assertions/assert_is_bbox.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import numpy as np
from chainer.backends import cuda


def assert_is_bbox(bbox, size=None):
Expand All @@ -16,9 +16,10 @@ def assert_is_bbox(bbox, size=None):
Each bounding box should be within the image.
"""

assert isinstance(bbox, np.ndarray), \
'bbox must be a numpy.ndarray.'
assert bbox.dtype == np.float32, \
xp = cuda.get_array_module(bbox)
assert isinstance(bbox, xp.ndarray), \
'bbox must be a numpy.ndarray or cupy.ndarray.'
assert bbox.dtype == xp.float32, \
'The type of bbox must be numpy.float32,'
assert bbox.shape[1:] == (4,), \
'The shape of bbox must be (*, 4).'
Expand Down
8 changes: 5 additions & 3 deletions chainercv/utils/testing/assertions/assert_is_image.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import numpy as np
from chainer.backends import cuda


def assert_is_image(img, color=True, check_range=True):
Expand All @@ -22,13 +22,15 @@ def assert_is_image(img, color=True, check_range=True):

"""

assert isinstance(img, np.ndarray), 'img must be a numpy.ndarray.'
xp = cuda.get_array_module(img)
assert isinstance(img, xp.ndarray), \
'img must be a numpy.ndarray or cupy.ndarray.'
assert len(img.shape) == 3, 'img must be a 3-dimensional array.'
C, H, W = img.shape

if color:
assert C == 3, 'The number of channels must be 3.'
else:
elif color is False:
assert C == 1, 'The number of channels must be 1.'

if check_range:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def setUp(self):
)

def check_prepare(self):
x = _random_array(np, self.in_shape)
x = np.random.randint(
0, 256, size=self.in_shape).astype(np.float32)
out = self.link.prepare(x)
self.assertIsInstance(out, np.ndarray)
self.assertEqual(out.shape, self.expected_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def setUp(self):
)

def check_prepare(self):
x = _random_array(np, self.in_shape)
x = np.random.randint(
0, 256, size=self.in_shape).astype(np.float32)
out = self.link.prepare(x)
self.assertIsInstance(out, np.ndarray)
self.assertEqual(out.shape, self.expected_shape)
Expand Down
6 changes: 3 additions & 3 deletions tests/transforms_tests/bbox_tests/test_crop_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def setUp(self):
(0, 5, 3, 6),
(1, 2, 3, 4),
(3, 3, 4, 6),
))
), dtype=np.float32)
self.y_slice = slice(1, 5)
self.x_slice = slice(0, 4)

Expand All @@ -25,7 +25,7 @@ def test_crop_bbox(self):
(0, 0, 4, 4),
(0, 2, 2, 4),
(2, 3, 3, 4),
))
), dtype=np.float32)

out, param = crop_bbox(
self.bbox, y_slice=self.y_slice, x_slice=self.x_slice,
Expand All @@ -38,7 +38,7 @@ def test_crop_bbox_disallow_outside_center(self):
(0, 0, 2, 4),
(0, 0, 4, 4),
(0, 2, 2, 4),
))
), dtype=np.float32)

out, param = crop_bbox(
self.bbox, y_slice=self.y_slice, x_slice=self.x_slice,
Expand Down
4 changes: 2 additions & 2 deletions tests/transforms_tests/bbox_tests/test_flip_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from chainer import testing
from chainercv.transforms import flip_bbox
from chainercv.utils.testing.generate_random_bbox import generate_random_bbox


class TestFlipBbox(unittest.TestCase):

def test_flip_bbox(self):
bbox = np.random.uniform(
low=0., high=32., size=(10, 4))
bbox = generate_random_bbox(10, (32, 32), 0, 32)

out = flip_bbox(bbox, size=(34, 32), y_flip=True)
bbox_expected = bbox.copy()
Expand Down
4 changes: 2 additions & 2 deletions tests/transforms_tests/bbox_tests/test_resize_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from chainer import testing
from chainercv.transforms import resize_bbox
from chainercv.utils.testing.generate_random_bbox import generate_random_bbox


class TestResizeBbox(unittest.TestCase):

def test_resize_bbox(self):
bbox = np.random.uniform(
low=0., high=32., size=(10, 4))
bbox = generate_random_bbox(10, (32, 32), 0, 32)

out = resize_bbox(bbox, in_size=(32, 32), out_size=(64, 128))
bbox_expected = bbox.copy()
Expand Down
4 changes: 2 additions & 2 deletions tests/transforms_tests/bbox_tests/test_translate_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from chainer import testing
from chainercv.transforms import translate_bbox
from chainercv.utils.testing.generate_random_bbox import generate_random_bbox


class TestTranslateBbox(unittest.TestCase):

def test_translate_bbox(self):
bbox = np.random.uniform(
low=0., high=32., size=(10, 4))
bbox = generate_random_bbox(10, (32, 32), 0, 32)

out = translate_bbox(bbox, y_offset=5, x_offset=3)
bbox_expected = np.empty_like(bbox)
Expand Down