From 351cce6bf75fc9395e24ba4259a45f2926e06709 Mon Sep 17 00:00:00 2001 From: Ke Han Date: Tue, 11 Feb 2020 17:26:06 +0000 Subject: [PATCH] [Numpy] Add op fmax * Fix sanity * Fix bug of gpu part, add scalar compute --- python/mxnet/ndarray/numpy/_op.py | 21 +++++++++++- python/mxnet/numpy/multiarray.py | 32 ++++++++++++++++++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 8 ++++- src/operator/mshadow_op.h | 14 ++++++++ .../np_elemwise_broadcast_op_extended.cc | 31 ++++++++++++++++++ .../np_elemwise_broadcast_op_extended.cu | 13 ++++++++ src/operator/operator_tune.cc | 1 + .../unittest/test_numpy_interoperability.py | 8 +++++ tests/python/unittest/test_numpy_op.py | 2 ++ 10 files changed, 128 insertions(+), 3 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index b16fc785e2c3..a89e922ec991 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -37,7 +37,7 @@ 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack', - 'average', 'mean', 'maximum', 'minimum', + 'average', 'mean', 'maximum', 'fmax', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', @@ -3951,6 +3951,25 @@ def maximum(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out) +@set_module('mxnet.ndarray.numpy') +@wrap_np_binary_func +def fmax(x1, x2, out=None, **kwargs): + """ + Returns element-wise maximum of the input arrays with broadcasting. (Ignores NaNs) + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars.""" + return _ufunc_helper(x1, x2, _npi.fmax, _np.fmax, _npi.fmax_scalar, None, out) + + @set_module('mxnet.ndarray.numpy') @wrap_np_binary_func def minimum(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 0853336e2539..d9dc2d2e746e 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -58,7 +58,8 @@ 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort', 'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack', - 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', + 'average', 'mean', 'maximum', 'fmax', 'minimum', + 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round', 'arctan2', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', @@ -5756,6 +5757,35 @@ def maximum(x1, x2, out=None, **kwargs): return _mx_nd_np.maximum(x1, x2, out=out) +@set_module('mxnet.numpy') +@wrap_np_binary_func +def fmax(x1, x2, out=None, **kwargs): + """ + Returns element-wise maximum of the input arrays with broadcasting. (Ignores NaNs) + + Parameters + ---------- + x1, x2 : scalar or mxnet.numpy.ndarray + The arrays holding the elements to be compared. They must have the same shape, + or shapes that can be broadcast to a single shape. + + Returns + ------- + out : mxnet.numpy.ndarray or scalar + The maximum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + + Examples + -------- + >>> np.maximum(np.array([2, 3, 4]), np.array([1, 5, 2])) + array([2., 5., 4.]) + + >>> np.maximum(np.eye(2), np.array([0.5, 2])) # broadcasting + array([[1. , 2. ], + [0.5, 2. ]]) + """ + return _mx_nd_np.fmax(x1, x2, out=out) + + @set_module('mxnet.numpy') @wrap_np_binary_func def minimum(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 9bd6954d01e2..bfa7b9cfe189 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -262,6 +262,7 @@ def _register_array_function(): 'arccosh', 'arctanh', 'maximum', + 'fmax', 'minimum', 'ceil', 'trunc', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index a41dc0847d49..a14a12d69cd0 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -44,7 +44,7 @@ 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack', - 'average', 'mean', 'maximum', 'minimum', + 'average', 'mean', 'maximum', 'fmax', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', 'around', 'round', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', @@ -3865,6 +3865,12 @@ def maximum(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.maximum, _np.maximum, _npi.maximum_scalar, None, out) +@set_module('mxnet.symbol.numpy') +@wrap_np_binary_func +def fmax(x1, x2, out=None, **kwargs): + return _ufunc_helper(x1, x2, _npi.fmax, _np.fmax, _npi.fmax_scalar, None, out) + + @set_module('mxnet.symbol.numpy') @wrap_np_binary_func def minimum(x1, x2, out=None, **kwargs): diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index fa424ad6d0fc..c69a375eaf1b 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -1107,6 +1107,20 @@ struct maximum : public mxnet_op::tunable { } }; +/*! \brief used for computing binary operator fmax */ +struct fmax : public mxnet_op::tunable { + template + MSHADOW_XINLINE static DType Map(DType a, DType b) { + if (IsNan(b)) { + return a; + } else if (IsNan(a)) { + return b; + } else { + return (a > b ? a : b); + } + } +}; + /*! \brief used for computing binary operator minimum */ struct minimum : public mxnet_op::tunable { template diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc index 70233a596dc7..b8919fc7dcf4 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -371,5 +371,36 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar) .set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) .set_attr("FCompute", BinaryScalarOp::Backward); +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(broadcast_fmax) +.add_alias("_npi_fmax") +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_fmax"}); + +NNVM_REGISTER_OP(_backward_broadcast_fmax) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_fmax_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_fmax_scalar"}) +.add_alias("_FmaxScalar") +.add_alias("_npi_fmax_scalar"); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_fmax_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", BinaryScalarOp::Backward); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu index 8f135b3efd03..bc76d2985649 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu @@ -116,5 +116,18 @@ NNVM_REGISTER_OP(_backward_npi_ldexp_scalar) NNVM_REGISTER_OP(_backward_npi_rldexp_scalar) .set_attr("FCompute", BinaryScalarOp::Backward); +NNVM_REGISTER_OP(_npi_fmax) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_fmax) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +NVM_REGISTER_OP(_npi_fmax_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_fmax_scalar) +.set_attr("FCompute", BinaryScalarOp::Backward); + } // namespace op } // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index c0e9a63af892..ae44181d4d87 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -369,6 +369,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gelu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::elu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::maximum); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::fmax); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minimum); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::hypot_grad_left); // NOLINT() diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 962a46cef7e2..961a918653c4 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1415,6 +1415,13 @@ def _add_workload_maximum(array_pool): OpArgMngr.add_workload('maximum', array_pool['4x1'], array_pool['1x1x0']) +def _add_workload_fmax(array_pool): + OpArgMngr.add_workload('fmax', array_pool['4x1'], array_pool['1x2']) + OpArgMngr.add_workload('fmax', array_pool['4x1'], 2) + OpArgMngr.add_workload('fmax', 2, array_pool['4x1']) + OpArgMngr.add_workload('fmax', array_pool['4x1'], array_pool['1x1x0']) + + def _add_workload_minimum(array_pool): OpArgMngr.add_workload('minimum', array_pool['4x1'], array_pool['1x2']) OpArgMngr.add_workload('minimum', array_pool['4x1'], 2) @@ -1917,6 +1924,7 @@ def _prepare_workloads(): _add_workload_mod(array_pool) _add_workload_remainder() _add_workload_maximum(array_pool) + _add_workload_fmax(array_pool) _add_workload_minimum(array_pool) _add_workload_negative(array_pool) _add_workload_absolute(array_pool) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 0dc893cc7361..3fcdc23a7b41 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2126,6 +2126,8 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): 'bitwise_or': (-100, 100, [None], None, [[_np.int32]]), 'maximum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)], [lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)]), + 'fmax': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)], + [lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)]), 'minimum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 <= x2)], [lambda y, x1, x2: _np.ones(y.shape) * (x1 > x2)]), 'copysign': (-1, 1,