From eb1a41b59a67ff153ea624bf8976d29ff6ee3f16 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Thu, 26 Jul 2018 11:14:09 -0700 Subject: [PATCH] make Gluon RNN layer hybrid block --- python/mxnet/gluon/rnn/rnn_layer.py | 128 +++++++++++------------- src/operator/nn/concat.cc | 87 ++++++++++++++++ src/operator/nn/concat.cu | 4 + src/operator/rnn.cc | 6 +- src/operator/tensor/matrix_op-inl.h | 31 +++++- tests/python/unittest/test_gluon_rnn.py | 34 +++++-- 6 files changed, 206 insertions(+), 84 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 418c497ce832..5bbc6f420649 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -23,12 +23,11 @@ from __future__ import print_function __all__ = ['RNN', 'LSTM', 'GRU'] -from ... import ndarray -from .. import Block +from ... import ndarray, symbol +from .. import HybridBlock, tensor_types from . import rnn_cell - -class _RNNLayer(Block): +class _RNNLayer(HybridBlock): """Implementation of recurrent layers.""" def __init__(self, hidden_size, num_layers, layout, dropout, bidirectional, input_size, @@ -52,34 +51,31 @@ def __init__(self, hidden_size, num_layers, layout, self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] - self.i2h_weight = [] - self.h2h_weight = [] - self.i2h_bias = [] - self.h2h_bias = [] - ng, ni, nh = self._gates, input_size, hidden_size for i in range(num_layers): for j in (['l', 'r'] if self._dir == 2 else ['l']): - self.i2h_weight.append( - self.params.get('%s%d_i2h_weight'%(j, i), shape=(ng*nh, ni), - init=i2h_weight_initializer, - allow_deferred_init=True)) - self.h2h_weight.append( - self.params.get('%s%d_h2h_weight'%(j, i), shape=(ng*nh, nh), - init=h2h_weight_initializer, - allow_deferred_init=True)) - self.i2h_bias.append( - self.params.get('%s%d_i2h_bias'%(j, i), shape=(ng*nh,), - init=i2h_bias_initializer, - allow_deferred_init=True)) - self.h2h_bias.append( - self.params.get('%s%d_h2h_bias'%(j, i), shape=(ng*nh,), - init=h2h_bias_initializer, - allow_deferred_init=True)) + self._register_param('{}{}_i2h_weight'.format(j, i), + shape=(ng*nh, ni), + init=i2h_weight_initializer) + self._register_param('{}{}_h2h_weight'.format(j, i), + shape=(ng*nh, nh), + init=h2h_weight_initializer) + self._register_param('{}{}_i2h_bias'.format(j, i), + shape=(ng*nh,), + init=i2h_bias_initializer) + self._register_param('{}{}_h2h_bias'.format(j, i), + shape=(ng*nh,), + init=h2h_bias_initializer) ni = nh * self._dir self._unfused = self._unfuse() + def _register_param(self, name, shape, init): + p = self.params.get(name, shape=shape, init=init, + allow_deferred_init=True) + setattr(self, name, p) + return p + def __repr__(self): s = '{name}({mapping}, {_layout}' if self._num_layers != 1: @@ -89,7 +85,7 @@ def __repr__(self): if self._dir == 2: s += ', bidirectional' s += ')' - shape = self.i2h_weight[0].shape + shape = self.l0_i2h_weight.shape mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates) return s.format(name=self.__class__.__name__, mapping=mapping, @@ -99,7 +95,14 @@ def state_info(self, batch_size=0): raise NotImplementedError def _unfuse(self): - """Unfuses the fused RNN in to a stack of rnn cells.""" + """Unfuses the fused RNN in to a stack of rnn cells. + + Returns + ------- + cell : SequentialRNNCell + A sequential RNN cell that replicates the structure of the RNN layer, with shared + weights. + """ get_cell = {'rnn_relu': lambda **kwargs: rnn_cell.RNNCell(self._hidden_size, activation='relu', **kwargs), @@ -111,7 +114,7 @@ def _unfuse(self): 'gru': lambda **kwargs: rnn_cell.GRUCell(self._hidden_size, **kwargs)}[self._mode] - stack = rnn_cell.SequentialRNNCell(prefix=self.prefix, params=self.params) + stack = rnn_cell.HybridSequentialRNNCell(prefix=self.prefix, params=self.params) with stack.name_scope(): ni = self._input_size for i in range(self._num_layers): @@ -169,55 +172,42 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs): states.append(func(name='%sh0_%d'%(self.prefix, i), **info)) return states - def forward(self, inputs, states=None): - batch_size = inputs.shape[self._layout.find('N')] + def hybrid_forward(self, F, inputs, states=None, **kwargs): + if F is ndarray: + batch_size = inputs.shape[self._layout.find('N')] skip_states = states is None if skip_states: - states = self.begin_state(batch_size, ctx=inputs.context) - if isinstance(states, ndarray.NDArray): + if F is ndarray: + states = self.begin_state(batch_size, ctx=inputs.context) + else: + states = self.begin_state(0, func=symbol.zeros) + if isinstance(states, tensor_types): states = [states] - for state, info in zip(states, self.state_info(batch_size)): - if state.shape != info['shape']: - raise ValueError( - "Invalid recurrent state shape. Expecting %s, got %s."%( - str(info['shape']), str(state.shape))) - if self._input_size == 0: - for i in range(self._dir): - self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) - self.i2h_weight[i]._finish_deferred_init() - out = self._forward_kernel(inputs, states) + if F is ndarray: + for state, info in zip(states, self.state_info(batch_size)): + if state.shape != info['shape']: + raise ValueError( + "Invalid recurrent state shape. Expecting %s, got %s."%( + str(info['shape']), str(state.shape))) + out = self._forward_kernel(F, inputs, states, **kwargs) # out is (output, state) return out[0] if skip_states else out - def _forward(self, inputs, states): - """forward using gluon cell""" - ns = len(states) - axis = self._layout.find('T') - states = sum(zip(*((j for j in i) for i in states)), ()) - outputs, states = self._unfused.unroll( - inputs.shape[axis], inputs, states, - layout=self._layout, merge_outputs=True) - new_states = [] - for i in range(ns): - state = ndarray.concat(*(j.reshape((1,)+j.shape) for j in states[i::ns]), dim=0) - new_states.append(state) - - return outputs, new_states - - def _forward_kernel(self, inputs, states): + def _forward_kernel(self, F, inputs, states, **kwargs): """ forward using CUDNN or CPU kenrel""" if self._layout == 'NTC': - inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1) - ctx = inputs.context - params = sum(zip(self.i2h_weight, self.h2h_weight), ()) - params += sum(zip(self.i2h_bias, self.h2h_bias), ()) - params = (i.data(ctx).reshape((-1,)) for i in params) - params = ndarray.concat(*params, dim=0) - - rnn = ndarray.RNN(inputs, params, *states, state_size=self._hidden_size, - num_layers=self._num_layers, bidirectional=self._dir == 2, - p=self._dropout, state_outputs=True, mode=self._mode) + inputs = F.swapaxes(inputs, dim1=0, dim2=1) + params = (kwargs['{}{}_{}_{}'.format(j, i, c, p)].reshape(-1) + for p in ['weight', 'bias'] + for c in ['i2h', 'h2h'] + for i in range(self._num_layers) + for j in (['l', 'r'] if self._dir == 2 else ['l'])) + params = F._internal._rnn_param_concat(*params, dim=0) + + rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size, + num_layers=self._num_layers, bidirectional=self._dir == 2, + p=self._dropout, state_outputs=True, mode=self._mode) if self._mode == 'lstm': outputs, states = rnn[0], [rnn[1], rnn[2]] @@ -225,7 +215,7 @@ def _forward_kernel(self, inputs, states): outputs, states = rnn[0], [rnn[1]] if self._layout == 'NTC': - outputs = ndarray.swapaxes(outputs, dim1=0, dim2=1) + outputs = F.swapaxes(outputs, dim1=0, dim2=1) return outputs, states diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index 266ccb1b1a14..735c4af1c1ca 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -74,6 +74,57 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs, return dshape.Size() != 0; } +static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + using namespace mshadow; + const ConcatParam& param_ = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); + TShape dshape; + index_t size = 0; + int num_zero = 0; + int axis = -1; + for (int i = 0; i < param_.num_args; ++i) { + TShape tmp = (*in_shape)[i]; + if (tmp.ndim()) { + axis = CheckAxis(param_.dim, tmp.ndim()); + num_zero += tmp[axis] == 0; + size += tmp[axis]; + tmp[axis] = 0; + shape_assign(&dshape, tmp); + } + } + + TShape tmp = (*out_shape)[0]; + if (tmp.ndim()) { + axis = CheckAxis(param_.dim, tmp.ndim()); + tmp[axis] = 0; + shape_assign(&dshape, tmp); + } + + if (dshape.ndim() == 0) return false; + + for (int i = 0; i < param_.num_args; ++i) { + CHECK(shape_assign(&(*in_shape)[i], dshape)) + << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i]; + } + + if (!num_zero) dshape[axis] = size; + CHECK(shape_assign(&(*out_shape)[0], dshape)) + << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0]; + index_t residual = (*out_shape)[0][axis] - size; + if ((*out_shape)[0].Size() != 0 && residual > 0 && num_zero) { + bool need_infer = false; + for (int i = 0; i < num_zero; i++) { + (*in_shape)[i][axis] = residual / num_zero; + need_infer = need_infer || (*in_shape)[i].Size() == 0; + } + return !need_infer; + } + + return dshape.Size() != 0; +} + static bool ConcatType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { @@ -320,5 +371,41 @@ NNVM_REGISTER_OP(_backward_Concat) #endif .set_attr("FCompute", ConcatGradCompute); + +NNVM_REGISTER_OP(_rnn_param_concat) +.set_num_inputs([](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + return params.num_args; +}) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const ConcatParam& params = nnvm::get(attrs.parsed); + std::vector ret; + for (int i = 0; i < params.num_args; ++i) { + ret.push_back(std::string("arg") + std::to_string(i)); + } + return ret; +}) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif +.set_attr("FInferShape", RNNParamConcatShape) +.set_attr("FInferType", ConcatType) +.set_attr("FInferStorageType", ConcatForwardInferStorageType) +.set_attr("FCompute", ConcatCompute) +.set_attr("FComputeEx", ConcatComputeExCPU) +.set_attr("FGradient", ConcatGrad{"_backward_Concat"}) +.set_attr("key_var_num_args", "num_args") +.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") +.add_arguments(ConcatParam::__FIELDS__()); + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/concat.cu b/src/operator/nn/concat.cu index 4f6b8fc9ebef..2872d527898e 100644 --- a/src/operator/nn/concat.cu +++ b/src/operator/nn/concat.cu @@ -50,6 +50,10 @@ NNVM_REGISTER_OP(Concat) .set_attr("FCompute", ConcatCompute) .set_attr("FComputeEx", ConcatComputeExGPU); +NNVM_REGISTER_OP(_rnn_param_concat) +.set_attr("FCompute", ConcatCompute) +.set_attr("FComputeEx", ConcatComputeExGPU); + NNVM_REGISTER_OP(_backward_Concat) .set_attr("FCompute", ConcatGradCompute); diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 1e670a9047f0..73ef4f0f42a7 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -45,12 +45,12 @@ Operator *RNNProp::CreateOperatorEx(Context ctx, DMLC_REGISTER_PARAMETER(RNNParam); MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp) -.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are +.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support. **Vanilla RNN** -Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported: +Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported: ReLU and Tanh. With ReLU activation function: @@ -63,7 +63,7 @@ With Tanh activtion function: .. math:: h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh}) -Reference paper: Finding structure in time - Elman, 1988. +Reference paper: Finding structure in time - Elman, 1988. https://crl.ucsd.edu/~elman/Papers/fsit.pdf **LSTM** diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 555b9cc7cb59..3da37d4f5dd1 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -151,9 +151,30 @@ inline TShape InferReshapeShape(const nnvm::Tuple& shape, return oshape; } +inline bool ReverseReshapeInferShape(TShape *in, const TShape& out) { + if (in->Size() && out.Size()) { + return true; + } if (!out.Size()) { + return false; + } else { + int zero_axis = -1; + int non_zero_prod = 1; + for (index_t i = 0; i < in->ndim(); i++) { + if ((*in)[i] == 0) { + if (zero_axis != -1) return false; // more than 1 zero found. + else zero_axis = i; + } else { + non_zero_prod *= (*in)[i]; + } + } + (*in)[zero_axis] = out.Size() / non_zero_prod; + return true; + } +} + inline bool ReshapeShape(const nnvm::NodeAttrs& attrs, - std::vector *in_attrs, - std::vector *out_attrs) { + std::vector *in_attrs, + std::vector *out_attrs) { const ReshapeParam& param_ = nnvm::get(attrs.parsed); CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]"; CHECK_EQ(out_attrs->size(), 1U); @@ -182,14 +203,16 @@ inline bool ReshapeShape(const nnvm::NodeAttrs& attrs, oshape[inf_idx] = dshape.Size() / oshape.Size(); } } else { - return (*out_attrs)[0].ndim(); + if ((*out_attrs)[0].ndim()) { + return ReverseReshapeInferShape(&(*in_attrs)[0], oshape); + } } CHECK_EQ(oshape.Size(), dshape.Size()) << "Target shape size is different to source. " << "Target: " << oshape << "\nSource: " << dshape; SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape); - return true; + return ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]); } inline bool FlattenShape(const nnvm::NodeAttrs& attrs, diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 6167f660d2c1..90c4df16a751 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -396,9 +396,12 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False): layer.collect_params().initialize() inputs.attach_grad() with mx.autograd.record(): - out = layer(inputs, states) + if states is None: + out = layer(inputs) + else: + out = layer(inputs, states) if states is not None: - assert isinstance(out, tuple) and len(out) == 2 + assert isinstance(out, (list, tuple)) and len(out) == 2 out = out[0] else: assert isinstance(out, mx.nd.NDArray) @@ -410,15 +413,19 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False): layer.hybridize() with mx.autograd.record(): - out = layer(inputs, states) if states is not None: - assert isinstance(out, tuple) and len(out) == 2 + out = layer(inputs, states) + assert isinstance(out, (list, tuple)) and len(out) == 2 out = out[0] else: + out = layer(inputs) assert isinstance(out, mx.nd.NDArray) out.backward() - layer(inputs, states) # test is_training = false + if states is not None: + layer(inputs, states) # test is_training = false + else: + layer(inputs) if not run_only: mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5) @@ -448,15 +455,26 @@ def test_rnn_layers(): check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True) - net = gluon.nn.Sequential() - net.add(gluon.rnn.LSTM(10, 2, bidirectional=True)) + net = gluon.nn.HybridSequential() + net.add(gluon.rnn.LSTM(10, bidirectional=True)) net.add(gluon.nn.BatchNorm(axis=2)) net.add(gluon.nn.Flatten()) net.add(gluon.nn.Dense(3, activation='relu')) + net.hybridize() net.collect_params().initialize() with mx.autograd.record(): net(mx.nd.ones((2, 3, 10))).backward() + net2 = gluon.nn.HybridSequential() + net2.add(gluon.rnn.LSTM(10, bidirectional=True)) + net2.add(gluon.nn.BatchNorm(axis=2)) + net2.add(gluon.nn.Flatten()) + net2.add(gluon.nn.Dense(3, activation='relu')) + net2.hybridize() + net2.collect_params().initialize() + with mx.autograd.record(): + net2(mx.nd.ones((2, 3, 10))).backward() + def test_rnn_unroll_variant_length(): # Test for imperative usage @@ -545,7 +563,7 @@ def test_layer_fill_shape(): layer.hybridize() check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7))) print(layer) - assert layer.i2h_weight[0].shape[1] == 7, layer.i2h_weight[0].shape[1] + assert layer.l0_i2h_weight.shape[1] == 7, layer.l0_i2h_weight.shape[1] if __name__ == '__main__':