-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix squeeze in tf2jax for empty tuple axis arguments.
tf.squeeze and jnp.squeeze have different behavior when axis=(). TF will squeeze all axes, and jnp will squeeze no axes. This change will set axis=None when an empty tuple is passed into jnp.squeeze to ensure that the behavior matches that of TF. PiperOrigin-RevId: 555249518
- Loading branch information
TF2JAXDev
authored and
TF2JAXDev
committed
Aug 9, 2023
1 parent
7adfe09
commit 3e4bb64
Showing
2 changed files
with
14 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters