From d0b42f2ce8006430e0e312c1431c53668ed732d7 Mon Sep 17 00:00:00 2001 From: John QiangZhang Date: Wed, 15 Feb 2023 23:40:12 -0800 Subject: [PATCH] Fix the simple bug on call_tf.replace_non_float and add unittest for floating and complex data type. PiperOrigin-RevId: 510055139 --- jax/experimental/jax2tf/call_tf.py | 14 ++- jax/experimental/jax2tf/tests/call_tf_test.py | 102 ++++++++++++++++++ 2 files changed, 112 insertions(+), 4 deletions(-) diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 40fcb749e73f..33a8661b8aca 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -31,6 +31,7 @@ import jax from jax import dlpack from jax import dtypes +from jax import numpy as jnp from jax import tree_util from jax._src import util from jax._src import ad_util @@ -200,9 +201,9 @@ def tf_vjp_fun(args_tf, ct_res_tf): """Invoke TF gradient.""" # TF does not like us to watch non-float vars - def replace_non_float(arg): - if np.issubdtype(arg.dtype.as_numpy_dtype, np.inexact): - return arg + def replace_non_float(arg_tf): + if arg_tf.dtype.is_floating or arg_tf.dtype.is_complex: + return arg_tf else: # When watched, this will be ignored. When use in results it will # result in a floating 0. gradient, which JAX will ignore (and @@ -273,7 +274,12 @@ def _res_tf_to_jax(res_tf: TfVal): res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf) return jax.dlpack.from_dlpack(res_dlpack) - return jax.device_put(np.asarray(res_tf)) + # When working with a bfloat16 scalar tf.Tensor,np.asarray() can fail. + # To handle this special case, we create a numpy copy. + if res_tf.shape == tf.TensorShape([]) and res_tf.dtype == tf.bfloat16: + return jax.device_put(jnp.array(res_tf.numpy())) + else: + return jax.device_put(np.asarray(res_tf)) return list(map(_res_tf_to_jax, res_tf_flat)) diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 6e8f2046e0f6..08bf41a6032f 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -117,6 +117,13 @@ def test_eval_devicearray_no_copy(self): self.assertAllClose(x, res) self.assertTrue(np.shares_memory(x, res)) + x = jnp.array(3.0, dtype=jnp.bfloat16) + res = jax2tf.call_tf(lambda x: x)(x) + self.assertAllClose(x, res) + # bfloat16 scalar will create a copy. + with self.assertRaises(AssertionError): + self.assertTrue(np.shares_memory(x, res)) + @_parameterized_jit def test_eval_pytree(self, with_jit=True): @@ -868,6 +875,101 @@ def fun_jax(x, y): res = tf.function(converted_fun, jit_compile=True, autograph=False)(x, y) self.assertAllClose(expected, res.numpy(), atol=1e-5, rtol=1e-5) + @parameterized.named_parameters( + dict( + testcase_name=f"_{dtype.__name__}", + dtype=dtype, + ) + for dtype in set(jtu.dtypes.all_floating) + ) + def test_all_floating_input_gradient(self, dtype): + def tf_f(x): + res = tf.math.sin(x) + return tf.reduce_sum(res) + + jax_f = jax2tf.call_tf(tf_f) + tf_f_rt = jax2tf.convert(jax_f) + x = jnp.array([5.0, 6.0, 7.0]).astype(dtype) + + def assert_all_close_support_bfloat16(baseline, candidate): + def conversion(x): + # convert scalar to array and bfloat16 to float32 + # to support self.assertAllClose numpy array comparision. + if x.shape == tf.TensorShape([]): + x = tf.convert_to_tensor([x]) + if dtype == jnp.float16: + x = tf.cast(x, tf.float32) + return x + + baseline = jax.tree_util.tree_map(conversion, baseline) + candidate = jax.tree_util.tree_map(conversion, candidate) + self.assertAllClose(baseline, candidate) + + # Eager mode + assert_all_close_support_bfloat16(tf_f(x), tf_f_rt(x)) + + # Compiled function mode + assert_all_close_support_bfloat16( + tf.function(tf_f)(x), tf.function(tf_f_rt)(x) + ) + + # Compiled fucntion mode with jit_compiled=True + assert_all_close_support_bfloat16( + tf.function(tf_f, jit_compile=True)(x), + tf.function(tf_f_rt, jit_compile=True)(x), + ) + + # RoundTrip test for the gradient + grad_fun_jax = jax.grad(jax2tf.call_tf(tf_f)) + grad_fun_jax_rt = jax2tf.call_tf(jax2tf.convert(grad_fun_jax)) + + # Eager mode + assert_all_close_support_bfloat16(grad_fun_jax(x), grad_fun_jax_rt(x)) + + # Jit mode + assert_all_close_support_bfloat16( + jax.jit(grad_fun_jax)(x), jax.jit(grad_fun_jax_rt)(x) + ) + + @parameterized.named_parameters( + dict( + testcase_name=f"_{dtype.__name__}", + dtype=dtype, + ) + for dtype in set(jtu.dtypes.complex) + ) + def test_complex_input_gradient(self, dtype): + def tf_f(x): + res = tf.math.sin(x) + return tf.reduce_sum(res) + + x = jnp.array([(5.0 + 4.0j), (6.0 + 3.0j), (7.0 + 8.0j)]).astype(dtype) + + jax_f = jax2tf.call_tf(tf_f) + tf_f_rt = jax2tf.convert(jax_f) + + # Eager mode + self.assertAllClose(tf_f(x), tf_f_rt(x)) + + # tf.function context + self.assertAllClose(tf.function(tf_f)(x), tf.function(tf_f_rt)(x)) + + # tf.function context with jit_compiled=True + self.assertAllClose( + tf.function(tf_f, jit_compile=True)(x), + tf.function(tf_f_rt, jit_compile=True)(x), + ) + + # RoundTrip test for the gradient + grad_fun_jax = jax.grad(jax2tf.call_tf(tf_f), holomorphic=True) + grad_fun_jax_rt = jax2tf.call_tf(jax2tf.convert(grad_fun_jax)) + + # Eager mode + self.assertAllClose(grad_fun_jax(x), grad_fun_jax_rt(x)) + + # Jit mode + self.assertAllClose(jax.jit(grad_fun_jax)(x), jax.jit(grad_fun_jax_rt)(x)) + class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): "Reloading output of call_tf into TF with jax2tf."