Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JIT compilation with TensorFlow >= 2.14.0 #197

Merged
merged 1 commit into from
Dec 29, 2023

Conversation

FHof
Copy link
Collaborator

@FHof FHof commented Dec 25, 2023

Description

The previous check if code is currently being compiled no longer works with new TensorFlow versions because the Tensor type is now called SymbolicTensor.

This change adds a helper function to check if code is being compiled for JAX, TensorFlow or PyTorch. If tf.is_symbolic_tensor() is available, i.e. if the TensorFlow version is high enough, we use this function to check if code is being compiled.

To avoid inconsistencies between backends,
the check for integration domain values is disabled if code is being compiled with PyTorch even if the check works with PyTorch.

Resolved Issues

Fixes #195

How Has This Been Tested?

  • There is a new test for the new function and the other tests are executed by the CI, too.
  • I have manually tested type(x.shape[0]).__name__ == "Tensor" with PyTorch 2.0.1 but not PyTorch 1.10.0
  • I have not tested if _is_compiling works for TensorFlow < 2.14.0; since it does the same as the previous condition in _check_integration_domain, I assume that it works.

The previous check if code is currently being compiled no longer works with new TensorFlow versions because the `Tensor` type is now called `SymbolicTensor`.

This change adds a helper function to check if code is being compiled for JAX, TensorFlow or PyTorch.
If tf.is_symbolic_tensor() is available, i.e. if the TensorFlow version is high enough,
we use this function to check if code is being compiled.

To avoid inconsistencies between backends,
the check for integration domain values is disabled if code is being compiled with PyTorch even if the check works with PyTorch.
@FHof FHof mentioned this pull request Dec 25, 2023
Copy link
Collaborator

@gomezzz gomezzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 👍 thanks so much :)

@gomezzz gomezzz merged commit 7914a8e into esa:main Dec 29, 2023
3 checks passed
@gomezzz gomezzz deleted the hofmeier/is_compiled_check branch December 29, 2023 10:32
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Regression in tests with TF
2 participants