Skip to content

Commit

Permalink
add jacobian / hessian pytree tests (fixes #173)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Jan 7, 2019
1 parent e4f56d7 commit ca27f0a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
5 changes: 5 additions & 0 deletions jax/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,11 @@ def assertAllClose(self, x, y, check_dtypes, atol=None, rtol=None):
self.assertEqual(len(x), len(y))
for x_elt, y_elt in zip(x, y):
self.assertAllClose(x_elt, y_elt, check_dtypes, atol=atol, rtol=rtol)
elif isinstance(x, dict):
self.assertIsInstance(y, dict)
self.assertEqual(set(x.keys()), set(y.keys()))
for k in x.keys():
self.assertAllClose(x[k], y[k], check_dtypes, atol=atol, rtol=rtol)
else:
is_array = lambda x: hasattr(x, '__array__') or onp.isscalar(x)
self.assertTrue(is_array(x))
Expand Down
47 changes: 46 additions & 1 deletion tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from jax import test_util as jtu

import jax.numpy as np
from jax import jit, grad, device_get, device_put, jacfwd, jacrev
from jax import jit, grad, device_get, device_put, jacfwd, jacrev, hessian
from jax import api
from jax.core import Primitive
from jax.interpreters.partial_eval import def_abstract_eval
Expand Down Expand Up @@ -260,6 +260,15 @@ def test_jacobian(self):
f = lambda x: np.tanh(np.dot(A, x))
assert onp.allclose(jacfwd(f)(x), jacrev(f)(x))

@jtu.skip_on_devices("tpu")
def test_hessian(self):
R = onp.random.RandomState(0).randn
A = R(4, 4)
x = R(4)

f = lambda x: np.dot(x, np.dot(A, x))
assert onp.allclose(hessian(f)(x), A + A.T)

def test_std_basis(self):
basis = api._std_basis(np.zeros(3))
assert getattr(basis, "shape", None) == (3, 3)
Expand All @@ -276,6 +285,42 @@ def test_std_basis(self):
assert getattr(basis[1][0], "shape", None) == (16, 3)
assert getattr(basis[1][1], "shape", None) == (16, 3, 4)

@jtu.skip_on_devices("tpu")
def test_jacobian_on_pytrees(self):
for jacfun in [jacfwd, jacrev]:
ans = jacfun(lambda x, y: (x, y))(0., 1.)
expected = (1., 0.)
self.assertAllClose(ans, expected, check_dtypes=False)

ans = jacfun(lambda x, y: (x, y), 1)(0., 1.)
expected = (0., 1.)
self.assertAllClose(ans, expected, check_dtypes=False)

ans = jacfun(lambda x, y: (x, y), (0, 1))(0., 1.)
expected = ((1., 0.),
(0., 1.),)
self.assertAllClose(ans, expected, check_dtypes=False)

ans = jacfun(lambda x: x[:2])((1., 2., 3.))
expected = ((1., 0., 0.),
(0., 1., 0.))
self.assertAllClose(ans, expected, check_dtypes=False)

R = onp.random.RandomState(0).randn
x = R(2)
y = R(3)
ans = jacfun(lambda x, y: {'x': x, 'xy': np.outer(x, y)})(x, y)
expected = {'x': onp.eye(2),
'xy': onp.kron(onp.eye(2), y[:, None]).reshape(2, 3, 2)}
self.assertAllClose(ans, expected, check_dtypes=False)

@jtu.skip_on_devices("tpu")
def test_hessian_on_pytrees(self):
ans = hessian(lambda x: np.array(x)**2)((1., 2.))
expected = ((onp.array([2., 0.]), onp.array([0., 0.])),
(onp.array([0., 0.]), onp.array([0., 2.])))
self.assertAllClose(ans, expected, check_dtypes=False)


if __name__ == '__main__':
absltest.main()

0 comments on commit ca27f0a

Please sign in to comment.