Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
yijunc committed Jul 2, 2020
1 parent e639c36 commit 02d4fbf
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 66 deletions.
6 changes: 0 additions & 6 deletions tests/python/unittest/test_higher_order_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ def grad_grad_op(x):
check_second_order_unary(array, arccos, grad_grad_op)


@xfail_when_nonstandard_decimal_separator
@with_seed()
def test_arctan():
def arctan(x):
Expand Down Expand Up @@ -217,7 +216,6 @@ def grad_grad_op(x):
check_second_order_unary(array, arccosh, grad_grad_op)


@xfail_when_nonstandard_decimal_separator
@with_seed()
def test_arctanh():
def arctanh(x):
Expand Down Expand Up @@ -294,7 +292,6 @@ def grad_grad_op(x):
check_second_order_unary(array, log2, grad_grad_op)


@xfail_when_nonstandard_decimal_separator
@with_seed()
def test_log10():
def log10(x):
Expand All @@ -309,7 +306,6 @@ def grad_grad_op(x):
check_second_order_unary(array, log10, grad_grad_op)


@xfail_when_nonstandard_decimal_separator
@with_seed()
def test_square():
def grad_grad_op(x):
Expand Down Expand Up @@ -461,7 +457,6 @@ def grad_grad_op(x):
check_second_order_unary(array, cbrt, grad_grad_op)


@xfail_when_nonstandard_decimal_separator
@with_seed()
def test_rsqrt():
def rsqrt(x):
Expand All @@ -482,7 +477,6 @@ def grad_grad_op(x):
check_second_order_unary(array, rsqrt, grad_grad_op)


@xfail_when_nonstandard_decimal_separator
@with_seed()
def test_rcbrt():
def rcbrt(x):
Expand Down
60 changes: 0 additions & 60 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2553,66 +2553,6 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
continue
check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2)

@with_seed()
@use_np
def test_np_binary_scalar_funcs():
itypes = [np.int8, np.int32, np.int64]
def check_binary_scalar_func(func, low, high, lshape, lgrad, ltype, scalar_is_int, hybridize):
class TestBinaryScalar(HybridBlock):
def __init__(self, func, scalar):
super(TestBinaryScalar, self).__init__()
self._func = func
self._scalar = scalar

def hybrid_forward(self, F, a, *args, **kwargs):
return getattr(F.np, self._func)(a, self._scalar)

np_test_x1 = _np.random.uniform(low, high, lshape).astype(ltype)
np_test_x2 = int(_np.random.uniform(low, high)) if scalar_is_int else _np.random.uniform(low, high)
mx_test_x1 = np.array(np_test_x1, dtype=ltype)
mx_test_x2 = np_test_x2
np_func = getattr(_np, func)
mx_func = TestBinaryScalar(func, mx_test_x2)
if hybridize:
mx_func.hybridize()
rtol = 1e-2 if ltype is np.float16 else 1e-3
atol = 1e-3 if ltype is np.float16 else 1e-5
if ltype not in itypes:
if lgrad:
mx_test_x1.attach_grad()
np_out = np_func(np_test_x1, np_test_x2)
with mx.autograd.record():
y = mx_func(mx_test_x1)
assert y.shape == np_out.shape
assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), rtol=rtol, atol=atol)
if lgrad:
y.backward()
assert_almost_equal(mx_test_x1.grad.asnumpy(),
collapse_sum_like(lgrad(y.asnumpy(), np_test_x1, np_test_x2), mx_test_x1.shape),
rtol=rtol, atol=atol, equal_nan=True, use_broadcast=False)

# Test imperative
np_out = getattr(_np, func)(np_test_x1, np_test_x2)
mx_out = getattr(mx.np, func)(mx_test_x1, mx_test_x2)
assert mx_out.shape == np_out.shape
assert mx_out.asnumpy().dtype == np_out.dtype
assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), rtol=rtol, atol=atol)

funcs = {
'add': (-1.0, 1.0, None),
'subtract': (-1.0, 1.0, None),
'multiply': (-1.0, 1.0, lambda y, x1, x2: _np.broadcast_to(x2, y.shape)),
'power': (1.0, 5.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2),
}

shapes = [(3, 2), (3, 0), (3, 1), (0, 2), (2, 3, 4)]
ltypes = [np.int32, np.int64, np.float16, np.float32, np.float64]
flags = [True, False]
for func, func_data in funcs.items():
low, high, lgrad = func_data
for shape, ltype, is_int, hybridize in itertools.product(shapes, ltypes, flags, flags):
check_binary_scalar_func(func, low, high, shape, lgrad, ltype, is_int, hybridize)


@with_seed()
@use_np
Expand Down

0 comments on commit 02d4fbf

Please sign in to comment.