diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py index 21cc8043154e..f4303eab1a27 100644 --- a/python/mxnet/gluon/rnn/rnn_cell.py +++ b/python/mxnet/gluon/rnn/rnn_cell.py @@ -400,7 +400,8 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias, num_hidden=self._hidden_size, name=prefix+'h2h') - output = self._get_activation(F, i2h + h2h, self._activation, + i2h_plus_h2h = F.elemwise_add(i2h, h2h, name=prefix+'plus0') + output = self._get_activation(F, i2h_plus_h2h, self._activation, name=prefix+'out') return output, [output] @@ -513,7 +514,7 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, num_hidden=self._hidden_size*4, name=prefix+'i2h') h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias, num_hidden=self._hidden_size*4, name=prefix+'h2h') - gates = i2h + h2h + gates = F.elemwise_add(i2h, h2h, name=prefix+'plus0') slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice') in_gate = self._get_activation( F, slice_gates[0], self._recurrent_activation, name=prefix+'i') @@ -523,9 +524,10 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, F, slice_gates[2], self._activation, name=prefix+'c') out_gate = self._get_activation( F, slice_gates[3], self._recurrent_activation, name=prefix+'o') - next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform, + next_c = F._internal._plus(F.elemwise_mul(forget_gate, states[1], name=prefix+'mul0'), + F.elemwise_mul(in_gate, in_transform, name=prefix+'mul1'), name=prefix+'state') - next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation), + next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation, name=prefix+'activation0'), name=prefix+'out') return next_h, [next_h, next_c] @@ -637,15 +639,22 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, h2h_r, h2h_z, h2h = F.SliceChannel(h2h, num_outputs=3, name=prefix+'h2h_slice') - reset_gate = F.Activation(i2h_r + h2h_r, act_type="sigmoid", + reset_gate = F.Activation(F.elemwise_add(i2h_r, h2h_r, name=prefix+'plus0'), act_type="sigmoid", name=prefix+'r_act') - update_gate = F.Activation(i2h_z + h2h_z, act_type="sigmoid", + update_gate = F.Activation(F.elemwise_add(i2h_z, h2h_z, name=prefix+'plus1'), act_type="sigmoid", name=prefix+'z_act') - next_h_tmp = F.Activation(i2h + reset_gate * h2h, act_type="tanh", + next_h_tmp = F.Activation(F.elemwise_add(i2h, + F.elemwise_mul(reset_gate, h2h, name=prefix+'mul0'), + name=prefix+'plus2'), + act_type="tanh", name=prefix+'h_act') - next_h = F._internal._plus((1. - update_gate) * next_h_tmp, update_gate * prev_state_h, + ones = F.ones_like(update_gate, name=prefix+"ones_like0") + next_h = F._internal._plus(F.elemwise_mul(F.elemwise_sub(ones, update_gate, name=prefix+'minus0'), + next_h_tmp, + name=prefix+'mul1'), + F.elemwise_mul(update_gate, prev_state_h, name=prefix+'mul20'), name=prefix+'out') return next_h, [next_h] diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 4e8241ffc1ea..c1d5f6a590f7 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -379,6 +379,54 @@ def test_rnn_cells(): net.add(gluon.rnn.GRUCell(100, input_size=100)) check_rnn_forward(net, mx.nd.ones((8, 3, 200))) + +def test_rnn_cells_export_import(): + class RNNLayer(gluon.HybridBlock): + def __init__(self): + super(RNNLayer, self).__init__() + with self.name_scope(): + self.cell = gluon.rnn.RNNCell(hidden_size=1) + + def hybrid_forward(self, F, seq): + outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True) + return outputs + + class LSTMLayer(gluon.HybridBlock): + def __init__(self): + super(LSTMLayer, self).__init__() + with self.name_scope(): + self.cell = gluon.rnn.LSTMCell(hidden_size=1) + + def hybrid_forward(self, F, seq): + outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True) + return outputs + + class GRULayer(gluon.HybridBlock): + def __init__(self): + super(GRULayer, self).__init__() + with self.name_scope(): + self.cell = gluon.rnn.GRUCell(hidden_size=1) + + def hybrid_forward(self, F, seq): + outputs, state = self.cell.unroll(inputs=seq, length=2, merge_outputs=True) + return outputs + + for hybrid in [RNNLayer(), LSTMLayer(), GRULayer()]: + hybrid.initialize() + hybrid.hybridize() + input = mx.nd.ones(shape=(1, 2, 1)) + output1 = hybrid(input) + hybrid.export(path="./model", epoch=0) + symbol = mx.gluon.SymbolBlock.imports( + symbol_file="./model-symbol.json", + input_names=["data"], + param_file="./model-0000.params", + ctx=mx.Context.default_ctx + ) + output2 = symbol(input) + assert_almost_equal(output1.asnumpy(), output2.asnumpy()) + + def check_rnn_layer_forward(layer, inputs, states=None, run_only=False): layer.collect_params().initialize() inputs.attach_grad()