diff --git a/thinc/tests/layers/test_tensorflow_wrapper.py b/thinc/tests/layers/test_tensorflow_wrapper.py index 948e1f035..1c10b8242 100644 --- a/thinc/tests/layers/test_tensorflow_wrapper.py +++ b/thinc/tests/layers/test_tensorflow_wrapper.py @@ -64,11 +64,12 @@ def model(tf_model): def test_tensorflow_wrapper_roundtrip_conversion(): import tensorflow as tf - xp_tensor = numpy.zeros((2, 3), dtype="f") + ops = get_current_ops() + xp_tensor = ops.alloc2f(2, 3, zeros=True) tf_tensor = xp2tensorflow(xp_tensor) assert isinstance(tf_tensor, tf.Tensor) - new_xp_tensor = tensorflow2xp(tf_tensor) - assert numpy.array_equal(xp_tensor, new_xp_tensor) + new_xp_tensor = tensorflow2xp(tf_tensor, ops=ops) + assert ops.xp.array_equal(xp_tensor, new_xp_tensor) @pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow") @@ -99,8 +100,12 @@ def test_tensorflow_wrapper_predict(model, X): @pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow") def test_tensorflow_wrapper_train_overfits(model, X, Y, answer): optimizer = Adam() + ops = get_current_ops() for i in range(100): guesses, backprop = model(X, is_train=True) + # Ensure that the tensor is type-compatible with the current backend. + guesses = ops.asarray(guesses) + d_guesses = (guesses - Y) / guesses.shape[0] backprop(d_guesses) model.finish_update(optimizer) @@ -114,8 +119,12 @@ def test_tensorflow_wrapper_accumulate_gradients(model, X, Y, answer): optimizer = Adam() gradients = [] + ops = get_current_ops() for i in range(3): guesses, backprop = model(X, is_train=True) + # Ensure that the tensor is type-compatible with the current backend. + guesses = ops.asarray(guesses) + d_guesses = (guesses - Y) / guesses.shape[0] backprop(d_guesses) shim_grads = [tf.identity(var) for var in model.shims[0].gradients] @@ -173,6 +182,9 @@ def call(self, inputs) -> tf.Tensor: optimizer = Adam() for i in range(50): guesses, backprop = model(X, is_train=True) + # Ensure that the tensor is type-compatible with the current backend. + guesses = ops.asarray(guesses) + d_guesses = (guesses - Y) / guesses.shape[0] backprop(d_guesses) model.finish_update(optimizer) @@ -323,10 +335,14 @@ def test_tensorflow_wrapper_from_bytes(model, X): @pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow") def test_tensorflow_wrapper_use_params(model, X, Y, answer): optimizer = Adam() + ops = get_current_ops() with model.use_params(optimizer.averages): assert model.predict(X).argmax() is not None for i in range(10): guesses, backprop = model.begin_update(X) + # Ensure that the tensor is type-compatible with the current backend. + guesses = ops.asarray(guesses) + d_guesses = (guesses - Y) / guesses.shape[0] backprop(d_guesses) model.finish_update(optimizer) diff --git a/thinc/tests/util.py b/thinc/tests/util.py index 867e4839e..7440a4b6e 100644 --- a/thinc/tests/util.py +++ b/thinc/tests/util.py @@ -5,6 +5,7 @@ from thinc.api import Linear, Ragged, Padded, ArgsKwargs import numpy import pytest +from thinc.util import has_cupy, is_cupy_array, is_numpy_array @contextlib.contextmanager @@ -95,18 +96,24 @@ def check_input_converters(Y, backprop, data, n_args, kwargs_keys, type_): assert all(isinstance(arg, type_) for arg in Y.args) assert all(isinstance(arg, type_) for arg in Y.kwargs.values()) dX = backprop(Y) + + def is_supported_backend_array(arr): + return is_cupy_array(arr) or is_numpy_array(arr) + input_type = type(data) if not isinstance(data, list) else tuple - assert isinstance(dX, input_type) + assert isinstance(dX, input_type) or is_supported_backend_array(dX) + if isinstance(data, dict): assert list(dX.keys()) == kwargs_keys - assert all(isinstance(arr, numpy.ndarray) for arr in dX.values()) + assert all(is_supported_backend_array(arr) for arr in dX.values()) elif isinstance(data, (list, tuple)): assert isinstance(dX, tuple) - assert all(isinstance(arr, numpy.ndarray) for arr in dX) + assert all(is_supported_backend_array(arr) for arr in dX) elif isinstance(data, ArgsKwargs): assert len(dX.args) == n_args assert list(dX.kwargs.keys()) == kwargs_keys - assert all(isinstance(arg, numpy.ndarray) for arg in dX.args) - assert all(isinstance(arg, numpy.ndarray) for arg in dX.kwargs.values()) + + assert all(is_supported_backend_array(arg) for arg in dX.args) + assert all(is_supported_backend_array(arg) for arg in dX.kwargs.values()) elif not isinstance(data, numpy.ndarray): pytest.fail(f"Bad data type: {dX}")