From 5474b086757a8df94984fb95622ead0047ac78b4 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sat, 4 Aug 2018 12:17:50 -0700 Subject: [PATCH] make gluon rnn layers hybrid blocks (#11482) * make Gluon RNN layer hybrid block * separate gluon gpu tests * remove excess assert_raises_cudnn_disabled usage * add comments and refactor * add bidirectional test * temporarily remove hybridize in test_gluon_rnn.test_layer_fill_shape --- python/mxnet/gluon/rnn/rnn_layer.py | 132 ++++++++------- src/operator/nn/concat.cc | 127 ++++++++++++--- src/operator/nn/concat.cu | 4 + src/operator/rnn.cc | 6 +- tests/python/gpu/test_gluon_gpu.py | 203 ++++++++++++++++++++++++ tests/python/gpu/test_operator_gpu.py | 124 --------------- tests/python/unittest/test_gluon_rnn.py | 91 +++++++++-- 7 files changed, 449 insertions(+), 238 deletions(-) create mode 100644 tests/python/gpu/test_gluon_gpu.py diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 418c497ce832..4a7a0be2bc30 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -23,12 +23,11 @@ from __future__ import print_function __all__ = ['RNN', 'LSTM', 'GRU'] -from ... import ndarray -from .. import Block +from ... import ndarray, symbol +from .. import HybridBlock, tensor_types from . import rnn_cell - -class _RNNLayer(Block): +class _RNNLayer(HybridBlock): """Implementation of recurrent layers.""" def __init__(self, hidden_size, num_layers, layout, dropout, bidirectional, input_size, @@ -52,33 +51,28 @@ def __init__(self, hidden_size, num_layers, layout, self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode] - self.i2h_weight = [] - self.h2h_weight = [] - self.i2h_bias = [] - self.h2h_bias = [] - ng, ni, nh = self._gates, input_size, hidden_size for i in range(num_layers): - for j in (['l', 'r'] if self._dir == 2 else ['l']): - self.i2h_weight.append( - self.params.get('%s%d_i2h_weight'%(j, i), shape=(ng*nh, ni), - init=i2h_weight_initializer, - allow_deferred_init=True)) - self.h2h_weight.append( - self.params.get('%s%d_h2h_weight'%(j, i), shape=(ng*nh, nh), - init=h2h_weight_initializer, - allow_deferred_init=True)) - self.i2h_bias.append( - self.params.get('%s%d_i2h_bias'%(j, i), shape=(ng*nh,), - init=i2h_bias_initializer, - allow_deferred_init=True)) - self.h2h_bias.append( - self.params.get('%s%d_h2h_bias'%(j, i), shape=(ng*nh,), - init=h2h_bias_initializer, - allow_deferred_init=True)) + for j in ['l', 'r'][:self._dir]: + self._register_param('{}{}_i2h_weight'.format(j, i), + shape=(ng*nh, ni), + init=i2h_weight_initializer) + self._register_param('{}{}_h2h_weight'.format(j, i), + shape=(ng*nh, nh), + init=h2h_weight_initializer) + self._register_param('{}{}_i2h_bias'.format(j, i), + shape=(ng*nh,), + init=i2h_bias_initializer) + self._register_param('{}{}_h2h_bias'.format(j, i), + shape=(ng*nh,), + init=h2h_bias_initializer) ni = nh * self._dir - self._unfused = self._unfuse() + def _register_param(self, name, shape, init): + p = self.params.get(name, shape=shape, init=init, + allow_deferred_init=True) + setattr(self, name, p) + return p def __repr__(self): s = '{name}({mapping}, {_layout}' @@ -89,12 +83,23 @@ def __repr__(self): if self._dir == 2: s += ', bidirectional' s += ')' - shape = self.i2h_weight[0].shape + shape = self.l0_i2h_weight.shape mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0] // self._gates) return s.format(name=self.__class__.__name__, mapping=mapping, **self.__dict__) + def _collect_params_with_prefix(self, prefix=''): + if prefix: + prefix += '.' + def convert_key(key): # for compatibility with old parameter format + key = key.split('_') + return '_unfused.{}.{}_cell.{}'.format(key[0][1:], key[0][0], '_'.join(key[1:])) + ret = {prefix + convert_key(key) : val for key, val in self._reg_params.items()} + for name, child in self._children.items(): + ret.update(child._collect_params_with_prefix(prefix + name)) + return ret + def state_info(self, batch_size=0): raise NotImplementedError @@ -111,7 +116,7 @@ def _unfuse(self): 'gru': lambda **kwargs: rnn_cell.GRUCell(self._hidden_size, **kwargs)}[self._mode] - stack = rnn_cell.SequentialRNNCell(prefix=self.prefix, params=self.params) + stack = rnn_cell.HybridSequentialRNNCell(prefix=self.prefix, params=self.params) with stack.name_scope(): ni = self._input_size for i in range(self._num_layers): @@ -169,55 +174,42 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs): states.append(func(name='%sh0_%d'%(self.prefix, i), **info)) return states - def forward(self, inputs, states=None): - batch_size = inputs.shape[self._layout.find('N')] + def hybrid_forward(self, F, inputs, states=None, **kwargs): + if F is ndarray: + batch_size = inputs.shape[self._layout.find('N')] skip_states = states is None if skip_states: - states = self.begin_state(batch_size, ctx=inputs.context) - if isinstance(states, ndarray.NDArray): + if F is ndarray: + states = self.begin_state(batch_size, ctx=inputs.context) + else: + states = self.begin_state(0, func=symbol.zeros) + if isinstance(states, tensor_types): states = [states] - for state, info in zip(states, self.state_info(batch_size)): - if state.shape != info['shape']: - raise ValueError( - "Invalid recurrent state shape. Expecting %s, got %s."%( - str(info['shape']), str(state.shape))) - if self._input_size == 0: - for i in range(self._dir): - self.i2h_weight[i].shape = (self._gates*self._hidden_size, inputs.shape[2]) - self.i2h_weight[i]._finish_deferred_init() - out = self._forward_kernel(inputs, states) + if F is ndarray: + for state, info in zip(states, self.state_info(batch_size)): + if state.shape != info['shape']: + raise ValueError( + "Invalid recurrent state shape. Expecting %s, got %s."%( + str(info['shape']), str(state.shape))) + out = self._forward_kernel(F, inputs, states, **kwargs) # out is (output, state) return out[0] if skip_states else out - def _forward(self, inputs, states): - """forward using gluon cell""" - ns = len(states) - axis = self._layout.find('T') - states = sum(zip(*((j for j in i) for i in states)), ()) - outputs, states = self._unfused.unroll( - inputs.shape[axis], inputs, states, - layout=self._layout, merge_outputs=True) - new_states = [] - for i in range(ns): - state = ndarray.concat(*(j.reshape((1,)+j.shape) for j in states[i::ns]), dim=0) - new_states.append(state) - - return outputs, new_states - - def _forward_kernel(self, inputs, states): + def _forward_kernel(self, F, inputs, states, **kwargs): """ forward using CUDNN or CPU kenrel""" if self._layout == 'NTC': - inputs = ndarray.swapaxes(inputs, dim1=0, dim2=1) - ctx = inputs.context - params = sum(zip(self.i2h_weight, self.h2h_weight), ()) - params += sum(zip(self.i2h_bias, self.h2h_bias), ()) - params = (i.data(ctx).reshape((-1,)) for i in params) - params = ndarray.concat(*params, dim=0) - - rnn = ndarray.RNN(inputs, params, *states, state_size=self._hidden_size, - num_layers=self._num_layers, bidirectional=self._dir == 2, - p=self._dropout, state_outputs=True, mode=self._mode) + inputs = F.swapaxes(inputs, dim1=0, dim2=1) + params = (kwargs['{}{}_{}_{}'.format(d, l, g, t)].reshape(-1) + for t in ['weight', 'bias'] + for l in range(self._num_layers) + for d in ['l', 'r'][:self._dir] + for g in ['i2h', 'h2h']) + params = F._internal._rnn_param_concat(*params, dim=0) + + rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size, + num_layers=self._num_layers, bidirectional=self._dir == 2, + p=self._dropout, state_outputs=True, mode=self._mode) if self._mode == 'lstm': outputs, states = rnn[0], [rnn[1], rnn[2]] @@ -225,7 +217,7 @@ def _forward_kernel(self, inputs, states): outputs, states = rnn[0], [rnn[1]] if self._layout == 'NTC': - outputs = ndarray.swapaxes(outputs, dim1=0, dim2=1) + outputs = F.swapaxes(outputs, dim1=0, dim2=1) return outputs, states diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index 266ccb1b1a14..7c7f403d6985 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -74,6 +74,65 @@ static bool ConcatShape(const nnvm::NodeAttrs& attrs, return dshape.Size() != 0; } +// Concat for RNN param deals with the reverse shape inference from output +// for the special case of concatenating RNN parameters. +// The first (and sometimes the second) input may be unknown on the target axis. +// If the two inputs are unknown, they always have the same shape. +static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + using namespace mshadow; + const ConcatParam& param_ = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); + TShape dshape; + index_t size = 0; + int num_zero = 0; + int axis = -1; + for (int i = 0; i < param_.num_args; ++i) { + TShape tmp = (*in_shape)[i]; + if (tmp.ndim()) { + axis = CheckAxis(param_.dim, tmp.ndim()); + num_zero += tmp[axis] == 0; + size += tmp[axis]; + tmp[axis] = 0; + shape_assign(&dshape, tmp); + } + } + + TShape tmp = (*out_shape)[0]; + if (tmp.ndim()) { + axis = CheckAxis(param_.dim, tmp.ndim()); + tmp[axis] = 0; + shape_assign(&dshape, tmp); + } + + if (dshape.ndim() == 0) return false; + + for (int i = 0; i < param_.num_args; ++i) { + CHECK(shape_assign(&(*in_shape)[i], dshape)) + << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i]; + } + + if (!num_zero) dshape[axis] = size; + CHECK(shape_assign(&(*out_shape)[0], dshape)) + << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0]; + if ((*out_shape)[0][axis] != 0 && num_zero) { + int residual = (*out_shape)[0][axis] - size; + CHECK_GE(residual, 0) + << "Input size already exceeds output size. Residual: " << residual; + CHECK(num_zero <= 2 && num_zero >= 0) + << "Expecting 1 or 2 inputs that need shape inference. Got: " << num_zero; + bool need_infer = !(*out_shape)[0].Size(); + for (int i = 0; i < num_zero; i++) { + (*in_shape)[i*2][axis] = residual / num_zero; + need_infer = need_infer || !(*in_shape)[i].Size(); + } + return !need_infer; + } + + return dshape.Size() != 0; +} + static bool ConcatType(const nnvm::NodeAttrs& attrs, std::vector *in_type, std::vector *out_type) { @@ -228,6 +287,34 @@ struct ConcatGrad { DMLC_REGISTER_PARAMETER(ConcatParam); +#define CONCAT_FORWARD_ATTRS \ +.set_num_inputs([](const NodeAttrs& attrs) { \ + const ConcatParam& params = nnvm::get(attrs.parsed); \ + return params.num_args; \ +}) \ +.set_num_outputs(1) \ +.set_attr_parser(ParamParser) \ +.set_attr("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + const ConcatParam& params = nnvm::get(attrs.parsed); \ + std::vector ret; \ + for (int i = 0; i < params.num_args; ++i) { \ + ret.push_back(std::string("arg") + std::to_string(i)); \ + } \ + return ret; \ +}) \ +.set_attr("FListOutputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector{"output"}; \ +}) \ +.set_attr("FInferType", ConcatType) \ +.set_attr("FInferStorageType", ConcatForwardInferStorageType) \ +.set_attr("FCompute", ConcatCompute) \ +.set_attr("FComputeEx", ConcatComputeExCPU) \ +.set_attr("FGradient", ConcatGrad{"_backward_Concat"}) \ +.set_attr("key_var_num_args", "num_args") + + NNVM_REGISTER_OP(Concat) MXNET_ADD_SPARSE_OP_ALIAS(concat) .add_alias("concat") @@ -268,37 +355,13 @@ Example:: [ 5., 5., 8., 8.]] )code" ADD_FILELINE) -.set_num_inputs([](const NodeAttrs& attrs) { - const ConcatParam& params = nnvm::get(attrs.parsed); - return params.num_args; -}) -.set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - const ConcatParam& params = nnvm::get(attrs.parsed); - std::vector ret; - for (int i = 0; i < params.num_args; ++i) { - ret.push_back(std::string("arg") + std::to_string(i)); - } - return ret; -}) -.set_attr("FListOutputNames", - [](const NodeAttrs& attrs) { - return std::vector{"output"}; -}) #if MXNET_USE_MKLDNN == 1 .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) #endif +CONCAT_FORWARD_ATTRS .set_attr("FInferShape", ConcatShape) -.set_attr("FInferType", ConcatType) -.set_attr("FInferStorageType", ConcatForwardInferStorageType) -.set_attr("FCompute", ConcatCompute) -.set_attr("FComputeEx", ConcatComputeExCPU) -.set_attr("FGradient", ConcatGrad{"_backward_Concat"}) -.set_attr("key_var_num_args", "num_args") .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") .add_arguments(ConcatParam::__FIELDS__()); @@ -320,5 +383,19 @@ NNVM_REGISTER_OP(_backward_Concat) #endif .set_attr("FCompute", ConcatGradCompute); +// _rnn_param_concat is a custom concat op with specialized infer_shape, +// which handles the case where the first one or two inputs may have +// unknown shape that can be inferred from output shape. +NNVM_REGISTER_OP(_rnn_param_concat) +#if MXNET_USE_MKLDNN == 1 +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +#endif +CONCAT_FORWARD_ATTRS +.set_attr("FInferShape", RNNParamConcatShape) +.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate") +.add_arguments(ConcatParam::__FIELDS__()); + } // namespace op } // namespace mxnet diff --git a/src/operator/nn/concat.cu b/src/operator/nn/concat.cu index 4f6b8fc9ebef..2872d527898e 100644 --- a/src/operator/nn/concat.cu +++ b/src/operator/nn/concat.cu @@ -50,6 +50,10 @@ NNVM_REGISTER_OP(Concat) .set_attr("FCompute", ConcatCompute) .set_attr("FComputeEx", ConcatComputeExGPU); +NNVM_REGISTER_OP(_rnn_param_concat) +.set_attr("FCompute", ConcatCompute) +.set_attr("FComputeEx", ConcatComputeExGPU); + NNVM_REGISTER_OP(_backward_Concat) .set_attr("FCompute", ConcatGradCompute); diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index 1e670a9047f0..73ef4f0f42a7 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -45,12 +45,12 @@ Operator *RNNProp::CreateOperatorEx(Context ctx, DMLC_REGISTER_PARAMETER(RNNParam); MXNET_REGISTER_OP_PROPERTY(RNN, RNNProp) -.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are +.describe(R"code(Applies recurrent layers to input data. Currently, vanilla RNN, LSTM and GRU are implemented, with both multi-layer and bidirectional support. **Vanilla RNN** -Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported: +Applies a single-gate recurrent layer to input X. Two kinds of activation function are supported: ReLU and Tanh. With ReLU activation function: @@ -63,7 +63,7 @@ With Tanh activtion function: .. math:: h_t = \tanh(W_{ih} * x_t + b_{ih} + W_{hh} * h_{(t-1)} + b_{hh}) -Reference paper: Finding structure in time - Elman, 1988. +Reference paper: Finding structure in time - Elman, 1988. https://crl.ucsd.edu/~elman/Papers/fsit.pdf **LSTM** diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py new file mode 100644 index 000000000000..42d65dab5fdc --- /dev/null +++ b/tests/python/gpu/test_gluon_gpu.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import print_function +import sys +import os +import time +import multiprocessing as mp +import unittest +import mxnet as mx +import numpy as np +import unittest +from nose.tools import assert_raises +from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal +from mxnet.base import MXNetError +from mxnet import autograd +from numpy.testing import assert_allclose + +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(curr_path, '../unittest')) +from common import setup_module, with_seed, teardown, assert_raises_cudnn_disabled +from test_gluon import * +from test_loss import * +from test_gluon_rnn import * + +set_default_context(mx.gpu(0)) + +def check_rnn_layer(layer): + layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) + with mx.gpu(0): + x = mx.nd.ones((10, 16, 30)) + states = layer.begin_state(16) + go, gs = layer(x, states) + + with mx.cpu(0): + x = mx.nd.ones((10, 16, 30)) + states = layer.begin_state(16) + co, cs = layer(x, states) + + # atol of 1e-6 required, as exposed by seed 2124685726 + assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-6) + for g, c in zip(gs, cs): + assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) + + +def check_rnn_layer_w_rand_inputs(layer): + layer.collect_params().initialize(ctx=[mx.cpu(0), mx.gpu(0)]) + x = mx.nd.uniform(shape=(10, 16, 30)) + with mx.gpu(0): + x = x.copyto(mx.gpu(0)) + states = layer.begin_state(16) + go, gs = layer(x, states) + + with mx.cpu(0): + x = x.copyto(mx.cpu(0)) + states = layer.begin_state(16) + co, cs = layer(x, states) + + assert_almost_equal(go.asnumpy(), co.asnumpy(), rtol=1e-2, atol=1e-6) + for g, c in zip(gs, cs): + assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) + + +@with_seed() +@assert_raises_cudnn_disabled() +def test_rnn_layer(): + check_rnn_layer(gluon.rnn.RNN(100, num_layers=3)) + check_rnn_layer(gluon.rnn.RNN(100, activation='tanh', num_layers=3)) + check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3)) + check_rnn_layer(gluon.rnn.GRU(100, num_layers=3)) + + check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) + check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) + + +@with_seed() +def test_gluon_ctc_consistency(): + loss = mx.gluon.loss.CTCLoss() + data = mx.nd.arange(0, 4, repeat=40, ctx=mx.gpu(0)).reshape((2,20,4)).flip(axis=0) + cpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.cpu(0)) + gpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.gpu(0)) + + cpu_data = data.copy().as_in_context(mx.cpu(0)) + cpu_data.attach_grad() + with mx.autograd.record(): + l_cpu = loss(cpu_data, cpu_label) + l_cpu.backward() + + gpu_data = data.copyto(mx.gpu(0)) + gpu_data.attach_grad() + with mx.autograd.record(): + l_gpu = loss(gpu_data, gpu_label) + l_gpu.backward() + + assert_almost_equal(cpu_data.grad.asnumpy(), gpu_data.grad.asnumpy(), atol=1e-3, rtol=1e-3) + + +@with_seed() +def test_global_norm_clip_multi_device(): + x1 = mx.nd.ones((3,3), ctx=mx.gpu(0)) + x2 = mx.nd.ones((4,4), ctx=mx.cpu(0)) + norm = gluon.utils.clip_global_norm([x1, x2], 1.0) + assert norm == 5.0 + assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5) + assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5) + + +def _check_batchnorm_result(input, num_devices=1, cuda=False): + from mxnet.gluon.utils import split_and_load + def _find_bn(module): + if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module + elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module.module + + raise RuntimeError('BN not found') + + def _syncParameters(bn1, bn2, ctx): + ctx = input.context + bn2.gamma.set_data(bn1.gamma.data(ctx)) + bn2.beta.set_data(bn1.beta.data(ctx)) + bn2.running_mean.set_data(bn1.running_mean.data(ctx)) + bn2.running_var.set_data(bn1.running_var.data(ctx)) + + input1 = input.copy() + input2 = input.copy() + + if cuda: + input1 = input.as_in_context(mx.gpu(0)) + ctx_list = [mx.gpu(i) for i in range(num_devices)] + else: + ctx_list = [mx.cpu(0) for _ in range(num_devices)] + + nch = input.shape[1] + bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) + bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) + + bn1.initialize(ctx=ctx_list[0]) + bn2.initialize(ctx=ctx_list) + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) + + input1.attach_grad() + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) + # assert forwarding + assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), + atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), + atol=1e-3, rtol=1e-3) + input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3) + + +def test_sync_batchnorm(): + def get_num_devices(): + for i in range(100): + try: + mx.nd.zeros((1,), ctx=mx.gpu(i)) + except: + return i + # no need to use SyncBN with 1 gpu + if get_num_devices() < 2: + return + ndev = 2 + # check with unsync version + for i in range(10): + _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), + num_devices=ndev, cuda=True) + +if __name__ == '__main__': + import nose + nose.runmodule() diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index a3e663a68274..3d799aa5319b 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -36,11 +36,8 @@ from test_operator import * from test_optimizer import * from test_random import * -from test_gluon import * -from test_loss import * from test_exc_handling import * #from test_rnn import * -from test_gluon_rnn import * from test_sparse_ndarray import * from test_sparse_operator import * from test_ndarray import * @@ -1660,17 +1657,6 @@ def check_rnn_layer_w_rand_inputs(layer): for g, c in zip(gs, cs): assert_almost_equal(g.asnumpy(), c.asnumpy(), rtol=1e-2, atol=1e-6) -@with_seed() -@assert_raises_cudnn_disabled() -def test_rnn_layer(): - check_rnn_layer(gluon.rnn.RNN(100, num_layers=3)) - check_rnn_layer(gluon.rnn.RNN(100, activation='tanh', num_layers=3)) - check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3)) - check_rnn_layer(gluon.rnn.GRU(100, num_layers=3)) - - check_rnn_layer(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) - check_rnn_layer_w_rand_inputs(gluon.rnn.LSTM(100, num_layers=3, bidirectional=True)) - @with_seed() def test_sequence_reverse(): check_sequence_reverse(mx.gpu(0)) @@ -1688,28 +1674,6 @@ def test_autograd_save_memory(): x.backward() -@with_seed() -def test_gluon_ctc_consistency(): - loss = mx.gluon.loss.CTCLoss() - data = mx.nd.arange(0, 4, repeat=40, ctx=mx.gpu(0)).reshape((2,20,4)).flip(axis=0) - cpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.cpu(0)) - gpu_label = mx.nd.array([[2,1,-1,-1],[3,2,2,-1]], ctx=mx.gpu(0)) - - cpu_data = data.copy().as_in_context(mx.cpu(0)) - cpu_data.attach_grad() - with mx.autograd.record(): - l_cpu = loss(cpu_data, cpu_label) - l_cpu.backward() - - gpu_data = data.copyto(mx.gpu(0)) - gpu_data.attach_grad() - with mx.autograd.record(): - l_gpu = loss(gpu_data, gpu_label) - l_gpu.backward() - - assert_almost_equal(cpu_data.grad.asnumpy(), gpu_data.grad.asnumpy(), atol=1e-3, rtol=1e-3) - - @with_seed() def test_cuda_rtc(): source = r''' @@ -1740,16 +1704,6 @@ def test_cuda_rtc(): assert (y.asnumpy() == 12).all() -@with_seed() -def test_global_norm_clip_multi_device(): - x1 = mx.nd.ones((3,3), ctx=mx.gpu(0)) - x2 = mx.nd.ones((4,4), ctx=mx.cpu(0)) - norm = gluon.utils.clip_global_norm([x1, x2], 1.0) - assert norm == 5.0 - assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5) - assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5) - - @with_seed() def test_cross_device_autograd(): x = mx.nd.random.uniform(shape=(10,)) @@ -1968,84 +1922,6 @@ def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 -def _check_batchnorm_result(input, num_devices=1, cuda=False): - from mxnet.gluon.utils import split_and_load - def _find_bn(module): - if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module - elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module.module - - raise RuntimeError('BN not found') - - def _syncParameters(bn1, bn2, ctx): - ctx = input.context - bn2.gamma.set_data(bn1.gamma.data(ctx)) - bn2.beta.set_data(bn1.beta.data(ctx)) - bn2.running_mean.set_data(bn1.running_mean.data(ctx)) - bn2.running_var.set_data(bn1.running_var.data(ctx)) - - input1 = input.copy() - input2 = input.copy() - - if cuda: - input1 = input.as_in_context(mx.gpu(0)) - ctx_list = [mx.gpu(i) for i in range(num_devices)] - else: - ctx_list = [mx.cpu(0) for _ in range(num_devices)] - - nch = input.shape[1] - bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) - bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) - - bn1.initialize(ctx=ctx_list[0]) - bn2.initialize(ctx=ctx_list) - - # using the same values for gamma and beta - #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) - - input1.attach_grad() - inputs2 = split_and_load(input2, ctx_list, batch_axis=0) - for xi in inputs2: - xi.attach_grad() - - with mx.autograd.record(): - output1 = bn1(input1) - output2 = [bn2(xi) for xi in inputs2] - loss1 = (output1 ** 2).sum() - loss2 = [(output ** 2).sum() for output in output2] - mx.autograd.backward(loss1) - mx.autograd.backward(loss2) - - output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) - # assert forwarding - assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-3, rtol=1e-3) - assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, rtol=1e-3) - assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), - _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), - atol=1e-3, rtol=1e-3) - assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), - _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), - atol=1e-3, rtol=1e-3) - input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3) - -def test_sync_batchnorm(): - def get_num_devices(): - for i in range(100): - try: - mx.nd.zeros((1,), ctx=mx.gpu(i)) - except: - return i - # no need to use SyncBN with 1 gpu - if get_num_devices() < 2: - return - ndev = 2 - # check with unsync version - for i in range(10): - _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), - num_devices=ndev, cuda=True) - if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index a9a2904e1e13..4e8241ffc1ea 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -16,7 +16,7 @@ # under the License. import mxnet as mx -from mxnet import gluon +from mxnet import gluon, nd import numpy as np import copy from numpy.testing import assert_allclose @@ -25,7 +25,6 @@ from common import assert_raises_cudnn_disabled -@assert_raises_cudnn_disabled() def test_rnn(): cell = gluon.rnn.RNNCell(100, prefix='rnn_') inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] @@ -51,7 +50,6 @@ def test_lstm(): assert outs == [(10, 100), (10, 100), (10, 100)] -@assert_raises_cudnn_disabled() def test_lstm_forget_bias(): forget_bias = 2.0 stack = gluon.rnn.SequentialRNNCell() @@ -77,19 +75,23 @@ def test_lstm_forget_bias(): def test_lstm_cpu_inference(): # should behave the same as lstm cell EXPECTED_LSTM_OUTPUT = np.array([[[0.72045636, 0.72045636, 0.95215213, 0.95215213], - [0.72045636, 0.72045636, 0.95215213, 0.95215213]], - [[0.95215213, 0.95215213, 0.72045636, 0.72045636], - [0.95215213, 0.95215213, 0.72045636, 0.72045636]]]) + [0.72045636, 0.72045636, 0.95215213, 0.95215213]], + [[0.95215213, 0.95215213, 0.72045636, 0.72045636], + [0.95215213, 0.95215213, 0.72045636, 0.72045636]]]) x = mx.nd.ones(shape=(2, 2, 2)) model = mx.gluon.rnn.LSTM(2, num_layers=6, bidirectional=True) + model_cell = model._unfuse() model.initialize(mx.init.One()) + y = model(x).asnumpy() + y_cell = model_cell.unroll(2, x, layout='TNC', merge_outputs=True)[0].asnumpy() + mx.test_utils.assert_almost_equal(y_cell, EXPECTED_LSTM_OUTPUT, + rtol=1e-3, atol=1e-5) mx.test_utils.assert_almost_equal(y, EXPECTED_LSTM_OUTPUT, rtol=1e-3, atol=1e-5) -@assert_raises_cudnn_disabled() def test_gru(): cell = gluon.rnn.GRUCell(100, prefix='rnn_') inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] @@ -241,6 +243,46 @@ def test_bidirectional(): assert outs == [(10, 200), (10, 200), (10, 200)] +@assert_raises_cudnn_disabled() +def test_layer_bidirectional(): + class RefBiLSTM(gluon.Block): + def __init__(self, size, **kwargs): + super(RefBiLSTM, self).__init__(**kwargs) + with self.name_scope(): + self._lstm_fwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='l0') + self._lstm_bwd = gluon.rnn.LSTM(size, bidirectional=False, prefix='r0') + + def forward(self, inpt): + fwd = self._lstm_fwd(inpt) + bwd_inpt = nd.flip(inpt, 0) + bwd = self._lstm_bwd(bwd_inpt) + bwd = nd.flip(bwd, 0) + return nd.concat(fwd, bwd, dim=2) + + size = 7 + in_size = 5 + weights = {} + for d in ['l', 'r']: + weights['lstm_{}0_i2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, in_size)) + weights['lstm_{}0_h2h_weight'.format(d)] = mx.random.uniform(shape=(size*4, size)) + weights['lstm_{}0_i2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) + weights['lstm_{}0_h2h_bias'.format(d)] = mx.random.uniform(shape=(size*4,)) + + net = gluon.rnn.LSTM(size, bidirectional=True, prefix='lstm_') + ref_net = RefBiLSTM(size, prefix='lstm_') + net.initialize() + ref_net.initialize() + net_params = net.collect_params() + ref_net_params = ref_net.collect_params() + for k in weights: + net_params[k].set_data(weights[k]) + ref_net_params[k.replace('l0', 'l0l0').replace('r0', 'r0l0')].set_data(weights[k]) + + data = mx.random.uniform(shape=(3, 10, in_size)) + assert_allclose(net(data).asnumpy(), ref_net(data).asnumpy()) + + + def test_zoneout(): cell = gluon.rnn.ZoneoutCell(gluon.rnn.RNNCell(100, prefix='rnn_'), zoneout_outputs=0.5, zoneout_states=0.5) @@ -341,9 +383,12 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False): layer.collect_params().initialize() inputs.attach_grad() with mx.autograd.record(): - out = layer(inputs, states) + if states is None: + out = layer(inputs) + else: + out = layer(inputs, states) if states is not None: - assert isinstance(out, tuple) and len(out) == 2 + assert isinstance(out, (list, tuple)) and len(out) == 2 out = out[0] else: assert isinstance(out, mx.nd.NDArray) @@ -355,15 +400,19 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False): layer.hybridize() with mx.autograd.record(): - out = layer(inputs, states) if states is not None: - assert isinstance(out, tuple) and len(out) == 2 + out = layer(inputs, states) + assert isinstance(out, (list, tuple)) and len(out) == 2 out = out[0] else: + out = layer(inputs) assert isinstance(out, mx.nd.NDArray) out.backward() - layer(inputs, states) # test is_training = false + if states is not None: + layer(inputs, states) # test is_training = false + else: + layer(inputs) if not run_only: mx.test_utils.assert_almost_equal(np_out, out.asnumpy(), rtol=1e-3, atol=1e-5) @@ -393,15 +442,26 @@ def test_rnn_layers(): check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5), mx.nd.ones((8, 3, 20)), mx.nd.ones((4, 3, 10)), run_only=True) - net = gluon.nn.Sequential() - net.add(gluon.rnn.LSTM(10, 2, bidirectional=True)) + net = gluon.nn.HybridSequential() + net.add(gluon.rnn.LSTM(10, bidirectional=True)) net.add(gluon.nn.BatchNorm(axis=2)) net.add(gluon.nn.Flatten()) net.add(gluon.nn.Dense(3, activation='relu')) + net.hybridize() net.collect_params().initialize() with mx.autograd.record(): net(mx.nd.ones((2, 3, 10))).backward() + net2 = gluon.nn.HybridSequential() + net2.add(gluon.rnn.LSTM(10, bidirectional=True)) + net2.add(gluon.nn.BatchNorm(axis=2)) + net2.add(gluon.nn.Flatten()) + net2.add(gluon.nn.Dense(3, activation='relu')) + net2.hybridize() + net2.collect_params().initialize() + with mx.autograd.record(): + net2(mx.nd.ones((2, 3, 10))).backward() + def test_rnn_unroll_variant_length(): # Test for imperative usage @@ -487,10 +547,9 @@ def test_cell_fill_shape(): @assert_raises_cudnn_disabled() def test_layer_fill_shape(): layer = gluon.rnn.LSTM(10) - layer.hybridize() check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7))) print(layer) - assert layer.i2h_weight[0].shape[1] == 7, layer.i2h_weight[0].shape[1] + assert layer.l0_i2h_weight.shape[1] == 7, layer.l0_i2h_weight.shape[1] if __name__ == '__main__':