From 0970b45f54be2910c57acc6215439eb758167afb Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Thu, 28 Jun 2018 23:33:04 -0700 Subject: [PATCH] make gluon rnn layers hybrid blocks WIP --- python/mxnet/gluon/rnn/rnn_layer.py | 100 +++++++++++++----------- src/operator/rnn.cc | 6 +- tests/python/unittest/test_gluon_rnn.py | 32 ++++++-- 3 files changed, 81 insertions(+), 57 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 418c497ce832..6ffce09578e8 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -23,12 +23,11 @@ from __future__ import print_function __all__ = ['RNN', 'LSTM', 'GRU'] -from ... import ndarray -from .. import Block +from ... import ndarray, symbol +from .. import HybridBlock, tensor_types from . import rnn_cell - -class _RNNLayer(Block): +class _RNNLayer(HybridBlock): """Implementation of recurrent layers.""" def __init__(self, hidden_size, num_layers, layout, dropout, bidirectional, input_size, @@ -98,8 +97,15 @@ def __repr__(self): def state_info(self, batch_size=0): raise NotImplementedError - def _unfuse(self): - """Unfuses the fused RNN in to a stack of rnn cells.""" + def unfuse(self): + """Unfuses the fused RNN in to a stack of rnn cells. + + Returns + ------- + cell : SequentialRNNCell + A sequential RNN cell that replicates the structure of the RNN layer, with shared + weights. + """ get_cell = {'rnn_relu': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size, activation='relu', **kwargs), @@ -169,55 +175,55 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs): states.append(func(name='%sh0_%d'%(self.prefix, i), **info)) return states - def forward(self, inputs, states=None): - batch_size = inputs.shape[self._layout.find('N')] + def hybrid_forward(self, F, inputs, states=None, **kwargs): + if F is ndarray: + batch_size = inputs.shape[self._layout.find('N')] + if self._input_size == 0: + for i in range(self._dir): + self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) + self.i2h_weight[i]._finish_deferred_init() skip_states = states is None if skip_states: - states = self.begin_state(batch_size, ctx=inputs.context) - if isinstance(states, ndarray.NDArray): + if F is ndarray: + states = self.begin_state(batch_size, ctx=inputs.context) + else: + states = self.begin_state(0, func=symbol.zeros) + if isinstance(states, tensor_types): states = [states] - for state, info in zip(states, self.state_info(batch_size)): - if state.shape != info['shape']: - raise ValueError( - "Invalid recurrent state shape. Expecting %s, got %s."%( - str(info['shape']), str(state.shape))) - if self._input_size == 0: - for i in range(self._dir): - self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) - self.i2h_weight[i]._finish_deferred_init() - out = self._forward_kernel(inputs, states) + if F is ndarray: + for state, info in zip(states, self.state_info(batch_size)): + if state.shape != info['shape']: + raise ValueError( + "Invalid recurrent state shape. Expecting %s, got %s."%( + str(info['shape']), str(state.shape))) + out = self._forward_kernel(F, inputs, states, **kwargs) # out is (output, state) return out[0] if skip_states else out - def _forward(self, inputs, states): - """forward using gluon cell""" - ns = len(states) - axis = self._layout.find('T') - states = sum(zip(*((j for j in i) for i in states)), ()) - outputs, states = self._unfused.unroll( - inputs.shape[axis], inputs, states, - layout=self._layout, merge_outputs=True) - new_states = [] - for i in range(ns): - state = ndarray.concat(*(j.reshape((1,)+j.shape) for j in states[i::ns]), dim=0) - new_states.append(state) - - return outputs, new_states - - def _forward_kernel(self, inputs, states): + def __call__(self, inputs, *states): + if self._input_size == 0 and isinstance(inputs, ndarray.NDArray): + for i in range(self._dir): + self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) + self.i2h_weight[i]._finish_deferred_init() + states = list(filter(lambda x: x is not None, states)) + return super(_RNNLayer, self).__call__(inputs, *states) + + def _forward_kernel(self, F, inputs, states, **kwargs): """ forward using CUDNN or CPU kenrel""" if self._layout == 'NTC': - inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1) - ctx = inputs.context - params = sum(zip(self.i2h_weight, self.h2h_weight), ()) - params += sum(zip(self.i2h_bias, self.h2h_bias), ()) - params = (i.data(ctx).reshape((-1,)) for i in params) - params = ndarray.concat(*params, dim=0) - - rnn = ndarray.RNN(inputs, params, *states, state_size=self._hidden_size, - num_layers=self._num_layers, bidirectional=self._dir == 2, - p=self._dropout, state_outputs=True, mode=self._mode) + inputs = F.swapaxes(inputs, dim1=0, dim2=1) + prefix = self._prefix[:-1] if self._prefix[-1] == '_' else self._prefix + params = (kwargs['{}_{}{}_{}_{}'.format(prefix, j, i, c, p)].reshape((-1,)) + for p in ['weight', 'bias'] + for c in ['i2h', 'h2h'] + for i in range(self._num_layers) + for j in (['l', 'r'] if self._dir == 2 else ['l'])) + params = F.concat(*params, dim=0) + + rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size, + num_layers=self._num_layers, bidirectional=self._dir == 2, + p=self._dropout, state_outputs=True, mode=self._mode) if self._mode == 'lstm': outputs, states = rnn[0], [rnn[1], rnn[2]] @@ -225,7 +231,7 @@ def _forward_kernel(self, inputs, states): outputs, states = rnn[0], [rnn[1]] if self._layout == 'NTC': - outputs = ndarray.swapaxes(outputs, dim1=0, dim2=1) + outputs = F.swapaxes(outputs, dim1=0, dim2=1) return outputs, states diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 1e670a9047f0..73ef4f0f42a7 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -45,12 +45,12 @@ Operator *RNNProp::CreateOperatorEx(Context ctx, DMLC_REGISTER_PARAMETER(RNNParam); MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp) -.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are +.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support. **Vanilla RNN** -Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported: +Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported: ReLU and Tanh. With ReLU activation function: @@ -63,7 +63,7 @@ With Tanh activtion function: .. math:: h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh}) -Reference paper: Finding structure in time - Elman, 1988. +Reference paper: Finding structure in time - Elman, 1988. https://crl.ucsd.edu/~elman/Papers/fsit.pdf **LSTM** diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 6167f660d2c1..ffda4d2a96c6 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -396,9 +396,12 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False): layer.collect_params().initialize() inputs.attach_grad() with mx.autograd.record(): - out = layer(inputs, states) + if states is None: + out = layer(inputs) + else: + out = layer(inputs, states) if states is not None: - assert isinstance(out, tuple) and len(out) == 2 + assert isinstance(out, (list, tuple)) and len(out) == 2 out = out[0] else: assert isinstance(out, mx.nd.NDArray) @@ -410,15 +413,19 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False): layer.hybridize() with mx.autograd.record(): - out = layer(inputs, states) if states is not None: - assert isinstance(out, tuple) and len(out) == 2 + out = layer(inputs, states) + assert isinstance(out, (list, tuple)) and len(out) == 2 out = out[0] else: + out = layer(inputs) assert isinstance(out, mx.nd.NDArray) out.backward() - layer(inputs, states) # test is_training = false + if states is not None: + layer(inputs, states) # test is_training = false + else: + layer(inputs) if not run_only: mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5) @@ -448,15 +455,26 @@ def test_rnn_layers(): check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True) - net = gluon.nn.Sequential() - net.add(gluon.rnn.LSTM(10, 2, bidirectional=True)) + net = gluon.nn.HybridSequential() + net.add(gluon.rnn.LSTM(10, bidirectional=True)) net.add(gluon.nn.BatchNorm(axis=2)) net.add(gluon.nn.Flatten()) net.add(gluon.nn.Dense(3, activation='relu')) + net.hybridize() net.collect_params().initialize() with mx.autograd.record(): net(mx.nd.ones((2, 3, 10))).backward() + net2 = gluon.nn.HybridSequential() + net2.add(gluon.rnn.LSTM(10, bidirectional=True)) + net2.add(gluon.nn.BatchNorm(axis=2)) + net2.add(gluon.nn.Flatten()) + net2.add(gluon.nn.Dense(3, activation='relu')) + net2.hybridize() + net2.collect_params().initialize() + with mx.autograd.record(): + net2(mx.nd.ones((2, 3, 10))).backward() + def test_rnn_unroll_variant_length(): # Test for imperative usage