Skip to content

Commit

Permalink
Merge pull request #4535 from google:lazy-simplification
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 337183224
  • Loading branch information
jax authors committed Oct 14, 2020
2 parents c4f08ff + 990dc57 commit fb01f59
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 47 deletions.
35 changes: 24 additions & 11 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from jax import core
from jax import custom_derivatives
from jax import dtypes
from jax import lax
from jax import lax_linalg
from jax import linear_util as lu
from jax import numpy as jnp
Expand All @@ -38,8 +37,10 @@
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.lax import lax
from jax.lax import lax_control_flow
from jax.lax import lax_fft
from jax.lax import lax_parallel

import numpy as np
import tensorflow as tf # type: ignore[import]
Expand Down Expand Up @@ -609,23 +610,24 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
tf_not_yet_impl = [
lax.reduce_p, lax.rng_uniform_p,

lax.linear_solve_p,
lax_control_flow.linear_solve_p,
lax_linalg.lu_p,
lax_linalg.triangular_solve_p,

lax.igamma_grad_a_p,
lax.random_gamma_grad_p,

# Not high priority?
lax.after_all_p, lax.all_to_all_p, lax.create_token_p, lax.cummax_p, lax.cummin_p,
lax.infeed_p, lax.outfeed_p, lax.pmax_p, lax.pmin_p, lax.ppermute_p, lax.psum_p,
lax.axis_index_p,
lax.after_all_p, lax_parallel.all_to_all_p, lax.create_token_p, lax.cummax_p,
lax.cummin_p, lax.infeed_p, lax.outfeed_p, lax_parallel.pmax_p,
lax_parallel.pmin_p, lax_parallel.ppermute_p, lax_parallel.psum_p,
lax_parallel.axis_index_p,

pxla.xla_pmap_p,
]

try:
tf_impl[lax.lax.tie_in_p] = lambda x, y: y
tf_impl[lax.tie_in_p] = lambda x, y: y
except AttributeError:
pass
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
Expand Down Expand Up @@ -688,6 +690,17 @@ def _population_count(x):
tf_impl[lax.mul_p] = wrap_binary_op(tf.math.multiply)


def _iota(*, dtype, shape, dimension):
size = shape[dimension]
# Some dtypes are unsupporetd, like uint32, so we just fall back to int32.
# TODO(mattjj, necula): improve tf.range dtype handling
vec = tf.range(tf.cast(size, tf.int32), dtype=tf.int32)
vec_shape = [-1 if i == dimension else 1 for i in range(len(shape))]
return tf.cast(tf.broadcast_to(tf.reshape(vec, vec_shape), shape), dtype)

tf_impl[lax.iota_p] = _iota


def _div(lhs, rhs):
if lhs.dtype.is_integer:
quotient = tf.math.floordiv(lhs, rhs)
Expand Down Expand Up @@ -1227,8 +1240,8 @@ def _select_and_gather_add(tangents: TfVal,
const = lambda dtype, x: tf.constant(np.array(x), dtype)

if double_word_reduction:
word_dtype = lax.lax._UINT_DTYPES[nbits]
double_word_dtype = lax.lax._UINT_DTYPES[nbits * 2]
word_dtype = lax._UINT_DTYPES[nbits]
double_word_dtype = lax._UINT_DTYPES[nbits * 2]

# Packs two values into a tuple.
def pack(a, b):
Expand Down Expand Up @@ -1536,7 +1549,7 @@ def _cond(index: TfVal, *operands: TfVal,
for jaxpr in branches]
return tf.switch_case(index, branches_tf)

tf_impl[lax.cond_p] = _cond
tf_impl[lax_control_flow.cond_p] = _cond


def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr,
Expand Down Expand Up @@ -1603,10 +1616,10 @@ def select_one_carry(new_c: TfVal, c: TfVal) -> TfVal:
(init_pred_b, *init_carry))
return res_carry

tf_impl[lax.while_p] = _while
tf_impl[lax_control_flow.while_p] = _while

# We use the scan impl rule to rewrite in terms of while.
tf_impl_with_avals[lax.scan_p] = _convert_jax_impl(lax_control_flow._scan_impl)
tf_impl_with_avals[lax_control_flow.scan_p] = _convert_jax_impl(lax_control_flow._scan_impl)

def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]:
# Some types originally incompatible with tf.math.top_k can be promoted
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/correctness_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _to_np_dtype(dtype) -> NpDType:
pass
return np.dtype(dtype)

if args[0] is not core.unit:
if args and args[0] is not core.unit:
np_dtype = _to_np_dtype(args[0].dtype)
else:
np_dtype = None
Expand Down
99 changes: 71 additions & 28 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,55 +1407,73 @@ def iota(dtype: DType, size: int) -> Array:
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
operator.
"""
size = size if type(size) is masking.Poly else int(size)
shape = canonicalize_shape((size,))
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.iota(dtype, shape[0])
aval = ShapedArray(shape, dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
if config.omnistaging_enabled:
dtype = dtypes.canonicalize_dtype(dtype)
size = core.concrete_or_error(int, size, "size argument of lax.iota")
return iota_p.bind(dtype=dtype, shape=(size,), dimension=0)
else:
size = size if type(size) is masking.Poly else int(size)
shape = canonicalize_shape((size,))
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.iota(dtype, shape[0])
aval = ShapedArray(shape, dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array:
"""Convenience wrapper around ``iota``."""
dtype = dtypes.canonicalize_dtype(dtype)
shape = canonicalize_shape(shape)
dimension = int(dimension)
return broadcast_in_dim(iota(dtype, shape[dimension]), shape, [dimension])
dimension = core.concrete_or_error(
int, dimension, "dimension argument of lax.broadcasted_iota")
return iota_p.bind(dtype=dtype, shape=shape, dimension=dimension)

def _eye(dtype: DType, shape: Shape, offset: int) -> Array:
"""Like numpy.eye, create a 2D array with ones on a diagonal.
This function exists for creating lazy identity matrices; that is,
materialization of the array is delayed and it may be fused into consumers to
avoid materialization at all."""
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
N, M = tuple(map(int, shape))
offset = int(offset)
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.eye(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
if config.omnistaging_enabled:
bool_eye = eq(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)),
broadcasted_iota(np.int32, (N, M), 1))
return convert_element_type_p.bind(bool_eye, new_dtype=dtype,
old_dtype=np.bool_)
else:
lazy_expr = lazy.eye(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
"""This function exists for creating lazy Kronecker delta arrays, particularly
for use in jax.numpy.einsum to express traces. It differs from ``eye`` in that
it can create arrays of any rank, but doesn't allow offsets."""
"""This utility function exists for creating Kronecker delta arrays."""
shape = tuple(map(int, shape))
axes = tuple(map(int, axes))
dtype = dtypes.canonicalize_dtype(dtype)
base_shape = tuple(np.take(shape, axes))
lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes)
aval = ShapedArray(shape, dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
if config.omnistaging_enabled:
iotas = [broadcasted_iota(np.uint32, base_shape, i)
for i in range(len(base_shape))]
eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
result = convert_element_type_p.bind(_reduce(operator.and_, eyes),
new_dtype=dtype, old_dtype=np.bool_)
return broadcast_in_dim(result, shape, axes)
else:
lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes)
aval = ShapedArray(shape, dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def _tri(dtype: DType, shape: Shape, offset: int) -> Array:
"""Like numpy.tri, create a 2D array with ones below a diagonal.
This function exists for creating lazy triangular matrices, particularly for
use in jax.numpy.tri."""
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
N, M = tuple(map(int, shape))
offset = int(offset)
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.tri(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
if config.omnistaging_enabled:
bool_tri = ge(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)),
broadcasted_iota(np.int32, (N, M), 1))
return convert_element_type_p.bind(bool_tri, old_dtype=np.int32,
new_dtype=dtype)
else:
lazy_expr = lazy.tri(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def stop_gradient(x):
"""Stops gradient computation.
Expand Down Expand Up @@ -5779,6 +5797,7 @@ def _outfeed_translation_rule(c, token, *xs):
outfeed_p.def_abstract_eval(_outfeed_abstract_eval)
xla.translations[outfeed_p] = _outfeed_translation_rule


def rng_uniform(a, b, shape):
"""Stateful PRNG generator. Experimental and its use is discouraged.
Expand Down Expand Up @@ -5811,6 +5830,30 @@ def _rng_uniform_translation_rule(c, a, b, *, shape):
rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
xla.translations[rng_uniform_p] = _rng_uniform_translation_rule


def _iota_abstract_eval(*, dtype, shape, dimension):
_check_shapelike("iota", "shape", shape)
if not any(dtypes.issubdtype(dtype, t) for t in _num):
msg = 'iota does not accept dtype {}. Accepted dtypes are subtypes of {}.'
typename = str(np.dtype(dtype).name)
accepted_typenames = (t.__name__ for t in _num)
raise TypeError(msg.format(typename, ', '.join(accepted_typenames)))
if not 0 <= dimension < len(shape):
raise ValueError("iota dimension must be between 0 and len(shape), got "
f"dimension={dimension} for shape {shape}")
return ShapedArray(shape, dtype)

def _iota_translation_rule(c, dtype, shape, dimension):
etype = xla_client.dtype_to_etype(dtype)
xla_shape = xc.Shape.array_shape(etype, shape)
return xops.Iota(c, xla_shape, dimension)

iota_p = Primitive('iota')
iota_p.def_impl(partial(xla.apply_primitive, iota_p))
iota_p.def_abstract_eval(_iota_abstract_eval)
xla.translations[iota_p] = _iota_translation_rule


### util

_ndim = np.ndim
Expand Down
1 change: 1 addition & 0 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ def _fold_in(key, data):
return threefry_2x32(key, PRNGKey(data))


@partial(jit, static_argnums=(1, 2))
def _random_bits(key, bit_width, shape):
"""Sample uniform random bits of given width and shape using PRNG key."""
if not _is_prng_key(key):
Expand Down
10 changes: 5 additions & 5 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import jax
import jax.numpy as jnp
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
from jax import api, core, lax, lax_reference
from jax import api, core, lax, lax_reference, lazy
from jax.core import Primitive
from jax.interpreters import ad
from jax.interpreters import xla
Expand Down Expand Up @@ -1477,7 +1477,7 @@ def f(x):
def test_dtype_warning(self):
# cf. issue #1230
if FLAGS.jax_enable_x64:
return # test only applies when x64 is disabled
raise unittest.SkipTest("test only applies when x64 is disabled")

def check_warning(warn, nowarn):
with warnings.catch_warnings(record=True) as w:
Expand Down Expand Up @@ -2470,14 +2470,14 @@ def f(x):
assert python_should_be_executing
return jnp.sum(x)

x = jnp.arange(10, dtype=jnp.int32)
assert xla.is_device_constant(x) # lazy iota
x = jnp.zeros(10, dtype=jnp.int32)
assert not lazy.is_trivial(x._lazy_expr)

python_should_be_executing = True
_ = f(x)

python_should_be_executing = False # should not recompile
x = np.arange(10, dtype=np.int32)
x = np.zeros(10, dtype=np.int32)
_ = f(x)

@parameterized.parameters(jtu.cases_from_list(range(10000)))
Expand Down
5 changes: 3 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3651,8 +3651,9 @@ def testArange(self):
type(lax.iota(np.int32, 77)))

# test laziness for int dtypes
self.assertTrue(xla.is_device_constant(jnp.arange(77)))
self.assertTrue(xla.is_device_constant(jnp.arange(77, dtype=jnp.int32)))
if not config.omnistaging_enabled:
self.assertTrue(xla.is_device_constant(jnp.arange(77)))
self.assertTrue(xla.is_device_constant(jnp.arange(77, dtype=jnp.int32)))

def testArangeJit(self):
ans = api.jit(lambda: jnp.arange(5))()
Expand Down

0 comments on commit fb01f59

Please sign in to comment.