Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Raise an error on non-hashable static arguments for jax.jit and xla_c…
…omputation. Up to now, Jax was silently wrapping the object to ensure objects which are not hashable will be hashed using `id` and compared using `is`: ``` class WrapHashably(object): __slots__ = ["val"] def __init__(self, val): self.val = val def __hash__(self): return id(self.val) def __eq__(self, other): return self.val is other.val ``` This means that when providing different instances of objects that are non hashable, a recompilation was always occurring. This can be non-intuitive, for example with: @partial(jax.jit, static_argnums=(1,)) def sum(a, b): return a+ b sum(np.asarray([1,2,3]), np.asarray([4,5,6]) # The next line will recompile, because the 1-indexed argument is non # hashable and thus compared by identity with different instances sum(np.asarray([1,2,3]), np.asarray([4,5,6]) or more simply np.pad(a, [2, 3], 'constant', constant_values=(4, 6)) ^^^^^^ non-hashable static argument. The same problems can occur with any non-hashable types such as lists, dicts, etc. Even JAX itself was having some issues with this (which shows the behaviour was non-trivial to reason about). If this commit breaks you, you usually have one of the following options: - If specifying numpy array or jnp arrays arguments as static, you probably simply need to make them non static. - When using non-hashable values, such as list, dicts or sets, you can simply use non-mutable versions, with tuples, frozendict, and frozenset. - You can also change the way the function is defined, to capture these non-hashable arguments by closure, returning the jitted function. PiperOrigin-RevId: 339351798
- Loading branch information
cb48f42
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just ran into this in a Flax example. This is a non-backwards-compatible change, it ought to be mentioned in the CHANGELOG.md
cb48f42
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jblespiau Perhaps add a mention to CHANGELOG?
cb48f42
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sent https://critique-ng.corp.google.com/cl/340395327