diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index e55fa1af90e8..ac6ee1561c47 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1308,25 +1308,31 @@ def test_norm(ctx=default_context()): def l1norm(input_data, axis=0, keepdims=False): return np.sum(abs(input_data), axis=axis, keepdims=keepdims) - def l2norm(input_data, axis=0, keepdims=False): + def l2norm(input_data, axis=0, keepdims=False): return sp_norm(input_data, axis=axis, keepdims=keepdims) in_data_dim = random_sample([4,5,6], 1)[0] - in_data_shape = rand_shape_nd(in_data_dim) - np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32) - mx_arr = mx.nd.array(np_arr, ctx=ctx) - for ord in [1,2]: - for keep_dims in [True, False]: - for i in range(4): - npy_out = l1norm(np_arr, i, keep_dims) if ord==1 else l2norm(np_arr, i, keep_dims) - mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims) - assert npy_out.shape == mx_out.shape - mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy()) - if (i < 3): - npy_out = l1norm(np_arr, (i, i+1), keep_dims) if ord==1 else l2norm(np_arr, (i, i+1), keep_dims) - mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i+1), keepdims=keep_dims) + for force_reduce_dim1 in [True, False]: + in_data_shape = rand_shape_nd(in_data_dim) + if force_reduce_dim1: + in_data_shape = in_data_shape[:3] + (1, ) + in_data_shape[4:] + np_arr = np.random.uniform(-1, 1, in_data_shape).astype(np.float32) + mx_arr = mx.nd.array(np_arr, ctx=ctx) + for ord in [1, 2]: + for keep_dims in [True, False]: + for i in range(4): + npy_out = l1norm(np_arr, i, keep_dims) if ord == 1 else l2norm( + np_arr, i, keep_dims) + mx_out = mx.nd.norm(mx_arr, ord=ord, axis=i, keepdims=keep_dims) assert npy_out.shape == mx_out.shape mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy()) + if (i < 3): + npy_out = l1norm(np_arr, (i, i + 1), keep_dims) if ord == 1 else l2norm( + np_arr, (i, i + 1), keep_dims) + mx_out = mx.nd.norm(mx_arr, ord=ord, axis=(i, i + 1), keepdims=keep_dims) + assert npy_out.shape == mx_out.shape + mx.test_utils.assert_almost_equal(npy_out, mx_out.asnumpy()) + @with_seed() def test_ndarray_cpu_shared_ctx():