Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Add Conv2DActiv #384

Merged
merged 4 commits into from
Aug 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions chainercv/links/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from chainercv.links.connection.conv_2d_activ import Conv2DActiv # NOQA

from chainercv.links.model.pixelwise_softmax_classifier import PixelwiseSoftmaxClassifier # NOQA
from chainercv.links.model.sequential_feature_extractor import SequentialFeatureExtractor # NOQA

Expand Down
1 change: 1 addition & 0 deletions chainercv/links/connection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from chainercv.links.connection.conv_2d_activ import Conv2DActiv # NOQA
73 changes: 73 additions & 0 deletions chainercv/links/connection/conv_2d_activ.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import chainer
from chainer.functions import relu
from chainer.links import Convolution2D


class Conv2DActiv(chainer.Chain):
"""Convolution2D --> Activation
This is a chain that does two-dimensional convolution
and applies an activation.
The arguments are the same as those of
:class:`chainer.links.Convolution2D`
except for :obj:`activ`.
Example:
There are sevaral ways to initialize a :class:`Conv2DActiv`.
1. Give the first three arguments explicitly:
>>> l = Conv2DActiv(5, 10, 3)
2. Omit :obj:`in_channels` or fill it with :obj:`None`:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of descriptions should be same as that of examples.
I mean, Fill :obj:in_channels with :obj:None: or omit it: is better. (Changing the order of examples is also OK)

In these ways, attributes are initialized at runtime based on
the channel size of the input.
>>> l = Conv2DActiv(10, 3)
>>> l = Conv2DActiv(None, 10, 3)
Args:
in_channels (int or None): Number of channels of input arrays.
If :obj:`None`, parameter initialization will be deferred until the
first forward data pass at which time the size will be determined.
out_channels (int): Number of channels of output arrays.
ksize (int or pair of ints): Size of filters (a.k.a. kernels).
:obj:`ksize=k` and :obj:`ksize=(k, k)` are equivalent.
stride (int or pair of ints): Stride of filter applications.
:obj:`stride=s` and :obj:`stride=(s, s)` are equivalent.
pad (int or pair of ints): Spatial padding width for input arrays.
:obj:`pad=p` and :obj:`pad=(p, p)` are equivalent.
nobias (bool): If :obj:`True`,
then this link does not use the bias term.
initialW (4-D array): Initial weight value. If :obj:`None`, the default
initializer is used.
May also be a callable that takes :obj:`numpy.ndarray` or
:obj:`cupy.ndarray` and edits its value.
initial_bias (1-D array): Initial bias value. If :obj:`None`, the bias
is set to 0.
May also be a callable that takes :obj:`numpy.ndarray` or
:obj:`cupy.ndarray` and edits its value.
activ (callable): An activation function. The default value is
:func:`chainer.functions.relu`.
"""

def __init__(self, in_channels, out_channels, ksize=None,
stride=1, pad=0, nobias=False, initialW=None,
initial_bias=None, activ=relu):
if ksize is None:
out_channels, ksize, in_channels = in_channels, out_channels, None

self.activ = activ
super(Conv2DActiv, self).__init__()
with self.init_scope():
self.conv = Convolution2D(
in_channels, out_channels, ksize, stride, pad,
nobias, initialW, initial_bias)

def __call__(self, x):
h = self.conv(x)
return self.activ(h)
7 changes: 7 additions & 0 deletions docs/source/reference/links.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,10 @@ Classifiers
.. toctree::

links/classifier


Connection
----------

.. toctree::
links/connection
9 changes: 9 additions & 0 deletions docs/source/reference/links/connection.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Connection
==========

.. module:: chainercv.links.connection


Conv2DActiv
-----------
.. autoclass:: Conv2DActiv
98 changes: 98 additions & 0 deletions tests/links_tests/connection_tests/test_conv_2d_activ.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import unittest

import numpy as np

import chainer
from chainer import cuda
from chainer.functions import relu
from chainer import testing
from chainer.testing import attr

from chainercv.links import Conv2DActiv


def _add_one(x):
return x + 1


@testing.parameterize(*testing.product({
'args_style': ['explicit', 'None', 'omit'],
'activ': ['relu', 'add_one']
}))
class TestConv2DActiv(unittest.TestCase):

in_channels = 1
out_channels = 1
ksize = 3
stride = 1
pad = 1

def setUp(self):
if self.activ == 'relu':
activ = relu
elif self.activ == 'add_one':
activ = _add_one
self.x = np.random.uniform(
-1, 1, (5, self.in_channels, 5, 5)).astype(np.float32)
self.gy = np.random.uniform(
-1, 1, (5, self.out_channels, 5, 5)).astype(np.float32)

# Convolution is the identity function.
initialW = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]],
dtype=np.float32).reshape(1, 1, 3, 3)
initial_bias = 0
if self.args_style == 'explicit':
self.l = Conv2DActiv(
self.in_channels, self.out_channels, self.ksize,
self.stride, self.pad,
initialW=initialW, initial_bias=initial_bias,
activ=activ)
elif self.args_style == 'None':
self.l = Conv2DActiv(
None, self.out_channels, self.ksize, self.stride, self.pad,
initialW=initialW, initial_bias=initial_bias,
activ=activ)
elif self.args_style == 'omit':
self.l = Conv2DActiv(
self.out_channels, self.ksize, stride=self.stride,
pad=self.pad, initialW=initialW, initial_bias=initial_bias,
activ=activ)

def check_forward(self, x_data):
x = chainer.Variable(x_data)
y = self.l(x)

self.assertIsInstance(y, chainer.Variable)
self.assertIsInstance(y.data, self.l.xp.ndarray)

if self.activ == 'relu':
np.testing.assert_almost_equal(
cuda.to_cpu(y.data), np.maximum(cuda.to_cpu(x_data), 0))
elif self.activ == 'add_one':
np.testing.assert_almost_equal(
cuda.to_cpu(y.data), cuda.to_cpu(x_data) + 1)

def test_forward_cpu(self):
self.check_forward(self.x)

@attr.gpu
def test_forward_gpu(self):
self.l.to_gpu()
self.check_forward(cuda.to_gpu(self.x))

def check_backward(self, x_data, y_grad):
x = chainer.Variable(x_data)
y = self.l(x)
y.grad = y_grad
y.backward()

def test_backward_cpu(self):
self.check_backward(self.x, self.gy)

@attr.gpu
def test_backward_gpu(self):
self.l.to_gpu()
self.check_backward(cuda.to_gpu(self.x), cuda.to_gpu(self.gy))


testing.run_module(__name__, __file__)