Skip to content

Commit

Permalink
Merge pull request #6400 from pschuh:convert_element
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 372431314
  • Loading branch information
jax authors committed May 6, 2021
2 parents 4bbb24d + 9d3e535 commit d0aa875
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 34 deletions.
30 changes: 12 additions & 18 deletions docs/jaxpr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,8 @@ For the example consider the function ``func11`` below
f = convert_element_type[ new_dtype=float32
weak_type=False ] b
g = add f e
h = convert_element_type[ new_dtype=float32
weak_type=False ] a
i = add g h
in (i, b) }
h = add g a
in (h, b) }
length=16
linear=(False, False, False, False)
num_carry=1
Expand Down Expand Up @@ -424,13 +422,11 @@ computation should run. For example
call_jaxpr={ lambda ; a b.
let c = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
d = convert_element_type[ new_dtype=float32
weak_type=False ] a
e = mul d c
f = convert_element_type[ new_dtype=float32
d = mul a c
e = convert_element_type[ new_dtype=float32
weak_type=False ] b
g = add f e
in (g,) }
f = add e d
in (f,) }
device=None
donated_invars=(False, False)
inline=False
Expand Down Expand Up @@ -461,16 +457,14 @@ captured using the ``xla_pmap`` primitive. Consider this example
axis_size=1
backend=None
call_jaxpr={ lambda ; a b.
let c = convert_element_type[ new_dtype=float32
weak_type=False ] a
d = add b c
e = broadcast_in_dim[ broadcast_dimensions=( )
let c = add b a
d = broadcast_in_dim[ broadcast_dimensions=( )
shape=(1,) ] 1.0
f = add d e
g = psum[ axes=('rows',)
e = add c d
f = psum[ axes=('rows',)
axis_index_groups=None ] b
h = div f g
in (h,) }
g = div e f
in (g,) }
devices=None
donated_invars=(False, False)
global_arg_shapes=(None,)
Expand Down
11 changes: 7 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
from jax.config import config
from jax.interpreters.xla import DeviceArray, _DeviceArray, _CppDeviceArray
from jax import lax
from jax._src.lax.lax import _device_put_raw
from jax import ops
from jax._src.util import (partial, unzip2, prod as _prod, subvals, safe_zip,
canonicalize_axis as _canonicalize_axis, maybe_named_axis)
Expand Down Expand Up @@ -2940,25 +2939,29 @@ def array(object, dtype=None, copy=True, order="K", ndmin=0):
# large integers; see discussion in https://github.com/google/jax/pull/6047.
object = _np_array(object, dtype=dtype, ndmin=ndmin, copy=False)

# call _np_array a second time with canonicalized dtype
dtype = dtypes.canonicalize_dtype(object.dtype)
object = _np_array(object, dtype=dtype, copy=False)

assert type(object) not in dtypes.python_scalar_dtypes

if type(object) is np.ndarray:
_inferred_dtype = object.dtype and dtypes.canonicalize_dtype(object.dtype)
lax._check_user_dtype_supported(_inferred_dtype, "array")
out = _device_put_raw(object, weak_type=weak_type)
out = _np_array(object, copy=copy, dtype=dtype)
if dtype: assert _dtype(out) == dtype
elif isinstance(object, (DeviceArray, core.Tracer)):
if isinstance(object, DeviceArray) and copy:
# We perform a copy by bouncing back to the host
# TODO(phawkins): add a device runtime function to copy a buffer
out = _device_put_raw(_np_asarray(object), weak_type=weak_type)
out = _np_asarray(object)
else:
out = object
elif isinstance(object, (list, tuple)):
if object:
out = stack([asarray(elt, dtype=dtype) for elt in object])
else:
out = _device_put_raw(_np_array([], dtype=dtype))
out = _np_array([], dtype=dtype)
else:
try:
view = memoryview(object)
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/call_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@
from jax import core
from jax import dlpack
from jax import dtypes
from jax import numpy as jnp
from jax import tree_util
from jax._src import util
from jax._src.lax.lax import _device_put_raw
from jax.interpreters import xla
from jax.lib import xla_bridge
from jax.lib import xla_client
Expand Down Expand Up @@ -180,7 +180,7 @@ def _res_tf_to_jax(res_tf):
return jax.dlpack.from_dlpack(
res_dlpack, backend=xla_bridge.get_backend(res_jax_platform))

return jnp.asarray(np.asarray(res_tf))
return _device_put_raw(np.asarray(res_tf))

return list(map(_res_tf_to_jax, res_tf_flat))

Expand Down
38 changes: 28 additions & 10 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,7 @@ def to_jaxpr(self, in_tracers, out_tracers):
outvars = [self.tracer_to_var[id(t)] for t in out_tracers]
constvars, constvals = unzip2(self.constvar_to_val.items())
jaxpr = Jaxpr(constvars, invars, outvars, self.eqns)
jaxpr, constvals = _prune_convert_element_types(jaxpr, constvals)
jaxpr, constvals = _inline_literals(jaxpr, constvals)
out_avals = [t.aval for t in out_tracers]
return jaxpr, out_avals, constvals
Expand All @@ -970,7 +971,28 @@ def find_progenitors(self, tracer):
const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars]
return invar_positions, const_eqns

def _prune_convert_element_types(jaxpr, constvals):
consts = dict(zip(jaxpr.constvars, constvals))
new_eqns = []
for eqn in jaxpr.eqns:
if eqn.primitive is core.convert_element_type_p:
c = consts.get(eqn.invars[0])
if type(c) in core.literalable_types and not np.shape(c):
# constant-fold dtype conversion of literals to be inlined
consts[eqn.outvars[0]] = np.array(c, eqn.params['new_dtype'])
continue
if c is not None and dtypes.dtype(c) == eqn.params['new_dtype']:
# don't stage out no-op convert_element_type calls as clutter
consts[eqn.outvars[0]] = c
continue
new_eqns.append(eqn)
new_constvars, new_constvals = unzip2(consts.items())
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, jaxpr.outvars, new_eqns)
return new_jaxpr, new_constvals

def _inline_literals(jaxpr, constvals):
# This function also ensures variables are labeled in a canonical ordering,
# prunes unused constants, and inserts `dropvar` symbols.
consts = dict(zip(jaxpr.constvars, constvals))
newvar = core.gensym()
newvars = {}
Expand All @@ -984,20 +1006,16 @@ def lit(var: core.Var) -> Optional[Any]:
return None

used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars)
new_constvars = [var(v) for v in jaxpr.constvars if not lit(v)]
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)]
new_constvars = [var(v) for v in jaxpr.constvars if v in used and not lit(v)]
new_constvals = [c for v, c in zip(jaxpr.constvars, constvals)
if v in used and not lit(v)]
new_invars = [var(v) for v in jaxpr.invars]
new_eqns = []
for eqn in jaxpr.eqns:
invars = [lit(v) or var(v) for v in eqn.invars]
if (eqn.primitive is core.convert_element_type_p and type(invars[0]) is Literal):
# constant-fold dtype conversion of literals to be inlined
consts[eqn.outvars[0]] = np.array(invars[0].val, eqn.params['new_dtype'])
else:
# might do DCE here, but we won't until we're more careful about effects
outvars = [var(v) if v in used else dropvar for v in eqn.outvars]
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
eqn.source_info))
outvars = [var(v) if v in used else dropvar for v in eqn.outvars]
new_eqns.append(new_jaxpr_eqn(invars, outvars, eqn.primitive, eqn.params,
eqn.source_info))
new_outvars = [lit(v) or var(v) for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns)
return new_jaxpr, new_constvals
Expand Down
5 changes: 5 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2779,6 +2779,11 @@ def f(x):
jaxpr = api.make_jaxpr(f)(3)
self.assertNotIn('xla_call', str(jaxpr))

def test_jnp_array_doesnt_device_put(self):
with jtu.count_device_put() as count:
api.make_jaxpr(lambda: jnp.array(3))()
self.assertEqual(count[0], 0)


class RematTest(jtu.JaxTestCase):

Expand Down

0 comments on commit d0aa875

Please sign in to comment.