Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

use scales (tuple of floats) in region_proposal_network #729

Merged
merged 6 commits into from
Nov 7, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions chainercv/links/model/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ def n_class(self):
# Total number of classes including the background.
return self.head.n_class

def __call__(self, x, scale=1.):
def __call__(self, x, scales=None):
"""Forward Faster R-CNN.

Scaling paramter :obj:`scale` is used by RPN to determine the
Scaling paramter :obj:`scales` is used by RPN to determine the
threshold to select small objects, which are going to be
rejected irrespective of their confidence scores.

Expand All @@ -132,8 +132,8 @@ def __call__(self, x, scale=1.):

Args:
x (~chainer.Variable): 4D image variable.
scale (float): Amount of scaling applied to the raw image
during preprocessing.
scales (tuple of floats): Amount of scaling applied to each input
image during preprocessing.

Returns:
Variable, Variable, array, array:
Expand All @@ -153,7 +153,7 @@ def __call__(self, x, scale=1.):

h = self.extractor(x)
rpn_locs, rpn_scores, rois, roi_indices, anchor =\
self.rpn(h, img_size, scale)
self.rpn(h, img_size, scales)
roi_cls_locs, roi_scores = self.head(
h, rois, roi_indices)
return roi_cls_locs, roi_scores, rois, roi_indices
Expand Down Expand Up @@ -286,7 +286,7 @@ def predict(self, imgs):
img_var = chainer.Variable(self.xp.asarray(img[None]))
scale = img_var.shape[3] / size[1]
roi_cls_locs, roi_scores, rois, _ = self.__call__(
img_var, scale=scale)
img_var, scales=[scale])
# We are assuming that batch size is 1.
roi_cls_loc = roi_cls_locs.array
roi_score = roi_scores.array
Expand Down
14 changes: 7 additions & 7 deletions chainercv/links/model/faster_rcnn/faster_rcnn_train_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, faster_rcnn, rpn_sigma=3., roi_sigma=1.,
self.loc_normalize_mean = faster_rcnn.loc_normalize_mean
self.loc_normalize_std = faster_rcnn.loc_normalize_std

def __call__(self, imgs, bboxes, labels, scale):
def __call__(self, imgs, bboxes, labels, scales):
"""Forward Faster R-CNN and calculate losses.

Here are notations used.
Expand All @@ -79,8 +79,8 @@ def __call__(self, imgs, bboxes, labels, scale):
the definition, which means that the range of the value
is :math:`[0, L - 1]`. :math:`L` is the number of foreground
classes.
scale (float or ~chainer.Variable): Amount of scaling applied to
the raw image during preprocessing.
scales (~chainer.Variable): Amount of scaling applied to
each input image during preprocessing.

Returns:
chainer.Variable:
Expand All @@ -93,9 +93,9 @@ def __call__(self, imgs, bboxes, labels, scale):
bboxes = bboxes.array
if isinstance(labels, chainer.Variable):
labels = labels.array
if isinstance(scale, chainer.Variable):
scale = scale.array
scale = np.asscalar(cuda.to_cpu(scale))
if isinstance(scales, chainer.Variable):
scales = scales.array
scales = cuda.to_cpu(scales)
n = bboxes.shape[0]
if n != 1:
raise ValueError('Currently only batch size 1 is supported.')
Expand All @@ -105,7 +105,7 @@ def __call__(self, imgs, bboxes, labels, scale):

features = self.faster_rcnn.extractor(imgs)
rpn_locs, rpn_scores, rois, roi_indices, anchor = self.faster_rcnn.rpn(
features, img_size, scale)
features, img_size, scales)

# Since batch size is one, convert variables to singular form
bbox = bboxes[0]
Expand Down
14 changes: 10 additions & 4 deletions chainercv/links/model/faster_rcnn/region_proposal_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __init__(
self.loc = L.Convolution2D(
mid_channels, n_anchor * 4, 1, 1, 0, initialW=initialW)

def __call__(self, x, img_size, scale=1.):
def __call__(self, x, img_size, scales=None):
"""Forward Region Proposal Network.

Here are notations.
Expand All @@ -82,8 +82,8 @@ def __call__(self, x, img_size, scale=1.):
Its shape is :math:`(N, C, H, W)`.
img_size (tuple of ints): A tuple :obj:`height, width`,
which contains image size after scaling.
scale (float): The amount of scaling done to the input images after
reading them from files.
scales (tuple of floats): The amount of scaling done to each input
image during preprocessing.

Returns:
(~chainer.Variable, ~chainer.Variable, array, array, array):
Expand All @@ -106,7 +106,13 @@ def __call__(self, x, img_size, scale=1.):
Its shape is :math:`(H W A, 4)`.

"""

n, _, hh, ww = x.shape
if scales is None:
scales = [1.0] * n
if not isinstance(scales, chainer.utils.collections_abc.Iterable):
scales = [scales] * n

anchor = _enumerate_shifted_anchor(
self.xp.array(self.anchor_base), self.feat_stride, hh, ww)
n_anchor = anchor.shape[0] // (hh * ww)
Expand All @@ -127,7 +133,7 @@ def __call__(self, x, img_size, scale=1.):
for i in range(n):
roi = self.proposal_layer(
rpn_locs[i].array, rpn_fg_scores[i].array, anchor, img_size,
scale=scale)
scale=scales[i])
batch_index = i * self.xp.ones((len(roi),), dtype=np.int32)
rois.append(roi)
roi_indices.append(batch_index)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def check_call(self):
xp = self.link.xp

x1 = chainer.Variable(_random_array(xp, (1, 3, 600, 800)))
roi_cls_locs, roi_scores, rois, roi_indices = self.link(x1)
scales = chainer.Variable(xp.array([1.], dtype=np.float32))
roi_cls_locs, roi_scores, rois, roi_indices = self.link(x1, scales)

self.assertIsInstance(roi_cls_locs, chainer.Variable)
self.assertIsInstance(roi_cls_locs.array, xp.ndarray)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@
from chainercv.links.model.faster_rcnn import RegionProposalNetwork


@testing.parameterize(
{'train': True},
{'train': False},
)
@testing.parameterize(*(testing.product({
'B': [1],
'train': [True, False],
'scales': [None, 1.0, 2.0, [1.0]],
}) + testing.product({
'B': [2],
'train': [True, False],
'scales': [None, 1.0, 2.0, [1.0, 2.0]],
})))
class TestRegionProposalNetwork(unittest.TestCase):

def setUp(self):
feat_stride = 4
self.B = 2
C = 16
H = 8
W = 12
Expand All @@ -37,10 +41,10 @@ def setUp(self):

chainer.config.train = self.train

def _check_call(self, x, img_size):
def _check_call(self, x, img_size, scales):
_, _, H, W = x.shape
rpn_locs, rpn_scores, rois, roi_indices, anchor = self.link(
chainer.Variable(x), img_size)
chainer.Variable(x), img_size, scales)
self.assertIsInstance(rpn_locs, chainer.Variable)
self.assertIsInstance(rpn_locs.array, type(x))
self.assertIsInstance(rpn_scores, chainer.Variable)
Expand Down Expand Up @@ -74,13 +78,13 @@ def _check_call(self, x, img_size):
self.assertEqual(anchor.shape, (A * H * W, 4))

def test_call_cpu(self):
self._check_call(self.x, self.img_size)
self._check_call(self.x, self.img_size, self.scales)

@attr.gpu
def test_call_gpu(self):
self.link.to_gpu()
self._check_call(
chainer.backends.cuda.to_gpu(self.x), self.img_size)
chainer.backends.cuda.to_gpu(self.x), self.img_size, self.scales)


testing.run_module(__name__, __file__)