From 5315602d094986af6f411a14c45b74e74c8fd760 Mon Sep 17 00:00:00 2001 From: Xi Wang Date: Fri, 13 Mar 2020 02:14:28 +0800 Subject: [PATCH] fix np.clip scalar input case (#17788) --- python/mxnet/numpy/multiarray.py | 5 +++++ tests/python/unittest/test_numpy_op.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 4396cfa5207f..09712c03ea96 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -6306,6 +6306,11 @@ def clip(a, a_min, a_max, out=None): >>> np.clip(a, 3, 6, out=a) array([3., 3., 3., 3., 4., 5., 6., 6., 6., 6.], dtype=float32) """ + from numbers import Number + if isinstance(a, Number): + # In case input is a scalar, the computation would fall back to native numpy. + # The value returned would be a python scalar. + return _np.clip(a, a_min, a_max, out=None) return _mx_nd_np.clip(a, a_min, a_max, out=out) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index b8e93437b2af..e33df0ef91f2 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3487,6 +3487,16 @@ def __init__(self, a_min=None, a_max=None): def hybrid_forward(self, F, x): return x.clip(self._a_min, self._a_max) + + # Test scalar case + for _, a_min, a_max, throw_exception in workloads: + a = _np.random.uniform() # A scalar + if throw_exception: + # No need to test the exception case here. + continue + mx_ret = np.clip(a, a_min, a_max) + np_ret = _np.clip(a, a_min, a_max) + assert_almost_equal(mx_ret, np_ret, atol=1e-4, rtol=1e-3, use_broadcast=False) for shape, a_min, a_max, throw_exception in workloads: for dtype in dtypes: