Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix the bug of BidirectionalCell #13575

Merged
merged 8 commits into from
Dec 13, 2018
8 changes: 4 additions & 4 deletions python/mxnet/gluon/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,8 +1041,8 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
reversed_inputs = F.SequenceReverse(F.stack(*inputs, axis=0),
sequence_length=valid_length,
use_sequence_length=True)
reversed_inputs = _as_list(F.split(reversed_inputs, axis=0, num_outputs=length,
squeeze_axis=True))
reversed_inputs = list(F.split(reversed_inputs, axis=0, num_outputs=length,
squeeze_axis=True))
begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)

states = begin_state
Expand All @@ -1063,8 +1063,8 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
sequence_length=valid_length,
use_sequence_length=True,
axis=0)
reversed_r_outputs = _as_list(F.split(reversed_r_outputs, axis=0, num_outputs=length,
squeeze_axis=True))
reversed_r_outputs = list(F.split(reversed_r_outputs, axis=0, num_outputs=length,
squeeze_axis=True))
if merge_outputs is None:
merge_outputs = isinstance(l_outputs, tensor_types)
l_outputs, _, _, _ = _format_sequence(None, l_outputs, layout, merge_outputs)
Expand Down