Skip to content

Commit

Permalink
Tests: Correctly handle GPU-resident Tensorflow tensors (explosion#653)
Browse files Browse the repository at this point in the history
* Tests: Correctly handle GPU-resident Tensorflow tensors

* Simplify `is_supported_backend_array`
  • Loading branch information
shadeMe authored and danieldk committed May 11, 2022
1 parent 8d6196e commit f0b7e43
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
22 changes: 19 additions & 3 deletions thinc/tests/layers/test_tensorflow_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,12 @@ def model(tf_model):
def test_tensorflow_wrapper_roundtrip_conversion():
import tensorflow as tf

xp_tensor = numpy.zeros((2, 3), dtype="f")
ops = get_current_ops()
xp_tensor = ops.alloc2f(2, 3, zeros=True)
tf_tensor = xp2tensorflow(xp_tensor)
assert isinstance(tf_tensor, tf.Tensor)
new_xp_tensor = tensorflow2xp(tf_tensor)
assert numpy.array_equal(xp_tensor, new_xp_tensor)
new_xp_tensor = tensorflow2xp(tf_tensor, ops=ops)
assert ops.xp.array_equal(xp_tensor, new_xp_tensor)


@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
Expand Down Expand Up @@ -99,8 +100,12 @@ def test_tensorflow_wrapper_predict(model, X):
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_train_overfits(model, X, Y, answer):
optimizer = Adam()
ops = get_current_ops()
for i in range(100):
guesses, backprop = model(X, is_train=True)
# Ensure that the tensor is type-compatible with the current backend.
guesses = ops.asarray(guesses)

d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
model.finish_update(optimizer)
Expand All @@ -114,8 +119,12 @@ def test_tensorflow_wrapper_accumulate_gradients(model, X, Y, answer):

optimizer = Adam()
gradients = []
ops = get_current_ops()
for i in range(3):
guesses, backprop = model(X, is_train=True)
# Ensure that the tensor is type-compatible with the current backend.
guesses = ops.asarray(guesses)

d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
shim_grads = [tf.identity(var) for var in model.shims[0].gradients]
Expand Down Expand Up @@ -173,6 +182,9 @@ def call(self, inputs) -> tf.Tensor:
optimizer = Adam()
for i in range(50):
guesses, backprop = model(X, is_train=True)
# Ensure that the tensor is type-compatible with the current backend.
guesses = ops.asarray(guesses)

d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
model.finish_update(optimizer)
Expand Down Expand Up @@ -323,10 +335,14 @@ def test_tensorflow_wrapper_from_bytes(model, X):
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_use_params(model, X, Y, answer):
optimizer = Adam()
ops = get_current_ops()
with model.use_params(optimizer.averages):
assert model.predict(X).argmax() is not None
for i in range(10):
guesses, backprop = model.begin_update(X)
# Ensure that the tensor is type-compatible with the current backend.
guesses = ops.asarray(guesses)

d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
model.finish_update(optimizer)
Expand Down
17 changes: 12 additions & 5 deletions thinc/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from thinc.api import Linear, Ragged, Padded, ArgsKwargs
import numpy
import pytest
from thinc.util import has_cupy, is_cupy_array, is_numpy_array


@contextlib.contextmanager
Expand Down Expand Up @@ -95,18 +96,24 @@ def check_input_converters(Y, backprop, data, n_args, kwargs_keys, type_):
assert all(isinstance(arg, type_) for arg in Y.args)
assert all(isinstance(arg, type_) for arg in Y.kwargs.values())
dX = backprop(Y)

def is_supported_backend_array(arr):
return is_cupy_array(arr) or is_numpy_array(arr)

input_type = type(data) if not isinstance(data, list) else tuple
assert isinstance(dX, input_type)
assert isinstance(dX, input_type) or is_supported_backend_array(dX)

if isinstance(data, dict):
assert list(dX.keys()) == kwargs_keys
assert all(isinstance(arr, numpy.ndarray) for arr in dX.values())
assert all(is_supported_backend_array(arr) for arr in dX.values())
elif isinstance(data, (list, tuple)):
assert isinstance(dX, tuple)
assert all(isinstance(arr, numpy.ndarray) for arr in dX)
assert all(is_supported_backend_array(arr) for arr in dX)
elif isinstance(data, ArgsKwargs):
assert len(dX.args) == n_args
assert list(dX.kwargs.keys()) == kwargs_keys
assert all(isinstance(arg, numpy.ndarray) for arg in dX.args)
assert all(isinstance(arg, numpy.ndarray) for arg in dX.kwargs.values())

assert all(is_supported_backend_array(arg) for arg in dX.args)
assert all(is_supported_backend_array(arg) for arg in dX.kwargs.values())
elif not isinstance(data, numpy.ndarray):
pytest.fail(f"Bad data type: {dX}")

0 comments on commit f0b7e43

Please sign in to comment.