Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport fixes from master to v8.0.x #662

Merged
merged 7 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions thinc/backends/cupy_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from . import _custom_kernels
from ..types import DeviceTypes
from ..util import torch2xp, tensorflow2xp, mxnet2xp
from ..util import is_torch_array, is_tensorflow_array, is_mxnet_array
from ..util import is_torch_gpu_array, is_tensorflow_gpu_array, is_mxnet_gpu_array
from ..util import is_cupy_array


@registry.ops("CupyOps")
Expand Down Expand Up @@ -72,29 +73,20 @@ def gemm(self, x, y, out=None, trans1=False, trans2=False):
return out

def asarray(self, data, dtype=None):
# This is sort of frustrating, but we can't easily otherwise pass
# forward "unset".
dtype = {"dtype": dtype} if dtype is not None else {}

# We'll try to perform a zero-copy conversion if possible.
array = None
cast_array = False
if isinstance(data, cupy.ndarray):
array = self.xp.asarray(data, **dtype)
elif is_torch_array(data) and data.device.type == "cuda":
if is_cupy_array(data):
array = data
elif is_torch_gpu_array(data):
array = torch2xp(data)
cast_array = True
elif is_tensorflow_array(data) and "GPU:" in data.device:
elif is_tensorflow_gpu_array(data):
array = tensorflow2xp(data)
cast_array = True
elif is_mxnet_array(data) and data.context.device_type != "cpu":
elif is_mxnet_gpu_array(data):
array = mxnet2xp(data)
cast_array = True
else:
array = self.xp.array(data, **dtype)
array = self.xp.array(data)

if cast_array and dtype != {}:
array = array.astype(dtype["dtype"])
if dtype is not None:
array = array.astype(dtype=dtype, copy=False)

return array

Expand Down Expand Up @@ -263,6 +255,10 @@ def scatter_add(self, table, indices, values):
def adam(
self, weights, gradient, mom1, mom2, beta1, beta2, eps, learn_rate, mod_rate=1.0
):
_check_compatible_shape(weights, gradient)
_check_compatible_shape(weights, mom1)
_check_compatible_shape(weights, mom2)

adam_kernel(
gradient, learn_rate, 1 - beta1, 1 - beta2, eps, weights, mom1, mom2
)
Expand All @@ -285,3 +281,9 @@ def position_encode(self, N, D, period=10000, out=None):
)
else:
adam_kernel = None


def _check_compatible_shape(u, v):
if u.shape != v.shape:
msg = f"arrays have incompatible shapes: {u.shape} and {v.shape}"
raise ValueError(msg)
34 changes: 23 additions & 11 deletions thinc/backends/numpy_ops.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,20 @@ class NumpyOps(Ops):

def asarray(self, data, dtype=None):
if isinstance(data, self.xp.ndarray):
if dtype is not None:
return self.xp.asarray(data, dtype=dtype)
else:
return self.xp.asarray(data)
array = data
elif hasattr(data, 'numpy'):
# Handles PyTorch Tensor
return data.numpy()
array = data.numpy()
elif hasattr(data, "get"):
return data.get()
elif dtype is not None:
return self.xp.array(data, dtype=dtype)
array = data.get()
else:
return self.xp.array(data)
array = self.xp.array(data)

if dtype is not None:
array = array.astype(dtype=dtype, copy=False)

return array


def alloc(self, shape: Shape, *, dtype: Optional[DTypes] = "float32") -> ArrayXd:
return self.xp.zeros(shape, dtype=dtype)
Expand Down Expand Up @@ -345,9 +346,14 @@ class NumpyOps(Ops):

@cython.boundscheck(False)
@cython.wraparound(False)
def adam(self, np.ndarray weights, np.ndarray gradient, np.ndarray mom1,
np.ndarray mom2, const float beta1, const float beta2, float eps,
def adam(self, np.ndarray[np.float32_t] weights, np.ndarray[np.float32_t] gradient,
np.ndarray[np.float32_t] mom1, np.ndarray[np.float32_t] mom2,
const float beta1, const float beta2, float eps,
float learn_rate, float mod_rate=1.):
_check_compatible_shape(weights, gradient)
_check_compatible_shape(weights, mom1)
_check_compatible_shape(weights, mom2)

_adam_momentum(<float*>gradient.data, <float*>mom1.data, <float*>mom2.data,
weights.shape[0], beta1, beta2, eps, learn_rate)
VecVec.add_i(<float*>weights.data,
Expand Down Expand Up @@ -1258,3 +1264,9 @@ cdef void MurmurHash3_x86_128_uint64(
out[1] = h1 >> 32
out[2] = h2 & 0xffffffffu
out[3] = h2 >> 32

def _check_compatible_shape(u: np.ndarray, v: np.ndarray):
if u.shape != v.shape:
msg = f"arrays have incompatible shapes: {u.shape} and {v.shape}"
raise ValueError(msg)

10 changes: 10 additions & 0 deletions thinc/backends/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,10 @@ def adam(
learn_rate: float,
mod_rate: float = 1.0,
) -> Tuple[Floats1d, Floats1d, Floats1d, Floats1d]:
_check_compatible_shape(weights, gradient)
_check_compatible_shape(weights, mom1)
_check_compatible_shape(weights, mom2)

# Internals for optimizer
mom1 *= beta1
mom2 *= beta2
Expand Down Expand Up @@ -1396,3 +1400,9 @@ def gaussian_cdf(ops: Ops, X: FloatsType) -> FloatsType:
def gaussian_pdf(ops: Ops, X: FloatsType) -> FloatsType:
"""Gaussian PDF for distribution with mean 0 and stdev 1."""
return INV_SQRT_2PI * ops.xp.exp(-0.5 * X * X)


def _check_compatible_shape(u: FloatsXd, v: FloatsXd):
if u.shape != v.shape:
msg = f"arrays have incompatible shapes: {u.shape} and {v.shape}"
raise ValueError(msg)
16 changes: 16 additions & 0 deletions thinc/tests/backends/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,22 @@ def test_ops_consistency(op):
assert str(p1) == str(p2), attr


@pytest.mark.parametrize("ops", ALL_OPS)
def test_adam_incorrect_inputs(ops):
one = ops.xp.zeros(1, dtype="f")
two = ops.xp.zeros(2, dtype="f")

ops.adam(one, one, one, one, 0.0, 0.0, 0.0, 0.0)
with pytest.raises(ValueError):
ops.adam(two, one, one, one, 0.0, 0.0, 0.0, 0.0)
with pytest.raises(ValueError):
ops.adam(one, two, one, one, 0.0, 0.0, 0.0, 0.0)
with pytest.raises(ValueError):
ops.adam(one, one, two, one, 0.0, 0.0, 0.0, 0.0)
with pytest.raises(ValueError):
ops.adam(one, one, one, two, 0.0, 0.0, 0.0, 0.0)


@pytest.mark.parametrize("ops", ALL_OPS)
def test_alloc(ops):
float_methods = (ops.alloc1f, ops.alloc2f, ops.alloc3f, ops.alloc4f)
Expand Down
28 changes: 24 additions & 4 deletions thinc/tests/layers/test_tensorflow_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from thinc.api import Adam, ArgsKwargs, Linear, Model, TensorFlowWrapper
from thinc.api import get_current_ops, keras_subclass, tensorflow2xp, xp2tensorflow
from thinc.util import has_cupy, has_tensorflow, to_categorical
from thinc.util import has_tensorflow, to_categorical, gpu_is_available

from ..util import check_input_converters, make_tempdir

Expand Down Expand Up @@ -64,11 +64,16 @@ def model(tf_model):
def test_tensorflow_wrapper_roundtrip_conversion():
import tensorflow as tf

xp_tensor = numpy.zeros((2, 3), dtype="f")
ops = get_current_ops()
xp_tensor = ops.alloc2f(2, 3)
tf_tensor = xp2tensorflow(xp_tensor)
assert isinstance(tf_tensor, tf.Tensor)
new_xp_tensor = tensorflow2xp(tf_tensor)
assert numpy.array_equal(xp_tensor, new_xp_tensor)
# The converted tensor will be backed by Cupy, so
# we'll need to convert it to current backend's repr.
new_xp_tensor = ops.asarray(new_xp_tensor)

assert ops.xp.array_equal(xp_tensor, new_xp_tensor)


@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
Expand Down Expand Up @@ -99,8 +104,12 @@ def test_tensorflow_wrapper_predict(model, X):
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_train_overfits(model, X, Y, answer):
optimizer = Adam()
ops = get_current_ops()
for i in range(100):
guesses, backprop = model(X, is_train=True)
# Ensure that the tensor is type-compatible with the current backend.
guesses = ops.asarray(guesses)

d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
model.finish_update(optimizer)
Expand All @@ -114,8 +123,12 @@ def test_tensorflow_wrapper_accumulate_gradients(model, X, Y, answer):

optimizer = Adam()
gradients = []
ops = get_current_ops()
for i in range(3):
guesses, backprop = model(X, is_train=True)
# Ensure that the tensor is type-compatible with the current backend.
guesses = ops.asarray(guesses)

d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
shim_grads = [tf.identity(var) for var in model.shims[0].gradients]
Expand Down Expand Up @@ -173,6 +186,9 @@ def call(self, inputs) -> tf.Tensor:
optimizer = Adam()
for i in range(50):
guesses, backprop = model(X, is_train=True)
# Ensure that the tensor is type-compatible with the current backend.
guesses = ops.asarray(guesses)

d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
model.finish_update(optimizer)
Expand Down Expand Up @@ -323,10 +339,14 @@ def test_tensorflow_wrapper_from_bytes(model, X):
@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
def test_tensorflow_wrapper_use_params(model, X, Y, answer):
optimizer = Adam()
ops = get_current_ops()
with model.use_params(optimizer.averages):
assert model.predict(X).argmax() is not None
for i in range(10):
guesses, backprop = model.begin_update(X)
# Ensure that the tensor is type-compatible with the current backend.
guesses = ops.asarray(guesses)

d_guesses = (guesses - Y) / guesses.shape[0]
backprop(d_guesses)
model.finish_update(optimizer)
Expand All @@ -342,7 +362,7 @@ def test_tensorflow_wrapper_to_cpu(tf_model):


@pytest.mark.skipif(not has_tensorflow, reason="needs TensorFlow")
@pytest.mark.skipif(not has_cupy, reason="needs cupy")
@pytest.mark.skipif(not gpu_is_available(), reason="needs GPU/cupy")
def test_tensorflow_wrapper_to_gpu(model, X):
model.to_gpu(0)

Expand Down
17 changes: 12 additions & 5 deletions thinc/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from thinc.api import Linear, Ragged, Padded, ArgsKwargs
import numpy
import pytest
from thinc.util import has_cupy, is_cupy_array, is_numpy_array


@contextlib.contextmanager
Expand Down Expand Up @@ -95,18 +96,24 @@ def check_input_converters(Y, backprop, data, n_args, kwargs_keys, type_):
assert all(isinstance(arg, type_) for arg in Y.args)
assert all(isinstance(arg, type_) for arg in Y.kwargs.values())
dX = backprop(Y)

def is_supported_backend_array(arr):
return is_cupy_array(arr) or is_numpy_array(arr)

input_type = type(data) if not isinstance(data, list) else tuple
assert isinstance(dX, input_type)
assert isinstance(dX, input_type) or is_supported_backend_array(dX)

if isinstance(data, dict):
assert list(dX.keys()) == kwargs_keys
assert all(isinstance(arr, numpy.ndarray) for arr in dX.values())
assert all(is_supported_backend_array(arr) for arr in dX.values())
elif isinstance(data, (list, tuple)):
assert isinstance(dX, tuple)
assert all(isinstance(arr, numpy.ndarray) for arr in dX)
assert all(is_supported_backend_array(arr) for arr in dX)
elif isinstance(data, ArgsKwargs):
assert len(dX.args) == n_args
assert list(dX.kwargs.keys()) == kwargs_keys
assert all(isinstance(arg, numpy.ndarray) for arg in dX.args)
assert all(isinstance(arg, numpy.ndarray) for arg in dX.kwargs.values())

assert all(is_supported_backend_array(arg) for arg in dX.args)
assert all(is_supported_backend_array(arg) for arg in dX.kwargs.values())
elif not isinstance(data, numpy.ndarray):
pytest.fail(f"Bad data type: {dX}")
15 changes: 15 additions & 0 deletions thinc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def get_array_module(arr): # pragma: no cover


def gpu_is_available():
if not has_cupy:
return False

try:
cupy.cuda.runtime.getDeviceCount()
return True
Expand Down Expand Up @@ -124,6 +127,10 @@ def is_torch_array(obj: Any) -> bool: # pragma: no cover
return False


def is_torch_gpu_array(obj: Any) -> bool: # pragma: no cover
return is_torch_array(obj) and obj.is_cuda


def is_tensorflow_array(obj: Any) -> bool: # pragma: no cover
if not has_tensorflow:
return False
Expand All @@ -133,6 +140,10 @@ def is_tensorflow_array(obj: Any) -> bool: # pragma: no cover
return False


def is_tensorflow_gpu_array(obj: Any) -> bool: # pragma: no cover
return is_tensorflow_array(obj) and "GPU:" in obj.device


def is_mxnet_array(obj: Any) -> bool: # pragma: no cover
if not has_mxnet:
return False
Expand All @@ -142,6 +153,10 @@ def is_mxnet_array(obj: Any) -> bool: # pragma: no cover
return False


def is_mxnet_gpu_array(obj: Any) -> bool: # pragma: no cover
return is_mxnet_array(obj) and obj.context.device_type != "cpu"


def to_numpy(data): # pragma: no cover
if isinstance(data, numpy.ndarray):
return data
Expand Down