diff --git a/python/nnvm/frontend/mxnet.py b/python/nnvm/frontend/mxnet.py index 0f5efd3b4ad2..6cbd497fd470 100644 --- a/python/nnvm/frontend/mxnet.py +++ b/python/nnvm/frontend/mxnet.py @@ -7,6 +7,12 @@ __all__ = ['from_mxnet'] +def _get_nnvm_op(op_name): + op = getattr(_sym, op_name) + if not op: + raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name)) + return op + def _get_mxnet_version(): try: import mxnet as mx @@ -39,14 +45,11 @@ def _parse_bool_str(attr, key, default='False'): return attr.get(key, default).strip().lower() in ['true', '1', 't', 'y', 'yes'] def _rename(new_name): - def impl(attr): - return new_name, attr + def impl(inputs, attrs): + return _get_nnvm_op(new_name)(*inputs, **attrs) return impl -def _variable(attrs): - return "Variable", attrs - -def _pooling(attrs): +def _pooling(inputs, attrs): kernel = _parse_tshape(_required_attr(attrs, 'kernel')) if len(kernel) != 2: _raise_not_supported('non-2d kernel', 'pool_2d') @@ -61,9 +64,9 @@ def _pooling(attrs): new_attrs['strides'] = attrs.get('stride', (1, 1)) new_attrs['padding'] = attrs.get('pad', (0, 0)) new_attrs['ceil_mode'] = (attrs.get('pooling_convention', 'valid') == 'full') - return op_name, new_attrs + return _get_nnvm_op(op_name)(*inputs, **new_attrs) -def _batch_norm(attrs): +def _batch_norm(inputs, attrs): if _parse_bool_str(attrs, 'output_mean_var'): _raise_not_supported('output_mean_var', 'batch_norm') # if _parse_bool_str(attrs, 'fix_gamma'): @@ -77,14 +80,14 @@ def _batch_norm(attrs): new_attrs['epsilon'] = attrs.get('eps', 0.001) new_attrs['center'] = True new_attrs['scale'] = True - return op_name, new_attrs + return _get_nnvm_op(op_name)(*inputs, **new_attrs) -def _concat(attrs): +def _concat(inputs, attrs): op_name = 'concatenate' new_attrs = {'axis': attrs.get('dim', 1)} - return op_name, new_attrs + return _get_nnvm_op(op_name)(*inputs, **new_attrs) -def _conv2d(attrs): +def _conv2d(inputs, attrs): kernel = _parse_tshape(_required_attr(attrs, 'kernel')) if len(kernel) != 2: _raise_not_supported('non 2d kernel', 'conv2d') @@ -100,9 +103,9 @@ def _conv2d(attrs): new_attrs['groups'] = attrs.get('num_group', 1) new_attrs['layout'] = layout new_attrs['use_bias'] = attrs.get('no_bias', 'False').strip() == 'False' - return op_name, new_attrs + return _get_nnvm_op(op_name)(*inputs, **new_attrs) -def _conv2d_transpose(attrs): +def _conv2d_transpose(inputs, attrs): if 'target_shape' in attrs: _raise_not_supported('target_shape', 'conv2d_transpose') kernel = _parse_tshape(_required_attr(attrs, 'kernel')) @@ -121,51 +124,68 @@ def _conv2d_transpose(attrs): new_attrs['groups'] = attrs.get('num_group', 1) new_attrs['layout'] = layout new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') - return op_name, new_attrs + return _get_nnvm_op(op_name)(*inputs, **new_attrs) -def _dense(attrs): +def _dense(inputs, attrs): op_name, new_attrs = 'dense', {} new_attrs['units'] = _required_attr(attrs, 'num_hidden') new_attrs['use_bias'] = not _parse_bool_str(attrs, 'no_bias') major, minor, micro = _get_mxnet_version() if major >= 0 and minor >= 11 and micro >= 1: - new_attrs['flatten'] = _parse_bool_str(attrs, 'flatten', 'True') - return op_name, new_attrs + use_flatten = _parse_bool_str(attrs, 'flatten', 'True') + if use_flatten: + inputs[0] = _sym.flatten(inputs[0]) + return _get_nnvm_op(op_name)(*inputs, **new_attrs) -def _dropout(attrs): +def _dropout(inputs, attrs): op_name, new_attrs = 'dropout', {} new_attrs['rate'] = attrs.get('p', 0.5) - return op_name, new_attrs + return _get_nnvm_op(op_name)(*inputs, **new_attrs) -def _leaky_relu(attrs): +def _leaky_relu(inputs, attrs): act_type = _required_attr(attrs, 'act_type') - if act_type not in ['leaky']: + if act_type in ['leaky']: + op_name, new_attrs = 'leaky_relu', {} + new_attrs['alpha'] = attrs.get('slope', 0.25) + sym = _get_nnvm_op(op_name)(*inputs, **new_attrs) + elif act_type == 'elu': + slope = attrs.get('slope', 0.25) + sym = -slope * _sym.relu(1 - _sym.exp(*inputs)) + _sym.relu(*inputs) + elif act_type == 'rrelu': + lower_bound = float(_required_attr(attrs, 'lower_bound')) + upper_bound = float(_required_attr(attrs, 'upper_bound')) + slope = (lower_bound + upper_bound) / 2.0 + op_name, new_attrs = 'leaky_relu', {'alpha': str(slope)} + sym = _get_nnvm_op(op_name)(*inputs, **new_attrs) + else: _raise_not_supported('act_type: ' + act_type) - op_name, new_attrs = 'leaky_relu', {} - new_attrs['alpha'] = attrs.get('slope', 0.25) - return op_name, new_attrs + return sym -def _activations(attrs): +def _activations(inputs, attrs): act_type = _required_attr(attrs, 'act_type') - if act_type not in ['relu', 'sigmoid', 'tanh']: + if act_type in ['relu', 'sigmoid', 'tanh']: + op_name, new_attrs = act_type, {} + sym = _get_nnvm_op(op_name)(*inputs, **new_attrs) + elif act_type == 'softrelu': + sym = _sym.log((1 + _sym.exp(*inputs))) + else: _raise_not_supported('act_type: ' + act_type) - op_name, new_attrs = act_type, {} - return op_name, new_attrs + return sym -def _reshape(attrs): +def _reshape(inputs, attrs): if _parse_bool_str(attrs, 'reverse'): _raise_not_supported('reverse', 'reshape') op_name, new_attrs = 'reshape', {} new_attrs['shape'] = _required_attr(attrs, 'shape') - return op_name, new_attrs + return _get_nnvm_op(op_name)(*inputs, **new_attrs) -def _split(attrs): +def _split(inputs, attrs): if _parse_bool_str(attrs, 'squeeze_axis'): _raise_not_supported('squeeze_axis', 'split') op_name, new_attrs = 'split', {} new_attrs['indices_or_sections'] = _required_attr(attrs, 'num_outputs') new_attrs['axis'] = attrs.get('axis', 1) - return op_name, new_attrs + return _get_nnvm_op(op_name)(*inputs, **new_attrs) _identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__', '__div_symbol__', '__mul_scalar__', '__mul_symbol__', @@ -178,7 +198,12 @@ def _split(attrs): 'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose'] _convert_map = { - 'null' : _variable, + '_div_scalar' : _rename('__div_scalar__'), + '_minus_scalar' : _rename('__sub_scalar__'), + '_mul_scalar' : _rename('__mul_scalar__'), + '_plus_scalar' : _rename('__add_scalar__'), + '_rdiv_scalar' : _rename('__rdiv_scalar__'), + '_rminus_scalar': _rename('__rsub_scalar__'), 'Activation' : _activations, 'BatchNorm' : _batch_norm, 'BatchNorm_v1' : _batch_norm, @@ -202,7 +227,7 @@ def _split(attrs): 'sum_axis' : _rename('sum'), } -def _convert_symbol(op_name, attrs, +def _convert_symbol(op_name, inputs, attrs, identity_list=None, convert_map=None): """Convert from mxnet op to nnvm op. @@ -213,6 +238,8 @@ def _convert_symbol(op_name, attrs, ---------- op_name : str Operator name, such as Convolution, FullyConnected + inputs : list of nnvm.Symbol + List of input symbols. attrs : dict Dict of operator attributes identity_list : list @@ -224,21 +251,19 @@ def _convert_symbol(op_name, attrs, Returns ------- - (op_name, attrs) - Converted (op_name, attrs) for nnvm. + sym : nnvm.Symbol + Converted nnvm Symbol """ identity_list = identity_list if identity_list else _identity_list convert_map = convert_map if convert_map else _convert_map if op_name in identity_list: - pass + op = _get_nnvm_op(op_name) + sym = op(*inputs, **attrs) elif op_name in convert_map: - op_name, attrs = convert_map[op_name](attrs) + sym = convert_map[op_name](inputs, attrs) else: _raise_not_supported('Operator: ' + op_name) - op = getattr(_sym, op_name, None) - if not op: - raise RuntimeError("Unable to map op_name {} to nnvm.sym".format(op_name)) - return op, attrs + return sym def _is_mxnet_group_symbol(symbol): """Internal check for mxnet group symbol.""" @@ -274,28 +299,20 @@ def _from_mxnet_impl(symbol, graph): node = graph.get(name, None) if node: return node + attr = symbol.list_attr() # op_name = symbol.attr('op_name') - if symbol.get_children(): + childs = symbol.get_children() + if childs: op_name = symbol.attr('op_name') - else: - op_name = json.loads(symbol.tojson())['nodes'][0]['op'] - attr = symbol.list_attr() - new_op, new_attr = _convert_symbol(op_name, attr) - if new_op == _sym.Variable: - node = new_op(name=name, **new_attr) - else: - childs = symbol.get_children() childs = [_from_mxnet_impl(c, graph) for c in _as_list(childs)] childs = [x for y in childs for x in _as_list(y)] # expand group symbol - if new_op == _sym.dense and 'flatten' in new_attr: - if new_attr['flatten']: - childs[0] = _sym.flatten(childs[0]) - new_attr.pop('flatten') - node = new_op(name=name, *childs, **new_attr) + node = _convert_symbol(op_name, childs, attr) + else: + op_name = json.loads(symbol.tojson())['nodes'][0]['op'] + node = _sym.Variable(name=name, **attr) graph[name] = node return node - def from_mxnet(symbol, arg_params=None, aux_params=None): """Convert from MXNet's model into compatible NNVM format. diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index ca7a0156a8b5..0f9747538ce9 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -46,7 +46,7 @@ def get_tvm_output(symbol, x, args, auxs, target, ctx, dtype='float32'): assert "data" not in args for target, ctx in ctx_list(): tvm_out = get_tvm_output(mx_symbol, x, args, auxs, target, ctx, dtype) - np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5) + np.testing.assert_allclose(mx_out, tvm_out, rtol=1e-5, atol=1e-5) def test_forward_mlp(): mlp = model_zoo.mx_mlp @@ -62,7 +62,40 @@ def test_forward_resnet(): mx_sym = model_zoo.mx_resnet[n] verify_mxnet_frontend_impl(mx_sym) +def test_forward_elu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + mx_sym = mx.sym.LeakyReLU(data, act_type='elu') + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_rrelu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + mx_sym = mx.sym.LeakyReLU(data, act_type='rrelu', lower_bound=0.3, upper_bound=0.7) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_softrelu(): + data = mx.sym.var('data') + data = mx.sym.concat(data, -data, dim=1) # negative part explicitly + mx_sym = mx.sym.Activation(data, act_type='softrelu') + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 6, 100, 100)) + +def test_forward_fc_flatten(): + # test flatten=True option in mxnet 0.11.1 + data = mx.sym.var('data') + try: + mx_sym = mx.sym.FullyConnected(data, num_hidden=100, flatten=True) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100)) + mx_sym = mx.sym.FullyConnected(mx.sym.Flatten(data), num_hidden=100, flatten=False) + verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 100)) + except: + pass + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() test_forward_resnet() + test_forward_elu() + test_forward_rrelu() + test_forward_softrelu() + test_forward_fc_flatten()