Skip to content

Commit

Permalink
elide trivial convert_element_types
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Nov 16, 2021
1 parent 476ca94 commit c4e9bb0
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 21 deletions.
32 changes: 14 additions & 18 deletions docs/jaxpr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,17 +210,16 @@ For example:
...
>>> print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a:i32[] b:f32[]. let
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
d:i32[] = clamp 0 c 2
e:f32[] = cond[
c:i32[] = clamp 0 a 2
d:f32[] = cond[
branches=(
{ lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
{ lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
{ lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
{ lambda ; e:f32[]. let f:f32[] = add e 1.0 in (f,) }
{ lambda ; g:f32[]. let h:f32[] = sub g 2.0 in (h,) }
{ lambda ; i:f32[]. let j:f32[] = add i 3.0 in (j,) }
)
linear=(False,)
] d b
in (e,) }
] c b
in (d,) }

The cond primitive has a number of parameters:

Expand Down Expand Up @@ -372,10 +371,9 @@ For the example consider the function ``func11`` below
d:f32[] e:f32[16] = scan[
jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let
j:f32[] = mul h i
k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
l:f32[] = add k j
m:f32[] = add l f
in (m, g) }
k:f32[] = add g j
l:f32[] = add k f
in (l, g) }
length=16
linear=(False, False, False, False)
num_carry=1
Expand Down Expand Up @@ -418,17 +416,15 @@ which the computation should run. For example
call_jaxpr={ lambda ; d:f32[] e:f32[]. let
f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
g:f32[1] = mul d f
h:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
i:f32[1] = add h g
in (i,) }
h:f32[1] = add e g
in (h,) }
device=None
donated_invars=(False, False)
inline=False
name=inner
] a b
j:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
k:f32[1] = add j c
in (k,) }
i:f32[1] = add a c
in (i,) }


XLA_pmap
Expand Down
17 changes: 14 additions & 3 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,20 +1215,31 @@ def find_progenitors(self, tracer):
def _prune_convert_element_types(jaxpr, constvals):
consts = dict(zip(jaxpr.constvars, constvals))
new_eqns = []
var_subs: Dict[Var, Var] = {}
for eqn in jaxpr.eqns:
# apply invar substitutions
eqn = JaxprEqn([var_subs.get(v, v) for v in eqn.invars], eqn.outvars,
eqn.primitive, eqn.params, eqn.source_info)
if eqn.primitive is core.convert_element_type_p:
c = consts.get(eqn.invars[0])
if type(c) in core.literalable_types and not np.shape(c):
# constant-fold dtype conversion of literals to be inlined
consts[eqn.outvars[0]] = np.array(c, eqn.params['new_dtype'])
continue
if c is not None and dtypes.dtype(c) == eqn.params['new_dtype']:
# don't stage out no-op convert_element_type calls as clutter
elif c is not None and dtypes.dtype(c) == eqn.params['new_dtype']:
# don't include no-op convert_element_type calls on consts; instead,
# just make the outvar refer to the constant, and skip the eqn
consts[eqn.outvars[0]] = c
continue
elif c is None and eqn.invars[0].aval.dtype == eqn.params['new_dtype']:
# don't include no-op convert_element_type calls on vars; instead,
# substitute all occurences of the outvar with the invar, and skip eqn
var_subs[eqn.outvars[0]] = eqn.invars[0]
continue
new_eqns.append(eqn)
new_constvars, new_constvals = unzip2(consts.items())
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, jaxpr.outvars, new_eqns)
new_outvars = [var_subs.get(v, v) for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns)
return new_jaxpr, new_constvals

def _inline_literals(jaxpr, constvals):
Expand Down
13 changes: 13 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3176,6 +3176,19 @@ def test_jnp_array_doesnt_device_put(self):
api.make_jaxpr(lambda: jnp.array(3))()
self.assertEqual(count[0], 0)

def test_elide_trivial_convert_element_types(self):
if config.x64_enabled:
x = jnp.float64(1.)
else:
x = jnp.float32(1.)

jaxpr = api.make_jaxpr(lambda x, y: x + y)(x, 2.)
self.assertLen(jaxpr.eqns, 1)

cet = partial(lax.convert_element_type, new_dtype=x.dtype)
jaxpr = api.make_jaxpr(lambda x: cet(cet(cet(x))))(1.)
self.assertLen(jaxpr.eqns, 0)


class RematTest(jtu.JaxTestCase):

Expand Down

0 comments on commit c4e9bb0

Please sign in to comment.