Skip to content

Commit

Permalink
Merge pull request #545 from sony/feature/20191108-weight-normalization
Browse files Browse the repository at this point in the history
Add weight normalization
  • Loading branch information
TakuyaNarihira authored Nov 21, 2019
2 parents 951ddc0 + 3d3b9f2 commit 54ece10
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 0 deletions.
1 change: 1 addition & 0 deletions doc/python/api/parametric_function.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Here is the list of parametric functions.
.. automethod:: __call__(x, w_init, b_init, fix_parameters)

.. autofunction:: spectral_norm
.. autofunction:: weight_normalization
.. autofunction:: multi_head_attention
.. autoclass:: transformer
.. autoclass:: transformer_encode
Expand Down
67 changes: 67 additions & 0 deletions python/src/nnabla/parametric_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3597,6 +3597,73 @@ def _spectral_norm_outer_most_dim(w, dim, itr=1, eps=1e-12, test=False,
return w_sn


@parametric_function_api("wn", [
('g', 'Weight Normalization adaptive scale scalar.', 'w.shape[dim]', True),
])
def weight_normalization(v, dim=0, eps=1e-12, fix_parameters=False):
"""Weight Normalization.
.. math::
\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
where :math:`v` is the input matrix,
and :math:`g` is learnable multiplication factors each of which is applied to each output map at `dim`.
This function is in general used as callback passed to apply_w for PF.convolution, PF.affine and so on.
According to the author`s original implementation (https://github.com/TimSalimans/weight_norm), :math:`v` should be initialized by :math:`N(0, 0.05)`.
To meet this condition, initializer should be passed to convolution which Weight Normalization is applied, like an example below.
References:
* `Tim Salimans, Diederik P. Kingma, Weight Normalization: A Simple Reparameterization to Accelerate Training of Deep Neural Networks.
<https://arxiv.org/abs/1602.07868>`_
Args:
W (~nnabla.Variable): Input N-D array with shape. This is normally network parameter.
dim (`int`):
Output dimension. Default is 0.
If the dimension is not 0, then the specified dimension becomes the most-left dimension by transposing.
eps (`float`): Epsilon for the normalization. Default is 1e-12.
Returns:
~nnabla.Variable: :math:`W_{sn}` with the same shape as :math:`W`.
Example:
.. code-block:: python
import nnabla as nn
import nnabla.parametric_functions as PF
import nnabla.initializer as I
# h is nn.Variable.
# convolution
# according to the original implementation, w should be initialized by N(0, 0.05).
h = PF.convolution(h, ..., apply_w=PF.weight_normalization, w_init=I.NormalInitializer(0.05))
# affine
h = PF.affine(h, ..., apply_w=lambda w: PF.weight_normalization(w, dim=1), w_init=I.NormalInitializer(0.05))
"""
assert - \
len(v.shape) <= dim < len(
v.shape), "`dim` must be `-len(w.shape) <= dim < len(w.shape)`."
assert 0 < eps, "`eps` must be greater than 0."

# consider w as v.

outmaps = v.shape[dim]
g = get_parameter_or_create("g", (outmaps,),
initializer=ConstantInitializer(1.), need_grad=True, as_need_grad=not fix_parameters)

sh = tuple([outmaps if i == dim else 1 for i in range(len(v.shape))])
ax = tuple([i for i in range(len(v.shape)) if i != dim])

normalized_v = v / (F.sum(v ** 2, axis=ax, keepdims=True) + eps) ** 0.5

return F.reshape(g, sh) * normalized_v


@parametric_function_api("multi_head_attention", [
('q_weight', 'weights for query', '(E, E)', True),
('k_weight', 'weights for key', '(E_k, E)', True),
Expand Down
39 changes: 39 additions & 0 deletions python/test/test_parametric_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1587,4 +1587,43 @@ def test_pf_transformer_decode_execution(g_rng, tgt_len, batch_size, embed_dim,
else:
assert len(nn.get_parameters()) == 26


@pytest.mark.parametrize("func", ["conv", "affine"])
def test_pf_weight_norm_execution(g_rng, func):
# python implementation
def ref_weight_normalization(v, g, dim, eps=1e-12):
axis = tuple([i for i in range(len(v.shape)) if i != dim])
v_norm = np.sqrt(np.sum(v ** 2, axis=axis, keepdims=True) + eps)

return g * v / v_norm

dim = {"conv": 0, "affine": 1}[func]

def wn_clbk(v): return PF.weight_normalization(v, dim=dim)

x = nn.Variable.from_numpy_array(g_rng.randn(2, 4, 5, 5))
if func == "conv":
# assume channle first
y = PF.convolution(x, 8, (3, 3), apply_w=wn_clbk)
elif func == "affine":
y = PF.affine(x, 8, apply_w=wn_clbk)
else:
raise ValueError("unexpected function name {}".format(func))

# Setting
y.forward()
y.backward()

params = nn.get_parameters()
assert len(params) == 3 # w, b, g

# Check values
v = params["{}/W".format(func)]
w = y.parent.inputs[1]

v_np = v.d
w_np = ref_weight_normalization(v_np, 1, dim)

assert_allclose(w.d, w_np, atol=1e-2, rtol=1e-5)

# TODO: Test all parametric functions.

0 comments on commit 54ece10

Please sign in to comment.