Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JIT compilation with TensorFlow >= 2.14.0 #197

Merged
merged 1 commit into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions torchquad/integration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,20 +193,11 @@ def _check_integration_domain(integration_domain):
raise ValueError("integration_domain.shape[0] needs to be 1 or larger.")
if num_bounds != 2:
raise ValueError("integration_domain must have 2 values per boundary")
# Skip the values check if an integrator.integrate method is JIT
# compiled with JAX
if any(
nam in type(integration_domain).__name__ for nam in ["Jaxpr", "JVPTracer"]
):
# The boundary values check does not work if the code is JIT compiled
# with JAX or TensorFlow.
if _is_compiling(integration_domain):
return dim
boundaries_are_invalid = (
anp.min(integration_domain[:, 1] - integration_domain[:, 0]) < 0.0
)
# Skip the values check if an integrator.integrate method is
# compiled with tensorflow.function
if type(boundaries_are_invalid).__name__ == "Tensor":
return dim
if boundaries_are_invalid:
if anp.min(integration_domain[:, 1] - integration_domain[:, 0]) < 0.0:
raise ValueError("integration_domain has invalid boundary values")
return dim

Expand Down Expand Up @@ -263,6 +254,37 @@ def wrap(*args, **kwargs):
return wrap


def _is_compiling(x):
"""
Check if code is currently being compiled with PyTorch, JAX or TensorFlow

Args:
x (backend tensor): A tensor currently used for computations
Returns:
bool: True if code is currently being compiled, False otherwise
"""
backend = infer_backend(x)
if backend == "jax":
return any(nam in type(x).__name__ for nam in ["Jaxpr", "JVPTracer"])
if backend == "torch":
import torch

if hasattr(torch.jit, "is_tracing"):
# We ignore torch.jit.is_scripting() since we do not support
# compilation to TorchScript
return torch.jit.is_tracing()
# torch.jit.is_tracing() is unavailable below PyTorch version 1.11.0
return type(x.shape[0]).__name__ == "Tensor"
if backend == "tensorflow":
import tensorflow as tf

if hasattr(tf, "is_symbolic_tensor"):
return tf.is_symbolic_tensor(x)
# tf.is_symbolic_tensor() is unavailable below TensorFlow version 2.13.0
return type(x).__name__ == "Tensor"
return False


def _torch_trace_without_warnings(*args, **kwargs):
"""Execute `torch.jit.trace` on the passed arguments and hide tracer warnings

Expand Down
38 changes: 38 additions & 0 deletions torchquad/tests/utils_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_linspace_with_grads,
_add_at_indices,
_setup_integration_domain,
_is_compiling,
)
from utils.set_precision import set_precision
from utils.enable_cuda import enable_cuda
Expand Down Expand Up @@ -196,11 +197,48 @@ def test_setup_integration_domain():
_run_tests_with_all_backends(_run_setup_integration_domain_tests)


def _run_is_compiling_tests(dtype_name, backend):
"""
Test _is_compiling with the given dtype and numerical backend
"""
dtype = to_backend_dtype(dtype_name, like=backend)
x = anp.array([[0.0, 1.0], [1.0, 2.0]], dtype=dtype, like=backend)
assert not _is_compiling(
x
), f"_is_compiling has a false positive with backend {backend}"

def check_compiling(x):
assert _is_compiling(
x
), f"_is_compiling has a false negative with backend {backend}"
return x

if backend == "jax":
import jax

jax.jit(check_compiling)(x)
elif backend == "torch":
import torch

torch.jit.trace(check_compiling, (x,), check_trace=False)(x)
elif backend == "tensorflow":
import tensorflow as tf

tf.function(check_compiling, jit_compile=True)(x)
tf.function(check_compiling, jit_compile=False)(x)


def test_is_compiling():
"""Test _is_compiling with all possible configurations"""
_run_tests_with_all_backends(_run_is_compiling_tests)


if __name__ == "__main__":
try:
# used to run this test individually
test_linspace_with_grads()
test_add_at_indices()
test_setup_integration_domain()
test_is_compiling()
except KeyboardInterrupt:
pass
Loading