Skip to content

Commit

Permalink
Don't include local_uint_constant_indices rewrite in JAX mode due to …
Browse files Browse the repository at this point in the history
…XLA bug
  • Loading branch information
ricardoV94 committed Jul 26, 2023
1 parent 14d2454 commit 4459199
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,14 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
JAXLinker(),
RewriteDatabaseQuery(
include=["fast_run", "jax"],
exclude=["cxx_only", "BlasOpt", "fusion", "inplace"],
# TODO: "local_uint_constant_indices" can be reintroduced once https://github.com/google/jax/issues/16836 is fixed.
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"local_uint_constant_indices",
],
),
)
NUMBA = Mode(
Expand Down

0 comments on commit 4459199

Please sign in to comment.