Skip to content

Commit

Permalink
Fix a bug in Slice for size==-1.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 682596718
  • Loading branch information
shaobohou authored and TF2JAXDev committed Oct 5, 2024
1 parent 633a306 commit 582c077
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
5 changes: 4 additions & 1 deletion tf2jax/_src/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,7 +1807,10 @@ def _func(
sizes: jnp.ndarray,
) -> jnp.ndarray:
"""`begins` and `sizes` must be concrete arrays."""
slices = [slice(b, b + s) for b, s in safe_zip(begins, sizes)]
slices = [
slice(b, (b + s) if s != -1 else None)
for b, s in safe_zip(begins, sizes)
]
return x[tuple(slices)]

return _func
Expand Down
6 changes: 5 additions & 1 deletion tf2jax/_src/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1721,7 +1721,11 @@ def size_static():

@chex.variants(with_jit=True, without_jit=True)
def test_slice(self):
inputs, begins, sizes = [np.array([[1, 2], [3, 4], [5, 6]]), [1, 1], [2, 1]]
inputs, begins, sizes = [
np.array([[1, 2], [3, 4], [5, 6]]),
[1, 1],
[2, -1],
]

def slice_fn(xs):
return tf.raw_ops.Slice(input=xs, begin=begins, size=sizes)
Expand Down

0 comments on commit 582c077

Please sign in to comment.