Skip to content

Commit

Permalink
Merge pull request #547 from sony/feature/20191121-orthogonal-initial…
Browse files Browse the repository at this point in the history
…izer-seno

Orthogonal Initializer
  • Loading branch information
TakuyaNarihira authored Nov 26, 2019
2 parents c770f32 + ffeeb94 commit bd483cb
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 0 deletions.
3 changes: 3 additions & 0 deletions doc/python/api/parametric_function.rst
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ listed below.
.. autoclass:: UniformInitializer
:show-inheritance:

.. autoclass:: OrthogonalInitializer
:show-inheritance:

.. autofunction:: calc_normal_std_he_forward
.. autofunction:: calc_normal_std_he_backward
.. autofunction:: calc_normal_std_glorot
Expand Down
46 changes: 46 additions & 0 deletions python/src/nnabla/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,52 @@ def __call__(self, shape):
return np.ones(shape) * self.value


class OrthogonalInitializer(BaseInitializer):

r"""Generates an orthogonal matrix weights proposed by Saxe et al.
Args:
gain (float): scaling factor which should be decided depending on a type of units.
rng (numpy.random.RandomState): Random number generator.
Example:
.. code-block:: python
import numpy as np
import nnabla as nn
import nnabla.parametric_functions as PF
import nnabla.initializer as I
x = nn.Variable([60,1,28,28])
w = I.OrthogonalInitializer(np.sqrt(2.0))
b = I.ConstantInitializer(0.0)
h = PF.convolution(x, 64, [3, 3], w_init=w, b_init=b, pad=[1, 1], name='conv')
References:
* `Saxe, et al. Exact solutions to the nonlinear dynamics of
learning in deep linear neural networks.
<https://arxiv.org/abs/1312.6120>`_
"""

def __init__(self, gain=1.0, rng=None):
if rng is None:
rng = random.prng
self.rng = rng
self.gain = gain

def __repr__(self):
return '{}({})'.format(self.__class__.__name__,
self.gain)

def __call__(self, shape):
flat_shape = (shape[0], int(np.prod(shape[1:])))
x = self.rng.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(x, full_matrices=False)
q = u if u.shape == flat_shape else v
return q.reshape(shape).astype('float32') * self.gain


def calc_normal_std_he_forward(inmaps, outmaps, kernel=(1, 1)):
r"""Calculates the standard deviation proposed by He et al.
Expand Down
12 changes: 12 additions & 0 deletions python/test/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@
import nnabla.initializer as I


def orthogonal_test(x):
rows, cols = x.shape[0], int(np.prod(x.shape[1:]))
flattened = x.view().reshape((rows, cols))
if rows > cols:
target = np.matmul(flattened.T, flattened)
return np.allclose(target, np.eye(cols), atol=1e-6)
else:
target = np.matmul(flattened, flattened.T)
return np.allclose(target, np.eye(rows), atol=1e-6)


@pytest.mark.parametrize('rng', [None, np.random.RandomState(313)])
@pytest.mark.parametrize('shape', [
(10,),
Expand All @@ -30,6 +41,7 @@
(I.UniformInitializer, dict(lim=(-1, 10)),
lambda x: np.all(x >= -1) and np.all(x < 10)),
(I.ConstantInitializer, dict(value=-2), lambda x: np.all(x == -2)),
(I.OrthogonalInitializer, dict(gain=1.0), orthogonal_test)
])
def test_initializer_execution(shape, initializer, opts, condition, rng):
try:
Expand Down

0 comments on commit bd483cb

Please sign in to comment.