Skip to content

Commit

Permalink
Fix test for specific backend imports (#591)
Browse files Browse the repository at this point in the history
  • Loading branch information
sampathweb authored Jul 24, 2023
1 parent 628db59 commit 133bbdb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
13 changes: 8 additions & 5 deletions keras_core/backend/common/keras_tensor_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import tensorflow as tf
from jax import numpy as jnp

from keras_core import backend
from keras_core import ops
from keras_core import testing
from keras_core.backend.common import keras_tensor
Expand Down Expand Up @@ -37,10 +37,13 @@ def test_invalid_usage(self):
):
np.array(x)

with self.assertRaisesRegex(
ValueError, "cannot be used as input to a JAX function"
):
jnp.array(x)
if backend.backend() == "jax":
from jax import numpy as jnp

with self.assertRaisesRegex(
ValueError, "cannot be used as input to a JAX function"
):
jnp.array(x)

with self.assertRaisesRegex(
ValueError, "cannot be used as input to a TensorFlow function"
Expand Down
25 changes: 14 additions & 11 deletions keras_core/trainers/data_adapters/array_data_adapter_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import jax
import numpy as np
import pandas
import tensorflow as tf
import torch
from absl.testing import parameterized

from keras_core import backend
Expand All @@ -11,22 +9,27 @@


class TestArrayDataAdapter(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(
[("np",), ("tf",), ("jax",), ("torch",), ("pandas")]
)
@parameterized.parameters([("np",), ("tf",), ("backend",), ("pandas",)])
def test_basic_flow(self, array_type):
if array_type == "np":
x = np.random.random((34, 4))
y = np.random.random((34, 2))
elif array_type == "tf":
x = tf.random.normal((34, 4))
y = tf.random.normal((34, 2))
elif array_type == "jax":
x = jax.numpy.ones((34, 4))
y = jax.numpy.ones((34, 2))
elif array_type == "torch":
x = torch.ones((34, 4))
y = torch.ones((34, 2))
elif array_type == "backend":
if backend.backend() == "jax":
import jax

x = jax.numpy.ones((34, 4))
y = jax.numpy.ones((34, 2))
elif backend.backend() == "torch":
import torch

x = torch.ones((34, 4))
y = torch.ones((34, 2))
else:
return # skip TF already addressed
elif array_type == "pandas":
x = pandas.DataFrame(np.random.random((34, 4)))
y = pandas.DataFrame(np.random.random((34, 2)))
Expand Down

0 comments on commit 133bbdb

Please sign in to comment.