Skip to content

Commit

Permalink
Merge pull request #6146 from google:clz
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 364039346
  • Loading branch information
jax authors committed Mar 20, 2021
2 parents a29d07f + 97aca25 commit 3c377a2
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 1 deletion.
6 changes: 6 additions & 0 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ def population_count(x: Array) -> Array:
r"""Elementwise popcount, count the number of set bits in each element."""
return population_count_p.bind(x)

def clz(x: Array) -> Array:
r"""Elementwise count-leading-zeros."""
return clz_p.bind(x)

def add(x: Array, y: Array) -> Array:
r"""Elementwise addition: :math:`x + y`."""
return add_p.bind(x, y)
Expand Down Expand Up @@ -2531,6 +2535,8 @@ def _integer_pow_jvp(g, x, *, y):

population_count_p = standard_unop(_int, 'population_count')

clz_p = standard_unop(_int, 'clz')

def _add_transpose(t, x, y):
# The following linearity assertion is morally true, but because in some cases we
# instantiate zeros for convenience, it doesn't always hold.
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,7 +797,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):

# Primitives that are not yet implemented must be explicitly declared here.
tf_not_yet_impl = [
"reduce", "rng_uniform",
"reduce", "rng_uniform", "clz",

"igamma_grad_a",
"random_gamma_grad",
Expand Down
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
ceil_p,
clamp,
clamp_p,
clz,
collapse,
complex,
complex_p,
Expand Down
9 changes: 9 additions & 0 deletions jax/lax_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,15 @@ def population_count(x):
x = (x & m[5]) + ((x >> 32) & m[5]) # put count of each 64 bits into those 64 bits
return x.astype(dtype)

def clz(x):
assert np.issubdtype(x.dtype, np.integer)
nbits = np.iinfo(x.dtype).bits
mask = (2 ** np.arange(nbits, dtype=x.dtype))[::-1]
bits = (x[..., None] & mask).astype(np.bool_)
out = np.argmax(bits, axis=-1).astype(x.dtype)
out[x == 0] = nbits
return out

eq = np.equal
ne = np.not_equal
ge = np.greater_equal
Expand Down
1 change: 1 addition & 0 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def op_record(op, nargs, dtypes, rng_factory, tol=None):
op_record("bitwise_or", 2, bool_dtypes, jtu.rand_small),
op_record("bitwise_xor", 2, bool_dtypes, jtu.rand_small),
op_record("population_count", 1, int_dtypes + uint_dtypes, jtu.rand_int),
op_record("clz", 1, int_dtypes + uint_dtypes, jtu.rand_int),

op_record("add", 2, default_dtypes + complex_dtypes, jtu.rand_small),
op_record("sub", 2, default_dtypes + complex_dtypes, jtu.rand_small),
Expand Down

0 comments on commit 3c377a2

Please sign in to comment.