Fix JIT compilation with TensorFlow >= 2.14.0 #197
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
The previous check if code is currently being compiled no longer works with new TensorFlow versions because the
Tensor
type is now calledSymbolicTensor
.This change adds a helper function to check if code is being compiled for JAX, TensorFlow or PyTorch. If tf.is_symbolic_tensor() is available, i.e. if the TensorFlow version is high enough, we use this function to check if code is being compiled.
To avoid inconsistencies between backends,
the check for integration domain values is disabled if code is being compiled with PyTorch even if the check works with PyTorch.
Resolved Issues
Fixes #195
How Has This Been Tested?
type(x.shape[0]).__name__ == "Tensor"
with PyTorch 2.0.1 but not PyTorch 1.10.0_is_compiling
works for TensorFlow < 2.14.0; since it does the same as the previous condition in_check_integration_domain
, I assume that it works.