-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add the unstack operation to keras core #597
Conversation
I am not sure we should add this operation --
|
I will add a symbolic version of this op. Initially, I thought that keras is inferring the shape even if it is >>> import tensorflow as tf
2023-07-25 22:27:40.429346: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
>>> from tensorflow import keras
>>> x = keras.Input([2,3,4])
>>> tf.unstack(x, axis=0)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/google/home/tirthp/oss/virtualenvs/keras-cv-dev/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File "/usr/local/google/home/tirthp/oss/virtualenvs/keras-cv-dev/lib/python3.10/site-packages/keras/layers/core/tf_op_layer.py", line 119, in handle
return TFOpLambda(op)(*args, **kwargs)
File "/usr/local/google/home/tirthp/oss/virtualenvs/keras-cv-dev/lib/python3.10/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
ValueError: Exception encountered when calling layer "tf.unstack" (type TFOpLambda).
Cannot infer argument `num` from shape (None, 2, 3, 4)
Call arguments received by layer "tf.unstack" (type TFOpLambda):
• value=tf.Tensor(shape=(None, 2, 3, 4), dtype=float32)
• num=None
• axis=0
• name=unstack |
I agree, but there is no way to replicate the op in TensorFlow or PyTorch without using the unstack operation i.e. tensorflow as torch don't allow just unpacking a tensor like |
I guess you could have symbolic support but require a defined shape? |
Yes, exactly. |
keras_core/ops/core_test.py
Outdated
for o in out: | ||
self.assertEqual(o.shape, (2, None)) | ||
|
||
pytest.raises(ValueError, lambda: core.unstack(x, axis=axis)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use self.assertRaisesRegex
and check (part of) the error message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
I recently got a use-case for the
unstack
operation while working with the segment anything model tensorflow port. This PR adds the op.Note that, if the input tensor is symbolic, it is not straight-forward to compute the output spec since the output list length might be arbitrary. So, this PR only adds support for non-symbolic tensors for the time being. Support for symbolic tensors can be added as a follow-up enhancement.cc @ianstenbit