From cb48f42372367141db567ec8b72f741a7fadab2e Mon Sep 17 00:00:00 2001 From: Jean-Baptiste Lespiau Date: Tue, 27 Oct 2020 16:11:41 -0700 Subject: [PATCH] Raise an error on non-hashable static arguments for jax.jit and xla_computation. Up to now, Jax was silently wrapping the object to ensure objects which are not hashable will be hashed using `id` and compared using `is`: ``` class WrapHashably(object): __slots__ = ["val"] def __init__(self, val): self.val = val def __hash__(self): return id(self.val) def __eq__(self, other): return self.val is other.val ``` This means that when providing different instances of objects that are non hashable, a recompilation was always occurring. This can be non-intuitive, for example with: @partial(jax.jit, static_argnums=(1,)) def sum(a, b): return a+ b sum(np.asarray([1,2,3]), np.asarray([4,5,6]) # The next line will recompile, because the 1-indexed argument is non # hashable and thus compared by identity with different instances sum(np.asarray([1,2,3]), np.asarray([4,5,6]) or more simply np.pad(a, [2, 3], 'constant', constant_values=(4, 6)) ^^^^^^ non-hashable static argument. The same problems can occur with any non-hashable types such as lists, dicts, etc. Even JAX itself was having some issues with this (which shows the behaviour was non-trivial to reason about). If this commit breaks you, you usually have one of the following options: - If specifying numpy array or jnp arrays arguments as static, you probably simply need to make them non static. - When using non-hashable values, such as list, dicts or sets, you can simply use non-mutable versions, with tuples, frozendict, and frozenset. - You can also change the way the function is defined, to capture these non-hashable arguments by closure, returning the jitted function. PiperOrigin-RevId: 339351798 --- jax/api_util.py | 11 ++++------- tests/api_test.py | 18 +++++++++++++++--- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/jax/api_util.py b/jax/api_util.py index bbb20226ff1a..8f8cffbb9381 100644 --- a/jax/api_util.py +++ b/jax/api_util.py @@ -88,13 +88,10 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...], try: hash(static_arg) except TypeError: - logging.warning( - "Static argument (index %s) of type %s for function %s is " - "non-hashable. As this can lead to unexpected cache-misses, it " - "will raise an error in a near future.", i, type(static_arg), - f.__name__) - # e.g. ndarrays, DeviceArrays - fixed_args[i] = WrapHashably(static_arg) # type: ignore + raise ValueError( + "Non-hashable static arguments are not supported, as this can lead " + f"to unexpected cache-misses. Static argument (index {i}) of type " + f"{type(static_arg)} for function {f.__name__} is non-hashable.") else: fixed_args[i] = Hashable(static_arg) # type: ignore diff --git a/tests/api_test.py b/tests/api_test.py index b06523475374..a1f62c6de318 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -415,6 +415,18 @@ def test_jit_reference_dropping(self): del g # no more references to x assert x() is None # x is gone + def test_jit_raises_on_first_invocation_on_non_hashable_static_argnum(self): + if self.jit != jax.api._python_jit: + raise unittest.SkipTest("this test only applies to _python_jit") + f = lambda x, y: x + 3 + jitted_f = self.jit(f, static_argnums=(1,)) + + msg = ("Non-hashable static arguments are not supported, as this can lead " + "to unexpected cache-misses. Static argument (index 1) of type " + " for function is non-hashable.") + with self.assertRaisesRegex(ValueError, re.escape(msg)): + jitted_f(1, np.asarray(1)) + def test_cpp_jit_raises_on_non_hashable_static_argnum(self): if version < (0, 1, 58): raise unittest.SkipTest("Disabled because it depends on some future " @@ -428,9 +440,9 @@ def test_cpp_jit_raises_on_non_hashable_static_argnum(self): jitted_f(1, 1) - msg = ( - """Non-hashable static arguments are not supported. An error occured while trying to hash an object of type , 1. The error was: -TypeError: unhashable type: 'numpy.ndarray'""") + msg = ("Non-hashable static arguments are not supported. An error occured " + "while trying to hash an object of type , 1. " + "The error was:\nTypeError: unhashable type: 'numpy.ndarray'") with self.assertRaisesRegex(ValueError, re.escape(msg)): jitted_f(1, np.asarray(1))