From 4e65a6f0a9d792ebcd22ea29efffa485f1d76f8c Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 10 Oct 2020 21:08:52 -0700 Subject: [PATCH 1/4] don't generate lazy iota/eye/tri/delta omnistaging --- jax/lax/lax.py | 99 +++++++++++++++++++++++++++++------------ jax/random.py | 1 + tests/api_test.py | 10 ++--- tests/lax_numpy_test.py | 5 ++- 4 files changed, 80 insertions(+), 35 deletions(-) diff --git a/jax/lax/lax.py b/jax/lax/lax.py index 839cb0d6856a..9b239d525192 100644 --- a/jax/lax/lax.py +++ b/jax/lax/lax.py @@ -1409,55 +1409,73 @@ def iota(dtype: DType, size: int) -> Array: `_ operator. """ - size = size if type(size) is masking.Poly else int(size) - shape = canonicalize_shape((size,)) - dtype = dtypes.canonicalize_dtype(dtype) - lazy_expr = lazy.iota(dtype, shape[0]) - aval = ShapedArray(shape, dtype) - return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) + if config.omnistaging_enabled: + dtype = dtypes.canonicalize_dtype(dtype) + size = core.concrete_or_error(int, size, "size argument of lax.iota") + return iota_p.bind(dtype=dtype, shape=(size,), dimension=0) + else: + size = size if type(size) is masking.Poly else int(size) + shape = canonicalize_shape((size,)) + dtype = dtypes.canonicalize_dtype(dtype) + lazy_expr = lazy.iota(dtype, shape[0]) + aval = ShapedArray(shape, dtype) + return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array: """Convenience wrapper around ``iota``.""" dtype = dtypes.canonicalize_dtype(dtype) shape = canonicalize_shape(shape) - dimension = int(dimension) - return broadcast_in_dim(iota(dtype, shape[dimension]), shape, [dimension]) + dimension = core.concrete_or_error( + int, dimension, "dimension argument of lax.broadcasted_iota") + return iota_p.bind(dtype=dtype, shape=shape, dimension=dimension) def _eye(dtype: DType, shape: Shape, offset: int) -> Array: - """Like numpy.eye, create a 2D array with ones on a diagonal. - - This function exists for creating lazy identity matrices; that is, - materialization of the array is delayed and it may be fused into consumers to - avoid materialization at all.""" + """Like numpy.eye, create a 2D array with ones on a diagonal.""" N, M = tuple(map(int, shape)) offset = int(offset) dtype = dtypes.canonicalize_dtype(dtype) - lazy_expr = lazy.eye(dtype, (N, M), offset) - aval = ShapedArray((N, M), dtype) - return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) + if config.omnistaging_enabled: + bool_eye = eq(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)), + broadcasted_iota(np.int32, (N, M), 1)) + return convert_element_type_p.bind(bool_eye, new_dtype=dtype, + old_dtype=np.bool_) + else: + lazy_expr = lazy.eye(dtype, (N, M), offset) + aval = ShapedArray((N, M), dtype) + return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array: - """This function exists for creating lazy Kronecker delta arrays, particularly - for use in jax.numpy.einsum to express traces. It differs from ``eye`` in that - it can create arrays of any rank, but doesn't allow offsets.""" + """This utility function exists for creating Kronecker delta arrays.""" shape = tuple(map(int, shape)) axes = tuple(map(int, axes)) dtype = dtypes.canonicalize_dtype(dtype) base_shape = tuple(np.take(shape, axes)) - lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes) - aval = ShapedArray(shape, dtype) - return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) + if config.omnistaging_enabled: + iotas = [broadcasted_iota(np.uint32, base_shape, i) + for i in range(len(base_shape))] + eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])] + result = convert_element_type_p.bind(_reduce(operator.and_, eyes), + new_dtype=dtype, old_dtype=np.bool_) + return broadcast_in_dim(result, shape, axes) + else: + lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes) + aval = ShapedArray(shape, dtype) + return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) def _tri(dtype: DType, shape: Shape, offset: int) -> Array: - """Like numpy.tri, create a 2D array with ones below a diagonal. - This function exists for creating lazy triangular matrices, particularly for - use in jax.numpy.tri.""" + """Like numpy.tri, create a 2D array with ones below a diagonal.""" N, M = tuple(map(int, shape)) offset = int(offset) dtype = dtypes.canonicalize_dtype(dtype) - lazy_expr = lazy.tri(dtype, (N, M), offset) - aval = ShapedArray((N, M), dtype) - return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) + if config.omnistaging_enabled: + bool_tri = ge(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)), + broadcasted_iota(np.int32, (N, M), 1)) + return convert_element_type_p.bind(bool_tri, old_dtype=np.int32, + new_dtype=dtype) + else: + lazy_expr = lazy.tri(dtype, (N, M), offset) + aval = ShapedArray((N, M), dtype) + return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant()) def stop_gradient(x): """Stops gradient computation. @@ -5797,6 +5815,7 @@ def _outfeed_translation_rule(c, token, *xs): outfeed_p.def_abstract_eval(_outfeed_abstract_eval) xla.translations[outfeed_p] = _outfeed_translation_rule + def rng_uniform(a, b, shape): """Stateful PRNG generator. Experimental and its use is discouraged. @@ -5829,6 +5848,30 @@ def _rng_uniform_translation_rule(c, a, b, *, shape): rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval) xla.translations[rng_uniform_p] = _rng_uniform_translation_rule + +def _iota_abstract_eval(*, dtype, shape, dimension): + _check_shapelike("iota", "shape", shape) + if not any(dtypes.issubdtype(dtype, t) for t in _num): + msg = 'iota does not accept dtype {}. Accepted dtypes are subtypes of {}.' + typename = str(np.dtype(dtype).name) + accepted_typenames = (t.__name__ for t in _num) + raise TypeError(msg.format(typename, ', '.join(accepted_typenames))) + if not 0 <= dimension < len(shape): + raise ValueError("iota dimension must be between 0 and len(shape), got " + f"dimension={dimension} for shape {shape}") + return ShapedArray(shape, dtype) + +def _iota_translation_rule(c, dtype, shape, dimension): + etype = xla_client.dtype_to_etype(dtype) + xla_shape = xc.Shape.array_shape(etype, shape) + return xops.Iota(c, xla_shape, dimension) + +iota_p = Primitive('iota') +iota_p.def_impl(partial(xla.apply_primitive, iota_p)) +iota_p.def_abstract_eval(_iota_abstract_eval) +xla.translations[iota_p] = _iota_translation_rule + + ### util _ndim = np.ndim diff --git a/jax/random.py b/jax/random.py index 3823c8c114aa..98e5fa64ad38 100644 --- a/jax/random.py +++ b/jax/random.py @@ -300,6 +300,7 @@ def _fold_in(key, data): return threefry_2x32(key, PRNGKey(data)) +@partial(jit, static_argnums=(1, 2)) def _random_bits(key, bit_width, shape): """Sample uniform random bits of given width and shape using PRNG key.""" if not _is_prng_key(key): diff --git a/tests/api_test.py b/tests/api_test.py index e9d54d9e4368..8e8e052e06db 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -33,7 +33,7 @@ import jax import jax.numpy as jnp from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian -from jax import api, core, lax, lax_reference +from jax import api, core, lax, lax_reference, lazy from jax.core import Primitive from jax.interpreters import ad from jax.interpreters import xla @@ -1450,7 +1450,7 @@ def f(x): def test_dtype_warning(self): # cf. issue #1230 if FLAGS.jax_enable_x64: - return # test only applies when x64 is disabled + raise unittest.SkipTest("test only applies when x64 is disabled") def check_warning(warn, nowarn): with warnings.catch_warnings(record=True) as w: @@ -2443,14 +2443,14 @@ def f(x): assert python_should_be_executing return jnp.sum(x) - x = jnp.arange(10, dtype=jnp.int32) - assert xla.is_device_constant(x) # lazy iota + x = jnp.zeros(10, dtype=jnp.int32) + assert not lazy.is_trivial(x._lazy_expr) python_should_be_executing = True _ = f(x) python_should_be_executing = False # should not recompile - x = np.arange(10, dtype=np.int32) + x = np.zeros(10, dtype=np.int32) _ = f(x) @parameterized.parameters(jtu.cases_from_list(range(10000))) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 076283b9a9f4..07f20b0a5e69 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3623,8 +3623,9 @@ def testArange(self): type(lax.iota(np.int32, 77))) # test laziness for int dtypes - self.assertTrue(xla.is_device_constant(jnp.arange(77))) - self.assertTrue(xla.is_device_constant(jnp.arange(77, dtype=jnp.int32))) + if not config.omnistaging_enabled: + self.assertTrue(xla.is_device_constant(jnp.arange(77))) + self.assertTrue(xla.is_device_constant(jnp.arange(77, dtype=jnp.int32))) def testArangeJit(self): ans = api.jit(lambda: jnp.arange(5))() From b402b87081f77960fe9ac6181ab5674d1ae801b7 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Sat, 10 Oct 2020 21:51:51 -0700 Subject: [PATCH 2/4] fix type error in jax2tf --- jax/experimental/jax2tf/jax2tf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index c807e366bdf6..f3d298f1d98d 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -218,7 +218,7 @@ def converted_fun_flat_with_custom_gradient(*args_flat: TfVal) -> TfVal: def _interpret_fun(fun: lu.WrappedFun, in_vals: Sequence[TfValOrUnit]) -> Sequence[TfValOrUnit]: new_main = core.new_base_main if config.omnistaging_enabled else core.new_main - with new_main(TensorFlowTrace) as main: + with new_main(TensorFlowTrace) as main: # type: ignore fun = _interpret_subtrace(fun, main) out_vals: Sequence[TfValOrUnit] = fun.call_wrapped(*in_vals) del main From b99e350aed0e2c59cdcc5a1fc69d8dd8a7932f9f Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 14 Oct 2020 14:30:09 -0700 Subject: [PATCH 3/4] add iota_p support to jax2tf --- jax/experimental/jax2tf/jax2tf.py | 29 ++++++++++++++----- .../jax2tf/tests/correctness_stats.py | 2 +- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index f3d298f1d98d..33b3f8c3cf37 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -24,7 +24,6 @@ from jax import core from jax import custom_derivatives from jax import dtypes -from jax import lax from jax import lax_linalg from jax import linear_util as lu from jax import numpy as jnp @@ -36,8 +35,10 @@ from jax.interpreters import partial_eval as pe from jax.interpreters import pxla from jax.interpreters import xla +from jax.lax import lax from jax.lax import lax_control_flow from jax.lax import lax_fft +from jax.lax import lax_parallel import numpy as np import tensorflow as tf # type: ignore[import] @@ -447,7 +448,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): tf_not_yet_impl = [ lax.reduce_p, lax.rng_uniform_p, - lax.linear_solve_p, + lax_control_flow.linear_solve_p, lax_linalg.lu_p, lax_linalg.triangular_solve_p, @@ -455,9 +456,10 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): lax.random_gamma_grad_p, # Not high priority? - lax.after_all_p, lax.all_to_all_p, lax.create_token_p, lax.cummax_p, lax.cummin_p, - lax.infeed_p, lax.outfeed_p, lax.pmax_p, lax.pmin_p, lax.ppermute_p, lax.psum_p, - lax.axis_index_p, + lax.after_all_p, lax_parallel.all_to_all_p, lax.create_token_p, lax.cummax_p, + lax.cummin_p, lax.infeed_p, lax.outfeed_p, lax_parallel.pmax_p, + lax_parallel.pmin_p, lax_parallel.ppermute_p, lax_parallel.psum_p, + lax_parallel.axis_index_p, pxla.xla_pmap_p, ] @@ -526,6 +528,17 @@ def _population_count(x): tf_impl[lax.mul_p] = wrap_binary_op(tf.math.multiply) +def _iota(*, dtype, shape, dimension): + size = shape[dimension] + # Some dtypes are unsupporetd, like uint32, so we just fall back to int32. + # TODO(mattjj, necula): improve tf.range dtype handling + vec = tf.range(tf.cast(size, tf.int32), dtype=tf.int32) + vec_shape = [-1 if i == dimension else 1 for i in range(len(shape))] + return tf.cast(tf.broadcast_to(tf.reshape(vec, vec_shape), shape), dtype) + +tf_impl[lax.iota_p] = _iota + + def _div(lhs, rhs): if lhs.dtype.is_integer: quotient = tf.math.floordiv(lhs, rhs) @@ -1249,7 +1262,7 @@ def _cond(index: TfVal, *operands: TfValOrUnit, res_tf: Sequence[TfVal] = tf.switch_case(index, branches_tf) return _tfval_add_unit(res_tf, branches[0].out_avals) -tf_impl[lax.cond_p] = _cond +tf_impl[lax_control_flow.cond_p] = _cond def _while(*args: TfValOrUnit, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr, @@ -1317,10 +1330,10 @@ def select_one_carry(new_c: TfVal, c: TfVal) -> TfVal: _tfval_remove_unit((init_pred_b, *init_carry))) return _tfval_add_unit(res_carry, body_jaxpr.out_avals) -tf_impl[lax.while_p] = _while +tf_impl[lax_control_flow.while_p] = _while # We use the scan impl rule to rewrite in terms of while. -tf_impl[lax.scan_p] = _convert_jax_impl(lax_control_flow._scan_impl) +tf_impl[lax_control_flow.scan_p] = _convert_jax_impl(lax_control_flow._scan_impl) def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]: # Some types originally incompatible with tf.math.top_k can be promoted diff --git a/jax/experimental/jax2tf/tests/correctness_stats.py b/jax/experimental/jax2tf/tests/correctness_stats.py index 9c34ebbd41d5..929cf32b9cbc 100644 --- a/jax/experimental/jax2tf/tests/correctness_stats.py +++ b/jax/experimental/jax2tf/tests/correctness_stats.py @@ -91,7 +91,7 @@ def _to_np_dtype(dtype) -> NpDType: pass return np.dtype(dtype) - if args[0] is not core.unit: + if args and args[0] is not core.unit: np_dtype = _to_np_dtype(args[0].dtype) else: np_dtype = None From a5f906a9f69acff76203db82a6262fac7d5708c6 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Wed, 14 Oct 2020 14:40:29 -0700 Subject: [PATCH 4/4] jax2tf import fix --- jax/experimental/jax2tf/jax2tf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 33b3f8c3cf37..5e0231f7e0b9 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -465,7 +465,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): ] try: - tf_impl[lax.lax.tie_in_p] = lambda x, y: y + tf_impl[lax.tie_in_p] = lambda x, y: y except AttributeError: pass tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient @@ -920,8 +920,8 @@ def _select_and_gather_add(tangents: TfVal, const = lambda dtype, x: tf.constant(np.array(x), dtype) if double_word_reduction: - word_dtype = lax.lax._UINT_DTYPES[nbits] - double_word_dtype = lax.lax._UINT_DTYPES[nbits * 2] + word_dtype = lax._UINT_DTYPES[nbits] + double_word_dtype = lax._UINT_DTYPES[nbits * 2] # Packs two values into a tuple. def pack(a, b):