diff --git a/tf2jax/_src/ops.py b/tf2jax/_src/ops.py index 7d088c5..4ad40ce 100644 --- a/tf2jax/_src/ops.py +++ b/tf2jax/_src/ops.py @@ -1807,7 +1807,7 @@ def _func( sizes: jnp.ndarray, ) -> jnp.ndarray: """`begins` and `sizes` must be concrete arrays.""" - slices = [slice(b, b + s) for b, s in safe_zip(begins, sizes)] + slices = [slice(b, b + s if slice != -1 else None) for b, s in safe_zip(begins, sizes)] return x[tuple(slices)] return _func diff --git a/tf2jax/_src/ops_test.py b/tf2jax/_src/ops_test.py index ee7fa32..9a9e276 100644 --- a/tf2jax/_src/ops_test.py +++ b/tf2jax/_src/ops_test.py @@ -1721,7 +1721,7 @@ def size_static(): @chex.variants(with_jit=True, without_jit=True) def test_slice(self): - inputs, begins, sizes = [np.array([[1, 2], [3, 4], [5, 6]]), [1, 1], [2, 1]] + inputs, begins, sizes = [np.array([[1, 2], [3, 4], [5, 6]]), [1, 1], [2, -1]] def slice_fn(xs): return tf.raw_ops.Slice(input=xs, begin=begins, size=sizes)