Skip to content

Commit

Permalink
added GRU layer
Browse files Browse the repository at this point in the history
  • Loading branch information
Florian Krebs committed Jul 21, 2016
1 parent 1e19d5b commit 8cefa54
Show file tree
Hide file tree
Showing 2 changed files with 220 additions and 1 deletion.
152 changes: 151 additions & 1 deletion madmom/ml/nn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,156 @@ def activate(self, data):
return out


class GRUCell(object):
"""
Cell as used by GRU layers proposed in [1]_. The cell output is computed by
.. math::
h = tanh(W_{xh} * x_t + W_{hh} * h_{t-1} + b).
Parameters
----------
weights : numpy array, shape (num_inputs, num_hiddens)
Weights of the connections between inputs and cell.
recurrent_weights : numpy array, shape (num_hiddens, num_hiddens)
Weights of the connections between cell and cell output of the
previous time step.
bias : scalar or numpy array, shape (num_hiddens,)
Bias.
activation_fn : numpy ufunc, optional
Activation function.
References
----------
.. [1] Kyunghyun Cho, Bart Van Merrienboer, Dzmitry Bahdanau, and Yoshua
Bengio,
"On the properties of neural machine translation: Encoder-decoder
approaches",
http://arxiv.org/abs/1409.1259, 2014.
Notes
-----
There are two formulations of the GRUCell in the literature. Here,
we adopted the (slightly older) one proposed in [1]_, which is also
implemented in the Lasagne toolbox.
"""

def __init__(self, weights, recurrent_weights, bias, activation_fn=tanh):
self.weights = weights
self.recurrent_weights = recurrent_weights
self.bias = bias
self.activation_fn = activation_fn

def activate(self, data, reset_gate, prev):
"""
Activate the gate with the given input, reset_gate and the previous
output.
Parameters
----------
data : scalar or numpy array, shape (num_frames, num_inputs)
Input data for the cell.
reset_gate : scalar or numpy array, shape (num_hiddens,)
Activation of the reset gate.
prev : scalar or numpy array, shape (num_hiddens,)
Cell output of the previous time step.
Returns
-------
numpy array, shape (num_frames, num_hiddens)
Activations of the gate for this data.
"""
# weight input and add bias
out = np.dot(data, self.weights) + self.bias
# weight previous cell output and reset gate
out += reset_gate * np.dot(prev, self.recurrent_weights)
# apply activation function and return it
return self.activation_fn(out)


class GRULayer(Layer):
"""
Recurrent network layer with Gated Recurrent Units (GRU) as proposed in
[1]_.
Parameters
----------
reset_gate : :class:`Gate`
Reset gate.
update_gate : :class:`Gate`
Update gate.
cell : :class:`GRUCell`
GRU cell
hid_init : numpy array, shape (num_hiddens,), optional
Initial state of hidden units.
References
----------
.. [1] Kyunghyun Cho, Bart Van Merrienboer, Dzmitry Bahdanau, and Yoshua
Bengio,
"On the properties of neural machine translation: Encoder-decoder
approaches",
http://arxiv.org/abs/1409.1259, 2014.
Notes
-----
There are two formulations of the GRUCell in the literature. Here,
we adopted the (slightly older) one proposed in [1], which is also
implemented in the Lasagne toolbox.
"""

def __init__(self, reset_gate, update_gate, cell, hid_init=None):
# init the gates
self.reset_gate = reset_gate
self.update_gate = update_gate
self.cell = cell
if hid_init is None:
hid_init = np.zeros(cell.bias.size, dtype=NN_DTYPE)
self.hid_init = hid_init

def activate(self, data):
"""
Activate the GRU layer.
Parameters
----------
data : numpy array, shape (num_frames, num_inputs)
Activate with this data.
Returns
-------
numpy array, shape (num_frames, num_hiddens)
Activations for this data.
"""
# init arrays
size = len(data)
# output matrix for the whole sequence
out = np.zeros((size, self.update_gate.bias.size), dtype=NN_DTYPE)
# output (of the previous time step)
out_ = self.hid_init
# process the input data
for i in range(size):
# cache input data
data_ = data[i]
# reset gate:
# operate on current data and previous output (activation)
rg = self.reset_gate.activate(data_, out_)
# update gate:
# operate on current data and previous output (activation)
ug = self.update_gate.activate(data_, out_)
# hidden_update:
# implemented as proposed in [1]
hug = self.cell.activate(data_, rg, out_)
# output (activation)
out_ = ug * hug + (1 - ug) * out_
out[i] = out_
return out


class ConvolutionalLayer(FeedForwardLayer):
"""
Convolutional network layer.
Expand Down Expand Up @@ -506,7 +656,7 @@ class BatchNormLayer(Layer):
"""
Batch normalization layer with activation function. The previous layer
is usually linear with no bias - the BatchNormLayer's beta parameter
replaces it. See [1] for a detailed understanding of the parameters.
replaces it. See [1]_ for a detailed understanding of the parameters.
Parameters
----------
Expand Down
69 changes: 69 additions & 0 deletions tests/test_ml_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,75 @@ def test_cnn(self):
0.84207922, 0.21631248]))


class TestGRUClass(unittest.TestCase):

W_xr = np.array([[-0.42948743, -1.29989187],
[0.77213901, 0.86070993],
[1.13791823, -0.87066225]])
W_xu = np.array([[0.44875312, 0.07172084],
[-0.24292999, 1.318794],
[1.0270179, 0.16293946]])
W_xhu = np.array([[0.8812559, 1.35859991],
[1.04311944, -0.25449358],
[-1.09539597, 1.19808424]])
W_hr = np.array([[0.96696973, 0.1384294],
[-0.09561655, -1.23413809]])
W_hu = np.array([[0.04664641, 0.59561686],
[1.00325841, -0.11574791]])
W_hhu = np.array([[1.19742848, 1.07850016],
[0.35234964, -1.45348681]])
b_r = np.array([1.41851288, -0.39743243])
b_u = np.array([-0.78729095, 0.83385797])
b_hu = np.array([1.25143065, -0.97715625])

IN = np.array([[0.91298812, -1.47626202, -1.08667502],
[0.49814883, -0.0104938, 0.93869008],
[-1.12282135, 0.3780883, 1.42017503],
[0.62669439, 0.89438929, -0.69354132],
[0.16162221, -1.00166208, 0.23579985]])
H = np.array([0.02345737, 0.34454183])

def setUp(self):
self.reset_gate = layers.Gate(
TestGRUClass.W_xr, TestGRUClass.b_r, TestGRUClass.W_hr,
activation_fn=activations.sigmoid)
self.update_gate = layers.Gate(
TestGRUClass.W_xu, TestGRUClass.b_u, TestGRUClass.W_hu,
activation_fn=activations.sigmoid)
self.gru_cell = layers.GRUCell(
TestGRUClass.W_xhu, TestGRUClass.W_hhu, TestGRUClass.b_hu)
self.gru_1 = layers.GRULayer(self.reset_gate, self.update_gate,
self.gru_cell)
self.gru_2 = layers.GRULayer(self.reset_gate, self.update_gate,
self.gru_cell, hid_init=TestGRUClass.H)

def test_process(self):
self.assertTrue(
np.allclose(self.reset_gate.activate(TestGRUClass.IN[0, :],
TestGRUClass.H), np.array([0.20419282, 0.08861294])))
self.assertTrue(
np.allclose(self.update_gate.activate(TestGRUClass.IN[0, :],
TestGRUClass.H), np.array([0.31254834, 0.2226105])))
self.assertTrue(
np.allclose(self.gru_cell.activate(TestGRUClass.IN[0, :],
TestGRUClass.H, TestGRUClass.H),
np.array([0.9366396, -0.67876764])))
self.assertTrue(
np.allclose(self.gru_1.activate(TestGRUClass.IN),
np.array([[0.22772433, -0.13181415],
[0.49479958, 0.51224858],
[0.08539771, -0.56119639],
[0.1946809, -0.50421363],
[0.17403202, -0.27258521]])))
self.assertTrue(
np.allclose(self.gru_2.activate(TestGRUClass.IN),
np.array([[0.30988133, 0.13258138],
[0.60639685, 0.55714613],
[0.21366976, -0.55568963],
[0.30860096, -0.43686554],
[0.28866628, -0.23025239]])))


class TestBatchNormLayerClass(unittest.TestCase):

IN = np.array([[[0.32400414, 0.31483042],
Expand Down

0 comments on commit 8cefa54

Please sign in to comment.