Skip to content

Commit

Permalink
fix rnn (apache#10954)
Browse files Browse the repository at this point in the history
  • Loading branch information
szha authored and Jin Huang committed May 29, 2018
1 parent 0cdd80c commit 8017a5f
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
5 changes: 3 additions & 2 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from __future__ import print_function
__all__ = ['RNN', 'LSTM', 'GRU']

from ... import ndarray
from ... import ndarray, autograd
from .. import Block
from . import rnn_cell

Expand Down Expand Up @@ -185,7 +185,8 @@ def forward(self, inputs, states=None):
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
if inputs.context.device_type == 'gpu' or self._mode == 'lstm':
if inputs.context.device_type == 'gpu' or \
self._mode == 'lstm' and not (self._dropout and autograd.is_training()):
out = self._forward_kernel(inputs, states)
else:
out = self._forward(inputs, states)
Expand Down
29 changes: 22 additions & 7 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_lstm_cpu_inference():

mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT,
rtol=1e-3, atol=1e-5)


def test_gru():
cell = gluon.rnn.GRUCell(100, prefix='rnn_')
Expand Down Expand Up @@ -242,7 +242,7 @@ def test_rnn_cells():
net.add(gluon.rnn.GRUCell(100, input_size=100))
check_rnn_forward(net, mx.nd.ones((8, 3, 200)))

def check_rnn_layer_forward(layer, inputs, states=None):
def check_rnn_layer_forward(layer, inputs, states=None, run_only=False):
layer.collect_params().initialize()
inputs.attach_grad()
with mx.autograd.record():
Expand All @@ -268,17 +268,32 @@ def check_rnn_layer_forward(layer, inputs, states=None):
assert isinstance(out, mx.nd.NDArray)
out.backward()

mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5)
mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5)
if not run_only:
mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5)
mx.test_utils.assert_almost_equal(np_dx, inputs.grad.asnumpy(), rtol=1e-3, atol=1e-5)


def test_rnn_layers():
check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.RNN(10, 2), mx.nd.ones((8, 3, 20)), mx.nd.ones((2, 3, 10)))
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2), mx.nd.ones((8, 3, 20)), [mx.nd.ones((2, 3, 10)), mx.nd.ones((2, 3, 10))])
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))])
check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.GRU(10, 2), mx.nd.ones((8, 3, 20)), mx.nd.ones((2, 3, 10)))
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)))

check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)),
run_only=True)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5),
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)),
run_only=True)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5),
mx.nd.ones((8, 3, 20)),
[mx.nd.ones((4, 3, 10)), mx.nd.ones((4, 3, 10))], run_only=True)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5), mx.nd.ones((8, 3, 20)),
run_only=True)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5),
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True)

net = gluon.nn.Sequential()
net.add(gluon.rnn.LSTM(10, 2, bidirectional=True))
Expand Down

0 comments on commit 8017a5f

Please sign in to comment.