Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
Merge pull request #685 from Hakuyume/fpn
Browse files Browse the repository at this point in the history
Feature Pyramid Networks
  • Loading branch information
yuyu2172 authored Feb 6, 2019
2 parents 9629c30 + 0f03307 commit a029d2f
Show file tree
Hide file tree
Showing 25 changed files with 2,166 additions and 3 deletions.
2 changes: 2 additions & 0 deletions chainercv/links/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from chainercv.links.model.pixelwise_softmax_classifier import PixelwiseSoftmaxClassifier # NOQA

from chainercv.links.model.faster_rcnn.faster_rcnn_vgg import FasterRCNNVGG16 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet101 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet50 # NOQA
from chainercv.links.model.resnet import ResNet101 # NOQA
from chainercv.links.model.resnet import ResNet152 # NOQA
from chainercv.links.model.resnet import ResNet50 # NOQA
Expand Down
9 changes: 9 additions & 0 deletions chainercv/links/model/fpn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from chainercv.links.model.fpn.faster_rcnn import FasterRCNN # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet101 # NOQA
from chainercv.links.model.fpn.faster_rcnn_fpn_resnet import FasterRCNNFPNResNet50 # NOQA
from chainercv.links.model.fpn.fpn import FPN # NOQA
from chainercv.links.model.fpn.head import Head # NOQA
from chainercv.links.model.fpn.head import head_loss_post # NOQA
from chainercv.links.model.fpn.head import head_loss_pre # NOQA
from chainercv.links.model.fpn.rpn import RPN # NOQA
from chainercv.links.model.fpn.rpn import rpn_loss # NOQA
172 changes: 172 additions & 0 deletions chainercv/links/model/fpn/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from __future__ import division

import numpy as np

import chainer
from chainer.backends import cuda

from chainercv import transforms


class FasterRCNN(chainer.Chain):
"""Base class of Feature Pyramid Networks.
This is a base class of Feature Pyramid Networks [#]_.
.. [#] Tsung-Yi Lin et al.
Feature Pyramid Networks for Object Detection. CVPR 2017
Args:
extractor (Link): A link that extracts feature maps.
This link must have :obj:`scales`, :obj:`mean` and
:meth:`__call__`.
rpn (Link): A link that has the same interface as
:class:`~chainercv.links.model.fpn.RPN`.
Please refer to the documentation found there.
head (Link): A link that has the same interface as
:class:`~chainercv.links.model.fpn.Head`.
Please refer to the documentation found there.
Parameters:
nms_thresh (float): The threshold value
for :func:`~chainercv.utils.non_maximum_suppression`.
The default value is :obj:`0.45`.
This value can be changed directly or by using :meth:`use_preset`.
score_thresh (float): The threshold value for confidence score.
If a bounding box whose confidence score is lower than this value,
the bounding box will be suppressed.
The default value is :obj:`0.6`.
This value can be changed directly or by using :meth:`use_preset`.
"""

_min_size = 800
_max_size = 1333
_stride = 32

def __init__(self, extractor, rpn, head):
super(FasterRCNN, self).__init__()
with self.init_scope():
self.extractor = extractor
self.rpn = rpn
self.head = head

self.use_preset('visualize')

def use_preset(self, preset):
"""Use the given preset during prediction.
This method changes values of :obj:`nms_thresh` and
:obj:`score_thresh`. These values are a threshold value
used for non maximum suppression and a threshold value
to discard low confidence proposals in :meth:`predict`,
respectively.
If the attributes need to be changed to something
other than the values provided in the presets, please modify
them by directly accessing the public attributes.
Args:
preset ({'visualize', 'evaluate'}): A string to determine the
preset to use.
"""

if preset == 'visualize':
self.nms_thresh = 0.5
self.score_thresh = 0.7
elif preset == 'evaluate':
self.nms_thresh = 0.5
self.score_thresh = 0.05
else:
raise ValueError('preset must be visualize or evaluate')

def __call__(self, x):
assert(not chainer.config.train)
hs = self.extractor(x)
rpn_locs, rpn_confs = self.rpn(hs)
anchors = self.rpn.anchors(h.shape[2:] for h in hs)
rois, roi_indices = self.rpn.decode(
rpn_locs, rpn_confs, anchors, x.shape)
rois, roi_indices = self.head.distribute(rois, roi_indices)
head_locs, head_confs = self.head(hs, rois, roi_indices)
return rois, roi_indices, head_locs, head_confs

def predict(self, imgs):
"""Detect objects from images.
This method predicts objects for each image.
Args:
imgs (iterable of numpy.ndarray): Arrays holding images.
All images are in CHW and RGB format
and the range of their value is :math:`[0, 255]`.
Returns:
tuple of lists:
This method returns a tuple of three lists,
:obj:`(bboxes, labels, scores)`.
* **bboxes**: A list of float arrays of shape :math:`(R, 4)`, \
where :math:`R` is the number of bounding boxes in a image. \
Each bounding box is organized by \
:math:`(y_{min}, x_{min}, y_{max}, x_{max})` \
in the second axis.
* **labels** : A list of integer arrays of shape :math:`(R,)`. \
Each value indicates the class of the bounding box. \
Values are in range :math:`[0, L - 1]`, where :math:`L` is the \
number of the foreground classes.
* **scores** : A list of float arrays of shape :math:`(R,)`. \
Each value indicates how confident the prediction is.
"""

sizes = [img.shape[1:] for img in imgs]
x, scales = self.prepare(imgs)

with chainer.using_config('train', False), chainer.no_backprop_mode():
rois, roi_indices, head_locs, head_confs = self(x)
bboxes, labels, scores = self.head.decode(
rois, roi_indices, head_locs, head_confs,
scales, sizes, self.nms_thresh, self.score_thresh)

bboxes = [cuda.to_cpu(bbox) for bbox in bboxes]
labels = [cuda.to_cpu(label) for label in labels]
scores = [cuda.to_cpu(score) for score in scores]
return bboxes, labels, scores

def prepare(self, imgs):
"""Preprocess images.
Args:
imgs (iterable of numpy.ndarray): Arrays holding images.
All images are in CHW and RGB format
and the range of their value is :math:`[0, 255]`.
Returns:
Two arrays: preprocessed images and \
scales that were caluclated in prepocessing.
"""

scales = []
resized_imgs = []
for img in imgs:
_, H, W = img.shape
scale = self._min_size / min(H, W)
if scale * max(H, W) > self._max_size:
scale = self._max_size / max(H, W)
scales.append(scale)
H, W = int(H * scale), int(W * scale)
img = transforms.resize(img, (H, W))
img -= self.extractor.mean
resized_imgs.append(img)

size = np.array([im.shape[1:] for im in resized_imgs]).max(axis=0)
size = (np.ceil(size / self._stride) * self._stride).astype(int)
x = np.zeros((len(imgs), 3, size[0], size[1]), dtype=np.float32)
for i, img in enumerate(resized_imgs):
_, H, W = img.shape
x[i, :, :H, :W] = img

x = self.xp.array(x)
return x, scales
145 changes: 145 additions & 0 deletions chainercv/links/model/fpn/faster_rcnn_fpn_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import division

import chainer
import chainer.functions as F
import chainer.links as L

from chainercv.links.model.fpn.faster_rcnn import FasterRCNN
from chainercv.links.model.fpn.fpn import FPN
from chainercv.links.model.fpn.head import Head
from chainercv.links.model.fpn.rpn import RPN
from chainercv.links.model.resnet import ResNet101
from chainercv.links.model.resnet import ResNet50
from chainercv import utils


class FasterRCNNFPNResNet(FasterRCNN):
"""Base class for FasterRCNNFPNResNet50 and FasterRCNNFPNResNet101.
A subclass of this class should have :obj:`_base` and :obj:`_models`.
"""

def __init__(self, n_fg_class=None, pretrained_model=None):
param, path = utils.prepare_pretrained_model(
{'n_fg_class': n_fg_class}, pretrained_model, self._models)

base = self._base(n_class=1, arch='he')
base.pick = ('res2', 'res3', 'res4', 'res5')
base.pool1 = lambda x: F.max_pooling_2d(
x, 3, stride=2, pad=1, cover_all=False)
base.remove_unused()
extractor = FPN(
base, len(base.pick), (1 / 4, 1 / 8, 1 / 16, 1 / 32, 1 / 64))

super(FasterRCNNFPNResNet, self).__init__(
extractor=extractor,
rpn=RPN(extractor.scales),
head=Head(param['n_fg_class'] + 1, extractor.scales),
)

if path == 'imagenet':
_copyparams(
self.extractor.base,
self._base(pretrained_model='imagenet', arch='he'))
elif path:
chainer.serializers.load_npz(path, self)


class FasterRCNNFPNResNet50(FasterRCNNFPNResNet):
"""Feature Pyramid Networks with ResNet-50.
This is a model of Feature Pyramid Networks [#]_.
This model uses :class:`~chainercv.links.ResNet50` as
its base feature extractor.
.. [#] Tsung-Yi Lin et al.
Feature Pyramid Networks for Object Detection. CVPR 2017
Args:
n_fg_class (int): The number of classes excluding the background.
pretrained_model (string): The weight file to be loaded.
This can take :obj:`'coco'`, `filepath` or :obj:`None`.
The default value is :obj:`None`.
* :obj:`'coco'`: Load weights trained on train split of \
MS COCO 2017. \
The weight file is downloaded and cached automatically. \
:obj:`n_fg_class` must be :obj:`80` or :obj:`None`.
* :obj:`'imagenet'`: Load weights of ResNet-50 trained on \
ImageNet. \
The weight file is downloaded and cached automatically. \
This option initializes weights partially and the rests are \
initialized randomly. In this case, :obj:`n_fg_class` \
can be set to any number.
* `filepath`: A path of npz file. In this case, :obj:`n_fg_class` \
must be specified properly.
* :obj:`None`: Do not load weights.
"""

_base = ResNet50
_models = {
'coco': {
'param': {'n_fg_class': 80},
'url': 'https://chainercv-models.preferred.jp/'
'faster_rcnn_fpn_resnet50_coco_trained_2018_12_13.npz',
'cv2': True
},
}


class FasterRCNNFPNResNet101(FasterRCNNFPNResNet):
"""Feature Pyramid Networks with ResNet-101.
This is a model of Feature Pyramid Networks [#]_.
This model uses :class:`~chainercv.links.ResNet101` as
its base feature extractor.
.. [#] Tsung-Yi Lin et al.
Feature Pyramid Networks for Object Detection. CVPR 2017
Args:
n_fg_class (int): The number of classes excluding the background.
pretrained_model (string): The weight file to be loaded.
This can take :obj:`'coco'`, `filepath` or :obj:`None`.
The default value is :obj:`None`.
* :obj:`'coco'`: Load weights trained on train split of \
MS COCO 2017. \
The weight file is downloaded and cached automatically. \
:obj:`n_fg_class` must be :obj:`80` or :obj:`None`.
* :obj:`'imagenet'`: Load weights of ResNet-101 trained on \
ImageNet. \
The weight file is downloaded and cached automatically. \
This option initializes weights partially and the rests are \
initialized randomly. In this case, :obj:`n_fg_class` \
can be set to any number.
* `filepath`: A path of npz file. In this case, :obj:`n_fg_class` \
must be specified properly.
* :obj:`None`: Do not load weights.
"""

_base = ResNet101
_models = {
'coco': {
'param': {'n_fg_class': 80},
'url': 'https://chainercv-models.preferred.jp/'
'faster_rcnn_fpn_resnet101_coco_trained_2018_12_13.npz',
'cv2': True
},
}


def _copyparams(dst, src):
if isinstance(dst, chainer.Chain):
for link in dst.children():
_copyparams(link, src[link.name])
elif isinstance(dst, chainer.ChainList):
for i, link in enumerate(dst):
_copyparams(link, src[i])
else:
dst.copyparams(src)
if isinstance(dst, L.BatchNormalization):
dst.avg_mean = src.avg_mean
dst.avg_var = src.avg_var
57 changes: 57 additions & 0 deletions chainercv/links/model/fpn/fpn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import chainer
import chainer.functions as F
from chainer import initializers
import chainer.links as L


class FPN(chainer.Chain):
"""An extractor class of Feature Pyramid Networks.
This class wraps a feature extractor and provides
multi-scale features.
Args:
base (Link): A base feature extractor.
It should have :meth:`__call__` and :obj:`mean`.
:meth:`__call__` should take a batch of images and return
feature maps of them. The size of the :math:`k+1`-th feature map
should be the half as that of the :math:`k`-th feature map.
n_base_output (int): The number of feature maps
that :obj:`base` returns.
scales (tuple of floats): The scales of feature maps.
"""

def __init__(self, base, n_base_output, scales):
super(FPN, self).__init__()
with self.init_scope():
self.base = base
self.inner = chainer.ChainList()
self.outer = chainer.ChainList()

init = {'initialW': initializers.GlorotNormal()}
for _ in range(n_base_output):
self.inner.append(L.Convolution2D(256, 1, **init))
self.outer.append(L.Convolution2D(256, 3, pad=1, **init))

self.scales = scales

@property
def mean(self):
return self.base.mean

def __call__(self, x):
hs = list(self.base(x))

for i in reversed(range(len(hs))):
hs[i] = self.inner[i](hs[i])
if i + 1 < len(hs):
hs[i] += F.unpooling_2d(hs[i + 1], 2, cover_all=False)

for i in range(len(hs)):
hs[i] = self.outer[i](hs[i])

while len(hs) < len(self.scales):
hs.append(F.max_pooling_2d(hs[-1], 1, stride=2, cover_all=False))

return hs
Loading

0 comments on commit a029d2f

Please sign in to comment.