Skip to content

Commit

Permalink
Attempt to land #6400 again.
Browse files Browse the repository at this point in the history
This PR changes `jax.numpy.array()` to avoid creating any on-device arrays during tracing. As a consequence, calls to `jnp.array()` in a traced context, such as `jax.jit` will always be staged into the trace.

This change may break code that depends on the current (undocumented and unintentional) behavior of `jnp.array()` to perform shape or index calculations that must be known statically (at trace time). The workaround for such cases is to use classic NumPy to perform shape/index calculations.

PiperOrigin-RevId: 395998020
  • Loading branch information
hawkinsp authored and jax authors committed Sep 21, 2021
1 parent 9a5cf7c commit 06a0123
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 35 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Many more `jax.numpy` functions now require array-like inputs, and will error
if passed a list ({jax-issue}`#7747` {jax-issue}`#7802` {jax-issue}`#7907`).
See {jax-issue}`#7737` for a discussion of the rationale behind this change.
* When inside a transformation such as `jax.jit`, `jax.numpy.array` always
stages the array it produces into the traced computation. Previously
`jax.numpy.array` would sometimes produce a on-device array, even under
a `jax.jit` decorator. This change may break code that used JAX arrays to
perform shape or index computations that must be known statically; the
workaround is to perform such computations using classic NumPy arrays
instead.
* New features:
* Added {func}`jax.numpy.insert` implementation ({jax-issue}`#7936`).

Expand Down
30 changes: 12 additions & 18 deletions docs/jaxpr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,8 @@ For the example consider the function ``func11`` below
f = convert_element_type[ new_dtype=float32
weak_type=False ] b
g = add f e
h = convert_element_type[ new_dtype=float32
weak_type=False ] a
i = add g h
in (i, b) }
h = add g a
in (h, b) }
length=16
linear=(False, False, False, False)
num_carry=1
Expand Down Expand Up @@ -424,13 +422,11 @@ which the computation should run. For example
call_jaxpr={ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
d = convert_element_type[ new_dtype=float32
weak_type=False ] a
e = mul d c
f = convert_element_type[ new_dtype=float32
d = mul a c
e = convert_element_type[ new_dtype=float32
weak_type=False ] b
g = add f e
in (g,) }
f = add e d
in (f,) }
device=None
donated_invars=(False, False)
inline=False
Expand Down Expand Up @@ -461,16 +457,14 @@ captured using the ``xla_pmap`` primitive. Consider this example
axis_size=1
backend=None
call_jaxpr={ lambda ; a b.
let c = convert_element_type[ new_dtype=float32
weak_type=False ] a
d = add b c
e = broadcast_in_dim[ broadcast_dimensions=( )
let c = add b a
d = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
f = add d e
g = psum[ axes=('rows',)
e = add c d
f = psum[ axes=('rows',)
axis_index_groups=None ] b
h = div f g
in (h,) }
g = div e f
in (g,) }
devices=None
donated_invars=(False, False)
global_arg_shapes=(None,)
Expand Down
11 changes: 7 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
from jax.interpreters import pxla
from jax import lax
from jax._src.lax.lax import _device_put_raw
from jax._src.ops import scatter
from jax._src.util import (unzip2, prod as _prod, subvals, safe_zip, ceil_of_ratio,
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
Expand Down Expand Up @@ -3280,25 +3279,29 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
# large integers; see discussion in https://github.com/google/jax/pull/6047.
object = _np_array(object, dtype=dtype, ndmin=ndmin, copy=False)

# call _np_array a second time with canonicalized dtype
dtype = dtypes.canonicalize_dtype(object.dtype)
object = _np_array(object, dtype=dtype, copy=False)

assert type(object) not in dtypes.python_scalar_dtypes

if type(object) is np.ndarray:
_inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)
lax._check_user_dtype_supported(_inferred_dtype, "array")
out = _device_put_raw(object, weak_type=weak_type)
out = _np_array(object, copy=copy, dtype=dtype)
if dtype: assert _dtype(out) == dtype
elif isinstance(object, (DeviceArray, core.Tracer)):
if isinstance(object, DeviceArray) and copy:
# We perform a copy by bouncing back to the host
# TODO(phawkins): add a device runtime function to copy a buffer
out = _device_put_raw(_np_asarray(object), weak_type=weak_type)
out = _np_asarray(object)
else:
out = object
elif isinstance(object, (list, tuple)):
if object:
out = stack([asarray(elt, dtype=dtype) for elt in object])
else:
out = _device_put_raw(_np_array([], dtype=dtype))
out = _np_array([], dtype=dtype)
else:
try:
view = memoryview(object)
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import scipy.special as osp_special

import jax
from jax._src import api
from jax import jit
from jax import lax, core
Expand Down Expand Up @@ -140,7 +141,8 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
if not np.issubdtype(out.dtype, np.complexfloating):
# Use jnp.array(nan) to avoid false positives in debug_nans
# (see https://github.com/google/jax/issues/7634)
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
with jax.debug_nans(False):
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
return out


Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@
from jax import core
from jax import dlpack
from jax import dtypes
from jax import numpy as jnp
from jax import tree_util
from jax._src import util
from jax._src import ad_util
from jax._src.lax.lax import _device_put_raw
from jax.interpreters import xla
from jax.lib import xla_client
from . import jax2tf as jax2tf_internal
Expand Down Expand Up @@ -217,7 +217,7 @@ def _res_tf_to_jax(res_tf: TfVal):
res_dlpack = tf.experimental.dlpack.to_dlpack(res_tf)
return jax.dlpack.from_dlpack(res_dlpack)

return jnp.asarray(np.asarray(res_tf))
return _device_put_raw(np.asarray(res_tf))

return list(map(_res_tf_to_jax, res_tf_flat))

Expand Down
38 changes: 28 additions & 10 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,7 @@ def to_jaxpr(self, in_tracers, out_tracers):
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
jaxpr, constvals = _prune_convert_element_types(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
out_avals = [t.aval for t in out_tracers]
return jaxpr, out_avals, constvals
Expand All @@ -1192,7 +1193,28 @@ def find_progenitors(self, tracer):
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
return invar_positions, const_eqns

def _prune_convert_element_types(jaxpr, constvals):
consts = dict(zip(jaxpr.constvars, constvals))
new_eqns = []
for eqn in jaxpr.eqns:
if eqn.primitive is core.convert_element_type_p:
c = consts.get(eqn.invars[0])
if type(c) in core.literalable_types and not np.shape(c):
# constant-fold dtype conversion of literals to be inlined
consts[eqn.outvars[0]] = np.array(c, eqn.params['new_dtype'])
continue
if c is not None and dtypes.dtype(c) == eqn.params['new_dtype']:
# don't stage out no-op convert_element_type calls as clutter
consts[eqn.outvars[0]] = c
continue
new_eqns.append(eqn)
new_constvars, new_constvals = unzip2(consts.items())
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, jaxpr.outvars, new_eqns)
return new_jaxpr, new_constvals

def _inline_literals(jaxpr, constvals):
# This function also ensures variables are labeled in a canonical ordering,
# prunes unused constants, and inserts `dropvar` symbols.
consts = dict(zip(jaxpr.constvars, constvals))
newvar = core.gensym()
newvars = {}
Expand All @@ -1206,20 +1228,16 @@ def lit(var: Var) -> Optional[Any]:
return None

used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars)
new_constvars = [var(v) for v in jaxpr.constvars if not lit(v)]
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)]
new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)]
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals)
if v in used and not lit(v)]
new_invars = [var(v) for v in jaxpr.invars]
new_eqns = []
for eqn in jaxpr.eqns:
invars = [lit(v) or var(v) for v in eqn.invars]
if (eqn.primitive is core.convert_element_type_p and type(invars[0]) is Literal):
# constant-fold dtype conversion of literals to be inlined
consts[eqn.outvars[0]] = np.array(invars[0].val, eqn.params['new_dtype'])
else:
# might do DCE here, but we won't until we're more careful about effects
outvars = [var(v) if v in used else dropvar for v in eqn.outvars]
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
eqn.source_info))
outvars = [var(v) if v in used else dropvar for v in eqn.outvars]
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
eqn.source_info))
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
return new_jaxpr, new_constvals
Expand Down
5 changes: 5 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3005,6 +3005,11 @@ def test_vmap_caching(self):

self.assertEqual(count[0], n)

def test_jnp_array_doesnt_device_put(self):
with jtu.count_device_put() as count:
api.make_jaxpr(lambda: jnp.array(3))()
self.assertEqual(count[0], 0)


class RematTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 06a0123

Please sign in to comment.