Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic/advanced usage of backward #552

Merged
merged 1 commit into from
Dec 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 186 additions & 17 deletions python/src/nnabla/_variable.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ cdef class Variable:
the arithmetic operation returns an :class:`~nnabla.NdArray` which stores the
output of the computation immediately invoked. Otherwise, it returns
:class:`~nnabla.Variable` holds the graph connection. The computation
is invoked immediately when :function:`nnabla.auto_forward`
or :function:`nnabla.set_auto_forward(True)` is used.
is invoked immediately when `nnabla.auto_forward`
or `nnabla.set_auto_forward(True)` is used.

Note:
Relational operators :code:`==` and :code:`!=` of two :obj:`Variable` s are
Expand Down Expand Up @@ -348,7 +348,7 @@ cdef class Variable:
Args:
var (:obj:`nnabla.Variable`):
The array elements and the parent function of ``var`` is
copied to ```self`` as references.
copied to ``self`` as references.
Note that the parent function of ``var`` is removed.

Example:
Expand Down Expand Up @@ -422,7 +422,7 @@ cdef class Variable:
NNabla array.
This method can be called as a setter to set the value held by this variable.
Refer to the documentation of the setter `nnabla._nd_array.NdArray.data`
for detailed behvaiors of the setter.
for detailed behaviors of the setter.

Args:
value(:obj:`numpy.ndarray`) (optional)
Expand All @@ -445,7 +445,7 @@ cdef class Variable:
NNabla array.
This method can be called as a setter to set the gradient held by this variable.
Refer to the documentation of the setter `nnabla._nd_array.NdArray.data`
for detailed behvaiors of the setter.
for detailed behaviors of the setter.

Args:
value(:obj:`numpy.ndarray`)
Expand Down Expand Up @@ -546,16 +546,20 @@ cdef class Variable:
The propagation will stop at a variable with need_grad=False.

Args:
grad(scalar, :obj:`numpy.ndarray`, or :obj:`nnabla._nd_array.NdArray`):
grad(scalar, :obj:`numpy.ndarray`, :obj:`nnabla._nd_array.NdArray`, or None):
The gradient signal value(s) of this variable.
The default value 1 is used in an usual neural network
training. This option is useful if you have a gradient
computation module outside NNabla, and want to use it as a
gradient signal of the neural network built in NNabla.
Note that this doesn't modifies the grad values of this
variable.
The default value 1 is used in an usual neural network training.
This option is useful if you have a gradient computation module outside NNabla,
and want to use that result as a gradient signal.
Note that this doesn't modifies the grad values of this variable,
instead assign received values to its gradient temporarily.
Also, if the :class:`~nnabla.Variable` you want to execute
`nnabla._variable.Variable.backward` is an unlinked variable from another,
and the corresponding :class:`~nnabla.Variable` holds the pre-computed gradient values,
**You need to set grad=None**, otherwise, for that backward pass (propagated from the unlinked :class:`~nnabla.Variable`),
pre-computed gradient values are **ignored**.
clear_buffer(bool): Clears the no longer referenced variables
during backpropagation to save memory.
during backpropagation to save memory.
communicator_callbacks(:obj:`nnabla.CommunicatorBackwardCallback` or list of :obj:`nnabla.CommunicatorBackwardCallback`):
The callback functions invoked when 1) backward computation
of each function is finished and 2) all backward
Expand All @@ -569,6 +573,171 @@ cdef class Variable:
It must take :obj:`~nnabla.function.Function` as an input.
The default is None.


Example:

We first explain simple backward usage.

.. code-block:: python

import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import numpy as np
import nnabla.initializer as I

rng = np.random.seed(217)
initializer = I.UniformInitializer((-0.1, 0.1), rng=rng)

x = nn.Variable((8, 3, 32, 32))
x.d = np.random.random(x.shape) # random input, just for example.

y0 = PF.convolution(x, outmaps=64, kernel=(3, 3), pad=(1, 1), stride=(2, 2), w_init=initializer, name="conv1", with_bias=False)
y1 = F.relu(y0)
y2 = PF.convolution(y1, outmaps=128, kernel=(3, 3), pad=(1, 1), stride=(2, 2), w_init=initializer, name="conv2", with_bias=False)
y3 = F.relu(y2)
y4 = F.average_pooling(y3, kernel=y3.shape[2:])
y5 = PF.affine(y4, 1, w_init=initializer)
loss = F.mean(F.abs(y5 - 1.))
loss.forward() # Execute forward

# We can check the current gradient of parameter.
print(nn.get_parameters()["conv1/conv/W"].g)

Output :

.. code-block:: plaintext

[[[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
...

Initially all the gradient values should be zero.
Then let's see what happens after calling backward.

.. code-block:: python

loss.backward()
print(nn.get_parameters()["conv1/conv/W"].g)

Output :

.. code-block:: plaintext

[[[[ 0.00539637 0.00770839 0.0090611 ]
[ 0.0078223 0.00978992 0.00720569]
[ 0.00879023 0.00578172 0.00790895]]
...

Now we know the gradient values are computed and registered by calling `backward`.
Note that calling `backward` successively **accumulates** the result.
It means if we execute `backward` again, we get the doubled result.

.. code-block:: python

loss.backward() # execute again.
print(nn.get_parameters()["conv1/conv/W"].g)

We can see it's accumulated.

.. code-block:: plaintext

[[[[ 0.01079273 0.01541678 0.0181222 ]
[ 0.01564459 0.01957984 0.01441139]
[ 0.01758046 0.01156345 0.0158179 ]]
...

Next is an advanced usage with an unlinked variable (please refer to `get_unlinked_variable`).
We use the same network, but it is separated by the unlinked variable.

.. code-block:: python

import nnabla as nn
import nnabla.functions as F
import nnabla.parametric_functions as PF
import numpy as np
import nnabla.initializer as I

rng = np.random.seed(217) # use the same random seed.
initializer = I.UniformInitializer((-0.1, 0.1), rng=rng)

x = nn.Variable((8, 3, 32, 32))
x.d = np.random.random(x.shape) # random input, just for example.

y0 = PF.convolution(x, outmaps=64, kernel=(3, 3), pad=(1, 1), stride=(2, 2), w_init=initializer, name="conv1", with_bias=False)
y1 = F.relu(y0)
y2 = PF.convolution(y1, outmaps=128, kernel=(3, 3), pad=(1, 1), stride=(2, 2), w_init=initializer, name="conv2", with_bias=False)
y3 = F.relu(y2)
y3_unlinked = y3.get_unlinked_variable() # the computation graph is cut apart here.
y4 = F.average_pooling(y3_unlinked, kernel=y3_unlinked.shape[2:])
y5 = PF.affine(y4, 1, w_init=initializer)
loss = F.mean(F.abs(y5 - 1.))

# Execute forward.
y3.forward() # you need to execute forward at the unlinked variable first.
loss.forward() # Then execute forward at the leaf variable.

# Execute backward.
loss.backward() # works, but backpropagation stops at y3_unlinked.
print(nn.get_parameters()["conv1/conv/W"].g) # no gradient registered yet.

Output :

.. code-block:: plaintext

[[[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
...

We can confirm that backpropagation stops at `y3_unlinked`.
Then let's see how to execute backpropagation to the root variable (`x`).
Since it's a little bit complicated, let us give you an example of common pitfall first.
**Note that this is an incorrect way and intended just to show the backward's behavior.**

.. code-block:: python

y3.backward() # this works, but computed gradient values are not correct.
print(nn.get_parameters()["conv1/conv/W"].g)

Output :

.. code-block:: plaintext

[[[[ 17.795254 23.960905 25.51168 ]
[ 20.661646 28.484127 19.406212 ]
[ 26.91042 22.239697 23.395714 ]]
...

**Note that this is a wrong result.** The gradient held by `y3_unlinked` has been totally ignored.
As described above, just calling `backward`, the gradient (of the leaf variable where you call `backward`) is considered to be 1.

To execute backpropagation over 2 separate graphs **correctly**, We need to specify `grad=None` as shown below, then present gradient held by that variable is used for computation.
(`y3.backward(grad=y3_unlinked.g)` does the same thing.)

.. code-block:: python

#reset all the gradient values.
for v in nn.get_parameters().values():
v.g = 0.
for v in [y0, y1, y2, y3, y4, y5]:
v.g = 0. # need to reset all the gradient values.

loss.backward() # backpropagation starts from the leaf variable again.
y3.backward(grad=None) # By this, it can take over the gradient held by y3_unlinked.
print(nn.get_parameters()["conv1/conv/W"].g) # correct result.

This time you should have the same result.

.. code-block:: plaintext

[[[[ 0.00539637 0.00770839 0.0090611 ]
[ 0.0078223 0.00978992 0.00720569]
[ 0.00879023 0.00578172 0.00790895]]
...


"""
cdef NdArrayPtr p
if grad is None:
Expand Down Expand Up @@ -630,14 +799,14 @@ cdef class Variable:
It is recommended to explicitly specify this option to avoid an
unintended behavior.

Returns: nnabla._variable.Variable
Returns: :class:`~nnabla.Variable`


Note:
The unlinked Variable behaves equivalent to the original variable
in a comparison operator and hash function regardless whether or
not the `need_grad` attribute is changed.
See a note in the `Variable` class documentation.
See a note in the `Variable` class documentation. Also, for backward execution with unlinked variable(s), please refer to `backward` and its example.

Example:

Expand Down Expand Up @@ -727,7 +896,7 @@ cdef class Variable:
pred = PF.affine(h, 10, name="pred")
return pred

# You can modify this PrintFunc to get the other informations like inputs(nnabla_func.inputs), outputs and arguments(nnabla_func.info.args) of nnabla functions.
# You can modify this PrintFunc to get the other information like inputs(nnabla_func.inputs), outputs and arguments(nnabla_func.info.args) of nnabla functions.
class PrintFunc(object):
def __call__(self, nnabla_func):
print(nnabla_func.info.type_name)
Expand Down Expand Up @@ -816,7 +985,7 @@ cdef class Variable:

.. code-block:: python

nn.clear_parameters() # call this in case you want to run the following code agian
nn.clear_parameters() # call this in case you want to run the following code again
output = network_graph(x, add_avg_pool=False) # Exclusion of AveragePooling function in the graph
print("The return value of visit_check() method is : {}".format(output.visit_check(PrintFunc())))

Expand Down
3 changes: 3 additions & 0 deletions python/src/nnabla/models/imagenet/inception.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@
class InceptionV3(ImageNetBase):
'''
InceptionV3 architecture.

The following is a list of string that can be specified to ``use_up_to`` option in ``__call__`` method;

* ``'classifier'`` (default): The output of the final affine layer for classification.
* ``'pool'``: The output of the final global average pooling.
* ``'prepool'``: The input of the final global average pooling, i.e. the output of the final inception block.

References:
* `Szegedy et al., Rethinking the Inception Architecture for Computer Vision.
<https://arxiv.org/abs/1512.00567>`_
Expand Down
6 changes: 5 additions & 1 deletion python/src/nnabla/models/imagenet/shufflenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,18 @@

class ShuffleNet(ImageNetBase):
'''
Model for architecture ShuffleNet, ShuffleNet-0.5x and ShufffleNet-2.0x .
Model for architecture ShuffleNet, ShuffleNet-0.5x and ShufffleNet-2.0x.

Args:
Scaling Factor (str): To customize the network to a desired complexity, we can simply apply a scale factor on the number of channnels. This can be chosen from '10', '5' and '20'.

The following is a list of string that can be specified to ``use_up_to`` option in ``__call__`` method;

* ``'classifier'`` (default): The output of the final affine layer for classification.
* ``'pool'``: The output of the final global average pooling.
* ``'lastconv'``: The input of the final global average pooling without ReLU activation..
* ``'lastconv+relu'``: Network up to ``'lastconv'`` followed by ReLU activation.

References:
* `Zhang. et al., ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices.
<https://arxiv.org/abs/1707.01083>`_
Expand Down
2 changes: 1 addition & 1 deletion python/src/nnabla/models/imagenet/xception.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class Xception(ImageNetBase):

* ``'classifier'`` (default): The output of the final affine layer for classification.
* ``'pool'``: The output of the final global average pooling.
* ``'lastconv'``: The input of the final global average pooling without ReLU activation..
* ``'lastconv'``: The input of the final global average pooling without ReLU activation.
* ``'lastconv+relu'``: Network up to ``'lastconv'`` followed by ReLU activation.

References:
Expand Down
12 changes: 6 additions & 6 deletions python/src/nnabla/parametric_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1781,7 +1781,7 @@ def batch_normalization(inp, axes=[1], decay_rate=0.9, eps=1e-5,
be ``'beta'``, ``'gamma'``, ``'mean'`` or ``'var'``.
A value of the dict must be an :obj:`~nnabla.initializer.Initializer`
or a :obj:`numpy.ndarray`.
E.g. ``{'beta': ConstantIntializer(0), 'gamma': np.ones(gamma_shape) * 2}``.
E.g. ``{'beta': ConstantInitializer(0), 'gamma': np.ones(gamma_shape) * 2}``.
no_scale (bool): If `True`, the scale term is omitted.
no_bias (bool): If `True`, the bias term is omitted.

Expand Down Expand Up @@ -1867,7 +1867,7 @@ def sync_batch_normalization(inp, comm, group="world", axes=[1], decay_rate=0.9,
be ``'beta'``, ``'gamma'``, ``'mean'`` or ``'var'``.
A value of the dict must be an :obj:`~nnabla.initializer.Initializer`
or a :obj:`numpy.ndarray`.
E.g. ``{'beta': ConstantIntializer(0), 'gamma': np.ones(gamma_shape) * 2}``.
E.g. ``{'beta': ConstantInitializer(0), 'gamma': np.ones(gamma_shape) * 2}``.
no_scale (bool): If `True`, the scale term is omitted.
no_bias (bool): If `True`, the bias term is omitted.

Expand Down Expand Up @@ -1993,7 +1993,7 @@ def layer_normalization(inp, batch_axis=0, eps=1e-05, output_stat=False, fix_par
be ``'gamma'``, ``'beta'``.
A value of the dict must be an :obj:`~nnabla.initializer.Initializer`
or a :obj:`numpy.ndarray`.
E.g. ``{'gamma': np.ones(...) * 2, 'beta': ConstantIntializer(0)}``.
E.g. ``{'gamma': np.ones(...) * 2, 'beta': ConstantInitializer(0)}``.
no_scale (bool): If `True`, the scale term is omitted.
no_bias (bool): If `True`, the bias term is omitted.

Expand Down Expand Up @@ -2059,7 +2059,7 @@ def instance_normalization(inp, channel_axis=1, batch_axis=0, eps=1e-05, output_
be ``'gamma'``, ``'beta'``.
A value of the dict must be an :obj:`~nnabla.initializer.Initializer`
or a :obj:`numpy.ndarray`.
E.g. ``{'gamma': np.ones(...) * 2, 'beta': ConstantIntializer(0)}``.
E.g. ``{'gamma': np.ones(...) * 2, 'beta': ConstantInitializer(0)}``.
no_scale (bool): If `True`, the scale term is omitted.
no_bias (bool): If `True`, the bias term is omitted.

Expand Down Expand Up @@ -2128,7 +2128,7 @@ def group_normalization(inp, num_groups, channel_axis=1, batch_axis=0, eps=1e-05
be ``'gamma'``, ``'beta'``.
A value of the dict must be an :obj:`~nnabla.initializer.Initializer`
or a :obj:`numpy.ndarray`.
E.g. ``{'gamma': np.ones(...) * 2, 'beta': ConstantIntializer(0)}``.
E.g. ``{'gamma': np.ones(...) * 2, 'beta': ConstantInitializer(0)}``.
no_scale (bool): If `True`, the scale term is omitted.
no_bias (bool): If `True`, the bias term is omitted.

Expand Down Expand Up @@ -3649,7 +3649,7 @@ def multi_head_attention(query, key, value, num_heads=12, dropout=0.0, rng=None,
Parameter initializers can be set with a dict. Possible keys of the dict include q_weight, k_weight, v_weight, q_bias, k_bias, v_bias, out_weight, out_bias, attn_bias_k, attn_bias_v.
A value of the dict must be an :obj:`~nnabla.initializer.Initializer`
or a :obj:`numpy.ndarray`.
E.g. ``{'q_bias': ConstantIntializer(0)}``.
E.g. ``{'q_bias': ConstantInitializer(0)}``.

Returns:
~nnabla.Variable: Output :math:`y` with shape :math:`(L_T, B, E)`
Expand Down