Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

don't generate lazy iota/eye/tri/delta with omnistaging #4535

Merged
merged 5 commits into from
Oct 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 24 additions & 11 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from jax import core
from jax import custom_derivatives
from jax import dtypes
from jax import lax
from jax import lax_linalg
from jax import linear_util as lu
from jax import numpy as jnp
Expand All @@ -38,8 +37,10 @@
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.interpreters import xla
from jax.lax import lax
from jax.lax import lax_control_flow
from jax.lax import lax_fft
from jax.lax import lax_parallel

import numpy as np
import tensorflow as tf # type: ignore[import]
Expand Down Expand Up @@ -609,23 +610,24 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
tf_not_yet_impl = [
lax.reduce_p, lax.rng_uniform_p,

lax.linear_solve_p,
lax_control_flow.linear_solve_p,
lax_linalg.lu_p,
lax_linalg.triangular_solve_p,

lax.igamma_grad_a_p,
lax.random_gamma_grad_p,

# Not high priority?
lax.after_all_p, lax.all_to_all_p, lax.create_token_p, lax.cummax_p, lax.cummin_p,
lax.infeed_p, lax.outfeed_p, lax.pmax_p, lax.pmin_p, lax.ppermute_p, lax.psum_p,
lax.axis_index_p,
lax.after_all_p, lax_parallel.all_to_all_p, lax.create_token_p, lax.cummax_p,
lax.cummin_p, lax.infeed_p, lax.outfeed_p, lax_parallel.pmax_p,
lax_parallel.pmin_p, lax_parallel.ppermute_p, lax_parallel.psum_p,
lax_parallel.axis_index_p,

pxla.xla_pmap_p,
]

try:
tf_impl[lax.lax.tie_in_p] = lambda x, y: y
tf_impl[lax.tie_in_p] = lambda x, y: y
except AttributeError:
pass
tf_impl[ad_util.stop_gradient_p] = tf.stop_gradient
Expand Down Expand Up @@ -688,6 +690,17 @@ def _population_count(x):
tf_impl[lax.mul_p] = wrap_binary_op(tf.math.multiply)


def _iota(*, dtype, shape, dimension):
size = shape[dimension]
# Some dtypes are unsupporetd, like uint32, so we just fall back to int32.
# TODO(mattjj, necula): improve tf.range dtype handling
vec = tf.range(tf.cast(size, tf.int32), dtype=tf.int32)
vec_shape = [-1 if i == dimension else 1 for i in range(len(shape))]
return tf.cast(tf.broadcast_to(tf.reshape(vec, vec_shape), shape), dtype)

tf_impl[lax.iota_p] = _iota


def _div(lhs, rhs):
if lhs.dtype.is_integer:
quotient = tf.math.floordiv(lhs, rhs)
Expand Down Expand Up @@ -1227,8 +1240,8 @@ def _select_and_gather_add(tangents: TfVal,
const = lambda dtype, x: tf.constant(np.array(x), dtype)

if double_word_reduction:
word_dtype = lax.lax._UINT_DTYPES[nbits]
double_word_dtype = lax.lax._UINT_DTYPES[nbits * 2]
word_dtype = lax._UINT_DTYPES[nbits]
double_word_dtype = lax._UINT_DTYPES[nbits * 2]

# Packs two values into a tuple.
def pack(a, b):
Expand Down Expand Up @@ -1536,7 +1549,7 @@ def _cond(index: TfVal, *operands: TfVal,
for jaxpr in branches]
return tf.switch_case(index, branches_tf)

tf_impl[lax.cond_p] = _cond
tf_impl[lax_control_flow.cond_p] = _cond


def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr,
Expand Down Expand Up @@ -1603,10 +1616,10 @@ def select_one_carry(new_c: TfVal, c: TfVal) -> TfVal:
(init_pred_b, *init_carry))
return res_carry

tf_impl[lax.while_p] = _while
tf_impl[lax_control_flow.while_p] = _while

# We use the scan impl rule to rewrite in terms of while.
tf_impl_with_avals[lax.scan_p] = _convert_jax_impl(lax_control_flow._scan_impl)
tf_impl_with_avals[lax_control_flow.scan_p] = _convert_jax_impl(lax_control_flow._scan_impl)

def _top_k(operand: TfVal, k: int) -> Tuple[TfVal, TfVal]:
# Some types originally incompatible with tf.math.top_k can be promoted
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/correctness_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _to_np_dtype(dtype) -> NpDType:
pass
return np.dtype(dtype)

if args[0] is not core.unit:
if args and args[0] is not core.unit:
np_dtype = _to_np_dtype(args[0].dtype)
else:
np_dtype = None
Expand Down
99 changes: 71 additions & 28 deletions jax/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -1407,55 +1407,73 @@ def iota(dtype: DType, size: int) -> Array:
<https://www.tensorflow.org/xla/operation_semantics#iota>`_
operator.
"""
size = size if type(size) is masking.Poly else int(size)
shape = canonicalize_shape((size,))
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.iota(dtype, shape[0])
aval = ShapedArray(shape, dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
if config.omnistaging_enabled:
dtype = dtypes.canonicalize_dtype(dtype)
size = core.concrete_or_error(int, size, "size argument of lax.iota")
return iota_p.bind(dtype=dtype, shape=(size,), dimension=0)
else:
size = size if type(size) is masking.Poly else int(size)
shape = canonicalize_shape((size,))
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.iota(dtype, shape[0])
aval = ShapedArray(shape, dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def broadcasted_iota(dtype: DType, shape: Shape, dimension: int) -> Array:
"""Convenience wrapper around ``iota``."""
dtype = dtypes.canonicalize_dtype(dtype)
shape = canonicalize_shape(shape)
dimension = int(dimension)
return broadcast_in_dim(iota(dtype, shape[dimension]), shape, [dimension])
dimension = core.concrete_or_error(
int, dimension, "dimension argument of lax.broadcasted_iota")
return iota_p.bind(dtype=dtype, shape=shape, dimension=dimension)

def _eye(dtype: DType, shape: Shape, offset: int) -> Array:
"""Like numpy.eye, create a 2D array with ones on a diagonal.

This function exists for creating lazy identity matrices; that is,
materialization of the array is delayed and it may be fused into consumers to
avoid materialization at all."""
"""Like numpy.eye, create a 2D array with ones on a diagonal."""
N, M = tuple(map(int, shape))
offset = int(offset)
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.eye(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
if config.omnistaging_enabled:
bool_eye = eq(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)),
broadcasted_iota(np.int32, (N, M), 1))
return convert_element_type_p.bind(bool_eye, new_dtype=dtype,
old_dtype=np.bool_)
else:
lazy_expr = lazy.eye(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def _delta(dtype: DType, shape: Shape, axes: Sequence[int]) -> Array:
"""This function exists for creating lazy Kronecker delta arrays, particularly
for use in jax.numpy.einsum to express traces. It differs from ``eye`` in that
it can create arrays of any rank, but doesn't allow offsets."""
"""This utility function exists for creating Kronecker delta arrays."""
shape = tuple(map(int, shape))
axes = tuple(map(int, axes))
dtype = dtypes.canonicalize_dtype(dtype)
base_shape = tuple(np.take(shape, axes))
lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes)
aval = ShapedArray(shape, dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
if config.omnistaging_enabled:
iotas = [broadcasted_iota(np.uint32, base_shape, i)
for i in range(len(base_shape))]
eyes = [eq(i1, i2) for i1, i2 in zip(iotas[:-1], iotas[1:])]
result = convert_element_type_p.bind(_reduce(operator.and_, eyes),
new_dtype=dtype, old_dtype=np.bool_)
return broadcast_in_dim(result, shape, axes)
else:
lazy_expr = lazy.broadcast(lazy.delta(dtype, base_shape), shape, axes)
aval = ShapedArray(shape, dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def _tri(dtype: DType, shape: Shape, offset: int) -> Array:
"""Like numpy.tri, create a 2D array with ones below a diagonal.
This function exists for creating lazy triangular matrices, particularly for
use in jax.numpy.tri."""
"""Like numpy.tri, create a 2D array with ones below a diagonal."""
N, M = tuple(map(int, shape))
offset = int(offset)
dtype = dtypes.canonicalize_dtype(dtype)
lazy_expr = lazy.tri(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())
if config.omnistaging_enabled:
bool_tri = ge(add(broadcasted_iota(np.int32, (N, M), 0), np.int32(offset)),
broadcasted_iota(np.int32, (N, M), 1))
return convert_element_type_p.bind(bool_tri, old_dtype=np.int32,
new_dtype=dtype)
else:
lazy_expr = lazy.tri(dtype, (N, M), offset)
aval = ShapedArray((N, M), dtype)
return xla.DeviceArray(aval, None, lazy_expr, xla.DeviceConstant())

def stop_gradient(x):
"""Stops gradient computation.
Expand Down Expand Up @@ -5795,6 +5813,7 @@ def _outfeed_translation_rule(c, token, *xs):
outfeed_p.def_abstract_eval(_outfeed_abstract_eval)
xla.translations[outfeed_p] = _outfeed_translation_rule


def rng_uniform(a, b, shape):
"""Stateful PRNG generator. Experimental and its use is discouraged.

Expand Down Expand Up @@ -5827,6 +5846,30 @@ def _rng_uniform_translation_rule(c, a, b, *, shape):
rng_uniform_p.def_abstract_eval(_rng_uniform_abstract_eval)
xla.translations[rng_uniform_p] = _rng_uniform_translation_rule


def _iota_abstract_eval(*, dtype, shape, dimension):
_check_shapelike("iota", "shape", shape)
if not any(dtypes.issubdtype(dtype, t) for t in _num):
msg = 'iota does not accept dtype {}. Accepted dtypes are subtypes of {}.'
typename = str(np.dtype(dtype).name)
accepted_typenames = (t.__name__ for t in _num)
raise TypeError(msg.format(typename, ', '.join(accepted_typenames)))
if not 0 <= dimension < len(shape):
raise ValueError("iota dimension must be between 0 and len(shape), got "
f"dimension={dimension} for shape {shape}")
return ShapedArray(shape, dtype)

def _iota_translation_rule(c, dtype, shape, dimension):
etype = xla_client.dtype_to_etype(dtype)
xla_shape = xc.Shape.array_shape(etype, shape)
return xops.Iota(c, xla_shape, dimension)

iota_p = Primitive('iota')
iota_p.def_impl(partial(xla.apply_primitive, iota_p))
iota_p.def_abstract_eval(_iota_abstract_eval)
xla.translations[iota_p] = _iota_translation_rule


### util

_ndim = np.ndim
Expand Down
1 change: 1 addition & 0 deletions jax/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def _fold_in(key, data):
return threefry_2x32(key, PRNGKey(data))


@partial(jit, static_argnums=(1, 2))
def _random_bits(key, bit_width, shape):
"""Sample uniform random bits of given width and shape using PRNG key."""
if not _is_prng_key(key):
Expand Down
10 changes: 5 additions & 5 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import jax
import jax.numpy as jnp
from jax import float0, jit, grad, device_put, jacfwd, jacrev, hessian
from jax import api, core, lax, lax_reference
from jax import api, core, lax, lax_reference, lazy
from jax.core import Primitive
from jax.interpreters import ad
from jax.interpreters import xla
Expand Down Expand Up @@ -1450,7 +1450,7 @@ def f(x):
def test_dtype_warning(self):
# cf. issue #1230
if FLAGS.jax_enable_x64:
return # test only applies when x64 is disabled
raise unittest.SkipTest("test only applies when x64 is disabled")

def check_warning(warn, nowarn):
with warnings.catch_warnings(record=True) as w:
Expand Down Expand Up @@ -2443,14 +2443,14 @@ def f(x):
assert python_should_be_executing
return jnp.sum(x)

x = jnp.arange(10, dtype=jnp.int32)
assert xla.is_device_constant(x) # lazy iota
x = jnp.zeros(10, dtype=jnp.int32)
assert not lazy.is_trivial(x._lazy_expr)

python_should_be_executing = True
_ = f(x)

python_should_be_executing = False # should not recompile
x = np.arange(10, dtype=np.int32)
x = np.zeros(10, dtype=np.int32)
_ = f(x)

@parameterized.parameters(jtu.cases_from_list(range(10000)))
Expand Down
5 changes: 3 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3651,8 +3651,9 @@ def testArange(self):
type(lax.iota(np.int32, 77)))

# test laziness for int dtypes
self.assertTrue(xla.is_device_constant(jnp.arange(77)))
self.assertTrue(xla.is_device_constant(jnp.arange(77, dtype=jnp.int32)))
if not config.omnistaging_enabled:
self.assertTrue(xla.is_device_constant(jnp.arange(77)))
self.assertTrue(xla.is_device_constant(jnp.arange(77, dtype=jnp.int32)))

def testArangeJit(self):
ans = api.jit(lambda: jnp.arange(5))()
Expand Down