Skip to content

Commit

Permalink
Fix squeeze in tf2jax for empty tuple axis arguments.
Browse files Browse the repository at this point in the history
tf.squeeze and jnp.squeeze have different behavior when axis=(). TF will squeeze all axes, and jnp will squeeze no axes. This change will set axis=None when an empty tuple is passed into jnp.squeeze to ensure that the behavior matches that of TF.

PiperOrigin-RevId: 555249518
  • Loading branch information
TF2JAXDev authored and TF2JAXDev committed Aug 9, 2023
1 parent 7adfe09 commit 3e4bb64
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
12 changes: 11 additions & 1 deletion tf2jax/_src/numpy_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,22 @@ def broadcast_to(arr, shape):
flip = lambda arr, axis: _get_np(arr).flip(arr, axis=axis)
roll = lambda arr, shift, axis: _get_np(arr).roll(arr, shift=shift, axis=axis)
split = lambda arr, sections, axis: _get_np(arr).split(arr, sections, axis=axis)
squeeze = lambda arr, axis: _get_np(arr).squeeze(arr, axis=axis)
stack = lambda arrs, axis: _get_np(*arrs).stack(arrs, axis=axis)
tile = lambda arr, reps: _get_np(arr, reps).tile(arr, reps=reps)
where = lambda cond, x, y: _get_np(cond, x, y).where(cond, x, y)


def squeeze(arr, axis):
# tf.squeeze and np/jnp.squeeze have different behaviors when axis=().
# - tf.squeeze will squeeze all dimensions.
# - np/jnp.squeeze will not squeeze any dimensions.
# Here we change () to None to ensure that squeeze has the same behavior
# when converted from tf to np/jnp.
if axis == tuple():
axis = None
return _get_np(arr).squeeze(arr, axis=axis)


def moveaxis(
arr,
source: Union[int, Sequence[int]],
Expand Down
5 changes: 3 additions & 2 deletions tf2jax/_src/ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1458,8 +1458,9 @@ def roll_static():
self._test_convert(roll_static, [])

@chex.variants(with_jit=True, without_jit=True)
def test_squeeze(self):
inputs, dims = np.array([[[42], [47]]]), (0, 2)
@parameterized.parameters(((0, 2),), (tuple(),), (None,))
def test_squeeze(self, dims):
inputs = np.array([[[42], [47]]])

def squeeze(x):
return tf.raw_ops.Squeeze(input=x, axis=dims)
Expand Down

0 comments on commit 3e4bb64

Please sign in to comment.