diff --git a/thinc/backends/cupy_ops.py b/thinc/backends/cupy_ops.py index 9f420fcd6..18e448dfd 100644 --- a/thinc/backends/cupy_ops.py +++ b/thinc/backends/cupy_ops.py @@ -281,6 +281,10 @@ def scatter_add(self, table, indices, values): def adam( self, weights, gradient, mom1, mom2, beta1, beta2, eps, learn_rate, mod_rate=1.0 ): + _check_compatible_shape(weights, gradient) + _check_compatible_shape(weights, mom1) + _check_compatible_shape(weights, mom2) + adam_kernel( gradient, learn_rate, 1 - beta1, 1 - beta2, eps, weights, mom1, mom2 ) @@ -303,3 +307,9 @@ def position_encode(self, N, D, period=10000, out=None): ) else: adam_kernel = None + + +def _check_compatible_shape(u, v): + if u.shape != v.shape: + msg = f"arrays have incompatible shapes: {u.shape} and {v.shape}" + raise ValueError(msg) diff --git a/thinc/backends/numpy_ops.pyx b/thinc/backends/numpy_ops.pyx index 1455aad1c..f30db6630 100644 --- a/thinc/backends/numpy_ops.pyx +++ b/thinc/backends/numpy_ops.pyx @@ -460,9 +460,14 @@ class NumpyOps(Ops): @cython.boundscheck(False) @cython.wraparound(False) - def adam(self, np.ndarray weights, np.ndarray gradient, np.ndarray mom1, - np.ndarray mom2, const float beta1, const float beta2, float eps, + def adam(self, np.ndarray[np.float32_t] weights, np.ndarray[np.float32_t] gradient, + np.ndarray[np.float32_t] mom1, np.ndarray[np.float32_t] mom2, + const float beta1, const float beta2, float eps, float learn_rate, float mod_rate=1.): + _check_compatible_shape(weights, gradient) + _check_compatible_shape(weights, mom1) + _check_compatible_shape(weights, mom2) + _adam_momentum(gradient.data, mom1.data, mom2.data, weights.shape[0], beta1, beta2, eps, learn_rate) VecVec.add_i(weights.data, diff --git a/thinc/backends/ops.py b/thinc/backends/ops.py index 4179b4842..e5a232d9b 100644 --- a/thinc/backends/ops.py +++ b/thinc/backends/ops.py @@ -1112,6 +1112,10 @@ def adam( learn_rate: float, mod_rate: float = 1.0, ) -> Tuple[Floats1d, Floats1d, Floats1d, Floats1d]: + _check_compatible_shape(weights, gradient) + _check_compatible_shape(weights, mom1) + _check_compatible_shape(weights, mom2) + # Internals for optimizer mom1 *= beta1 mom2 *= beta2 @@ -1570,3 +1574,9 @@ def gaussian_cdf(ops: Ops, X: FloatsType) -> FloatsType: def gaussian_pdf(ops: Ops, X: FloatsType) -> FloatsType: """Gaussian PDF for distribution with mean 0 and stdev 1.""" return INV_SQRT_2PI * ops.xp.exp(-0.5 * X * X) + + +def _check_compatible_shape(u: FloatsXd, v: FloatsXd): + if u.shape != v.shape: + msg = f"arrays have incompatible shapes: {u.shape} and {v.shape}" + raise ValueError(msg) diff --git a/thinc/tests/backends/test_ops.py b/thinc/tests/backends/test_ops.py index cdc319b91..e095142b1 100644 --- a/thinc/tests/backends/test_ops.py +++ b/thinc/tests/backends/test_ops.py @@ -127,6 +127,22 @@ def test_ops_consistency(op): assert str(p1) == str(p2), attr +@pytest.mark.parametrize("ops", ALL_OPS) +def test_adam_incorrect_inputs(ops): + one = ops.xp.zeros(1, dtype="f") + two = ops.xp.zeros(2, dtype="f") + + ops.adam(one, one, one, one, 0.0, 0.0, 0.0, 0.0) + with pytest.raises(ValueError): + ops.adam(two, one, one, one, 0.0, 0.0, 0.0, 0.0) + with pytest.raises(ValueError): + ops.adam(one, two, one, one, 0.0, 0.0, 0.0, 0.0) + with pytest.raises(ValueError): + ops.adam(one, one, two, one, 0.0, 0.0, 0.0, 0.0) + with pytest.raises(ValueError): + ops.adam(one, one, one, two, 0.0, 0.0, 0.0, 0.0) + + @pytest.mark.parametrize("ops", ALL_OPS) def test_alloc(ops): float_methods = (ops.alloc1f, ops.alloc2f, ops.alloc3f, ops.alloc4f)