diff --git a/keras_core/backend/common/keras_tensor_test.py b/keras_core/backend/common/keras_tensor_test.py index f93596deb..9a35898e7 100644 --- a/keras_core/backend/common/keras_tensor_test.py +++ b/keras_core/backend/common/keras_tensor_test.py @@ -1,7 +1,7 @@ import numpy as np import tensorflow as tf -from jax import numpy as jnp +from keras_core import backend from keras_core import ops from keras_core import testing from keras_core.backend.common import keras_tensor @@ -37,10 +37,13 @@ def test_invalid_usage(self): ): np.array(x) - with self.assertRaisesRegex( - ValueError, "cannot be used as input to a JAX function" - ): - jnp.array(x) + if backend.backend() == "jax": + from jax import numpy as jnp + + with self.assertRaisesRegex( + ValueError, "cannot be used as input to a JAX function" + ): + jnp.array(x) with self.assertRaisesRegex( ValueError, "cannot be used as input to a TensorFlow function" diff --git a/keras_core/trainers/data_adapters/array_data_adapter_test.py b/keras_core/trainers/data_adapters/array_data_adapter_test.py index babb97df6..a94babe34 100644 --- a/keras_core/trainers/data_adapters/array_data_adapter_test.py +++ b/keras_core/trainers/data_adapters/array_data_adapter_test.py @@ -1,8 +1,6 @@ -import jax import numpy as np import pandas import tensorflow as tf -import torch from absl.testing import parameterized from keras_core import backend @@ -11,9 +9,7 @@ class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase): - @parameterized.parameters( - [("np",), ("tf",), ("jax",), ("torch",), ("pandas")] - ) + @parameterized.parameters([("np",), ("tf",), ("backend",), ("pandas",)]) def test_basic_flow(self, array_type): if array_type == "np": x = np.random.random((34, 4)) @@ -21,12 +17,19 @@ def test_basic_flow(self, array_type): elif array_type == "tf": x = tf.random.normal((34, 4)) y = tf.random.normal((34, 2)) - elif array_type == "jax": - x = jax.numpy.ones((34, 4)) - y = jax.numpy.ones((34, 2)) - elif array_type == "torch": - x = torch.ones((34, 4)) - y = torch.ones((34, 2)) + elif array_type == "backend": + if backend.backend() == "jax": + import jax + + x = jax.numpy.ones((34, 4)) + y = jax.numpy.ones((34, 2)) + elif backend.backend() == "torch": + import torch + + x = torch.ones((34, 4)) + y = torch.ones((34, 2)) + else: + return # skip TF already addressed elif array_type == "pandas": x = pandas.DataFrame(np.random.random((34, 4))) y = pandas.DataFrame(np.random.random((34, 2)))