Skip to content

Commit

Permalink
Fix the simple bug on call_tf.replace_non_float and add unittest for …
Browse files Browse the repository at this point in the history
…floating and complex data type.

PiperOrigin-RevId: 510055139
  • Loading branch information
maxwillzq authored and jax authors committed Feb 16, 2023
1 parent 26045c4 commit d0b42f2
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 4 deletions.
14 changes: 10 additions & 4 deletions jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import jax
from jax import dlpack
from jax import dtypes
from jax import numpy as jnp
from jax import tree_util
from jax._src import util
from jax._src import ad_util
Expand Down Expand Up @@ -200,9 +201,9 @@ def tf_vjp_fun(args_tf, ct_res_tf):
"""Invoke TF gradient."""

# TF does not like us to watch non-float vars
def replace_non_float(arg):
if np.issubdtype(arg.dtype.as_numpy_dtype, np.inexact):
return arg
def replace_non_float(arg_tf):
if arg_tf.dtype.is_floating or arg_tf.dtype.is_complex:
return arg_tf
else:
# When watched, this will be ignored. When use in results it will
# result in a floating 0. gradient, which JAX will ignore (and
Expand Down Expand Up @@ -273,7 +274,12 @@ def _res_tf_to_jax(res_tf: TfVal):
res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
return jax.dlpack.from_dlpack(res_dlpack)

return jax.device_put(np.asarray(res_tf))
# When working with a bfloat16 scalar tf.Tensor,np.asarray() can fail.
# To handle this special case, we create a numpy copy.
if res_tf.shape == tf.TensorShape([]) and res_tf.dtype == tf.bfloat16:
return jax.device_put(jnp.array(res_tf.numpy()))
else:
return jax.device_put(np.asarray(res_tf))

return list(map(_res_tf_to_jax, res_tf_flat))

Expand Down
102 changes: 102 additions & 0 deletions jax/experimental/jax2tf/tests/call_tf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ def test_eval_devicearray_no_copy(self):
self.assertAllClose(x, res)
self.assertTrue(np.shares_memory(x, res))

x = jnp.array(3.0, dtype=jnp.bfloat16)
res = jax2tf.call_tf(lambda x: x)(x)
self.assertAllClose(x, res)
# bfloat16 scalar will create a copy.
with self.assertRaises(AssertionError):
self.assertTrue(np.shares_memory(x, res))

@_parameterized_jit
def test_eval_pytree(self, with_jit=True):

Expand Down Expand Up @@ -868,6 +875,101 @@ def fun_jax(x, y):
res = tf.function(converted_fun, jit_compile=True, autograph=False)(x, y)
self.assertAllClose(expected, res.numpy(), atol=1e-5, rtol=1e-5)

@parameterized.named_parameters(
dict(
testcase_name=f"_{dtype.__name__}",
dtype=dtype,
)
for dtype in set(jtu.dtypes.all_floating)
)
def test_all_floating_input_gradient(self, dtype):
def tf_f(x):
res = tf.math.sin(x)
return tf.reduce_sum(res)

jax_f = jax2tf.call_tf(tf_f)
tf_f_rt = jax2tf.convert(jax_f)
x = jnp.array([5.0, 6.0, 7.0]).astype(dtype)

def assert_all_close_support_bfloat16(baseline, candidate):
def conversion(x):
# convert scalar to array and bfloat16 to float32
# to support self.assertAllClose numpy array comparision.
if x.shape == tf.TensorShape([]):
x = tf.convert_to_tensor([x])
if dtype == jnp.float16:
x = tf.cast(x, tf.float32)
return x

baseline = jax.tree_util.tree_map(conversion, baseline)
candidate = jax.tree_util.tree_map(conversion, candidate)
self.assertAllClose(baseline, candidate)

# Eager mode
assert_all_close_support_bfloat16(tf_f(x), tf_f_rt(x))

# Compiled function mode
assert_all_close_support_bfloat16(
tf.function(tf_f)(x), tf.function(tf_f_rt)(x)
)

# Compiled fucntion mode with jit_compiled=True
assert_all_close_support_bfloat16(
tf.function(tf_f, jit_compile=True)(x),
tf.function(tf_f_rt, jit_compile=True)(x),
)

# RoundTrip test for the gradient
grad_fun_jax = jax.grad(jax2tf.call_tf(tf_f))
grad_fun_jax_rt = jax2tf.call_tf(jax2tf.convert(grad_fun_jax))

# Eager mode
assert_all_close_support_bfloat16(grad_fun_jax(x), grad_fun_jax_rt(x))

# Jit mode
assert_all_close_support_bfloat16(
jax.jit(grad_fun_jax)(x), jax.jit(grad_fun_jax_rt)(x)
)

@parameterized.named_parameters(
dict(
testcase_name=f"_{dtype.__name__}",
dtype=dtype,
)
for dtype in set(jtu.dtypes.complex)
)
def test_complex_input_gradient(self, dtype):
def tf_f(x):
res = tf.math.sin(x)
return tf.reduce_sum(res)

x = jnp.array([(5.0 + 4.0j), (6.0 + 3.0j), (7.0 + 8.0j)]).astype(dtype)

jax_f = jax2tf.call_tf(tf_f)
tf_f_rt = jax2tf.convert(jax_f)

# Eager mode
self.assertAllClose(tf_f(x), tf_f_rt(x))

# tf.function context
self.assertAllClose(tf.function(tf_f)(x), tf.function(tf_f_rt)(x))

# tf.function context with jit_compiled=True
self.assertAllClose(
tf.function(tf_f, jit_compile=True)(x),
tf.function(tf_f_rt, jit_compile=True)(x),
)

# RoundTrip test for the gradient
grad_fun_jax = jax.grad(jax2tf.call_tf(tf_f), holomorphic=True)
grad_fun_jax_rt = jax2tf.call_tf(jax2tf.convert(grad_fun_jax))

# Eager mode
self.assertAllClose(grad_fun_jax(x), grad_fun_jax_rt(x))

# Jit mode
self.assertAllClose(jax.jit(grad_fun_jax)(x), jax.jit(grad_fun_jax_rt)(x))


class RoundTripToTfTest(tf_test_util.JaxToTfTestCase):
"Reloading output of call_tf into TF with jax2tf."
Expand Down

0 comments on commit d0b42f2

Please sign in to comment.