Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
revert
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jun 29, 2018
1 parent 1abaa13 commit fe468c6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
11 changes: 8 additions & 3 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
def hybrid_forward(self, F, inputs, states=None, **kwargs):
if F is ndarray:
batch_size = inputs.shape[self._layout.find('N')]
if self._input_size == 0:
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
skip_states = states is None
if skip_states:
if F is ndarray:
Expand All @@ -199,12 +203,13 @@ def hybrid_forward(self, F, inputs, states=None, **kwargs):
# out is (output, state)
return out[0] if skip_states else out

def infer_shape(self, inputs, *states):
if self._input_size == 0:
def __call__(self, inputs, *states):
if self._input_size == 0 and isinstance(inputs, ndarray.NDArray):
for i in range(self._dir):
self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2])
self.i2h_weight[i]._finish_deferred_init()
return super(_RNNLayer, self).infer_shape(inputs, *states)
states = list(filter(lambda x: x is not None, states))
return super(_RNNLayer, self).__call__(inputs, *states)

def _forward_kernel(self, F, inputs, states, **kwargs):
""" forward using CUDNN or CPU kenrel"""
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def test_rnn_layers():
mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True)

net = gluon.nn.HybridSequential()
net.add(gluon.rnn.LSTM(10, 2, bidirectional=True))
net.add(gluon.rnn.LSTM(10, bidirectional=True))
net.add(gluon.nn.BatchNorm(axis=2))
net.add(gluon.nn.Flatten())
net.add(gluon.nn.Dense(3, activation='relu'))
Expand Down

0 comments on commit fe468c6

Please sign in to comment.