From 445919931e1dd1489c11d44006349c9673b3729b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 26 Jul 2023 11:22:36 +0200 Subject: [PATCH] Don't include local_uint_constant_indices rewrite in JAX mode due to XLA bug --- pytensor/compile/mode.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index e736f93dd1..514f8f48c4 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -451,7 +451,14 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): JAXLinker(), RewriteDatabaseQuery( include=["fast_run", "jax"], - exclude=["cxx_only", "BlasOpt", "fusion", "inplace"], + # TODO: "local_uint_constant_indices" can be reintroduced once https://github.com/google/jax/issues/16836 is fixed. + exclude=[ + "cxx_only", + "BlasOpt", + "fusion", + "inplace", + "local_uint_constant_indices", + ], ), ) NUMBA = Mode(