diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b980391193df..9ae539cc9b29 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -335,6 +335,10 @@ def population_count(x: Array) -> Array: r"""Elementwise popcount, count the number of set bits in each element.""" return population_count_p.bind(x) +def clz(x: Array) -> Array: + r"""Elementwise count-leading-zeros.""" + return clz_p.bind(x) + def add(x: Array, y: Array) -> Array: r"""Elementwise addition: :math:`x + y`.""" return add_p.bind(x, y) @@ -2531,6 +2535,8 @@ def _integer_pow_jvp(g, x, *, y): population_count_p = standard_unop(_int, 'population_count') +clz_p = standard_unop(_int, 'clz') + def _add_transpose(t, x, y): # The following linearity assertion is morally true, but because in some cases we # instantiate zeros for convenience, it doesn't always hold. diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index ce6eba2826a5..0ccc9938f1f0 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -797,7 +797,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): # Primitives that are not yet implemented must be explicitly declared here. tf_not_yet_impl = [ - "reduce", "rng_uniform", + "reduce", "rng_uniform", "clz", "igamma_grad_a", "random_gamma_grad", diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index 4f884a0a5593..3206e2bf89c3 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -75,6 +75,7 @@ ceil_p, clamp, clamp_p, + clz, collapse, complex, complex_p, diff --git a/jax/lax_reference.py b/jax/lax_reference.py index 9b839c360ee3..ca75554e0877 100644 --- a/jax/lax_reference.py +++ b/jax/lax_reference.py @@ -149,6 +149,15 @@ def population_count(x): x = (x & m[5]) + ((x >> 32) & m[5]) # put count of each 64 bits into those 64 bits return x.astype(dtype) +def clz(x): + assert np.issubdtype(x.dtype, np.integer) + nbits = np.iinfo(x.dtype).bits + mask = (2 ** np.arange(nbits, dtype=x.dtype))[::-1] + bits = (x[..., None] & mask).astype(np.bool_) + out = np.argmax(bits, axis=-1).astype(x.dtype) + out[x == 0] = nbits + return out + eq = np.equal ne = np.not_equal ge = np.greater_equal diff --git a/tests/lax_test.py b/tests/lax_test.py index 79d3d206c521..2eaf6db35a3b 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -146,6 +146,7 @@ def op_record(op, nargs, dtypes, rng_factory, tol=None): op_record("bitwise_or", 2, bool_dtypes, jtu.rand_small), op_record("bitwise_xor", 2, bool_dtypes, jtu.rand_small), op_record("population_count", 1, int_dtypes + uint_dtypes, jtu.rand_int), + op_record("clz", 1, int_dtypes + uint_dtypes, jtu.rand_int), op_record("add", 2, default_dtypes + complex_dtypes, jtu.rand_small), op_record("sub", 2, default_dtypes + complex_dtypes, jtu.rand_small),