Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the unstack operation to keras core #597

Merged
merged 4 commits into from
Jul 26, 2023

Conversation

tirthasheshpatel
Copy link
Contributor

@tirthasheshpatel tirthasheshpatel commented Jul 25, 2023

I recently got a use-case for the unstack operation while working with the segment anything model tensorflow port. This PR adds the op.

Note that, if the input tensor is symbolic, it is not straight-forward to compute the output spec since the output list length might be arbitrary. So, this PR only adds support for non-symbolic tensors for the time being. Support for symbolic tensors can be added as a follow-up enhancement.

cc @ianstenbit

@fchollet
Copy link
Contributor

I am not sure we should add this operation --

  1. All our ops are both eager and symbolic, so it would be inconsistent.
  2. It's not in the NumPy API (whereas stack is) and not in JAX / torch

@tirthasheshpatel
Copy link
Contributor Author

tirthasheshpatel commented Jul 25, 2023

1. All our ops are both eager and symbolic, so it would be inconsistent.

I will add a symbolic version of this op. Initially, I thought that keras is inferring the shape even if it is None in the axis which the user wishes to unstack. But this is not the case: Keras just errors out:

>>> import tensorflow as tf
2023-07-25 22:27:40.429346: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
>>> from tensorflow import keras
>>> x = keras.Input([2,3,4])
>>> tf.unstack(x, axis=0)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/google/home/tirthp/oss/virtualenvs/keras-cv-dev/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/usr/local/google/home/tirthp/oss/virtualenvs/keras-cv-dev/lib/python3.10/site-packages/keras/layers/core/tf_op_layer.py", line 119, in handle
    return TFOpLambda(op)(*args, **kwargs)
  File "/usr/local/google/home/tirthp/oss/virtualenvs/keras-cv-dev/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
ValueError: Exception encountered when calling layer "tf.unstack" (type TFOpLambda).

Cannot infer argument `num` from shape (None, 2, 3, 4)

Call arguments received by layer "tf.unstack" (type TFOpLambda):
  • value=tf.Tensor(shape=(None, 2, 3, 4), dtype=float32)
  • num=Noneaxis=0name=unstack

@tirthasheshpatel
Copy link
Contributor Author

2. It's not in the NumPy API (whereas stack is) and not in JAX / torch

I agree, but there is no way to replicate the op in TensorFlow or PyTorch without using the unstack operation i.e. tensorflow as torch don't allow just unpacking a tensor like a, b, c = ops.random.normal((3, 4)). Also, it is very easy to implement the op in NumPy and Jax. So, I am more on the side of adding it than not.

@fchollet
Copy link
Contributor

I will add a symbolic version of this op. Initially, I thought that keras is inferring the shape even if it is None in the axis which the user wishes to unstack. But this is not the case: Keras just errors out:

I guess you could have symbolic support but require a defined shape?

@tirthasheshpatel
Copy link
Contributor Author

I guess you could have symbolic support but require a defined shape?

Yes, exactly.

for o in out:
self.assertEqual(o.shape, (2, None))

pytest.raises(ValueError, lambda: core.unstack(x, axis=axis))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use self.assertRaisesRegex and check (part of) the error message

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@fchollet fchollet merged commit 2601ce3 into keras-team:main Jul 26, 2023
@tirthasheshpatel tirthasheshpatel deleted the add-unstack branch July 26, 2023 16:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants