Skip to content

Commit

Permalink
elide trivial convert_element_types
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Nov 16, 2021
1 parent 476ca94 commit 4d9f277
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
17 changes: 14 additions & 3 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,20 +1215,31 @@ def find_progenitors(self, tracer):
def _prune_convert_element_types(jaxpr, constvals):
consts = dict(zip(jaxpr.constvars, constvals))
new_eqns = []
var_subs: Dict[Var, Var] = {}
for eqn in jaxpr.eqns:
# apply invar substitutions
eqn = JaxprEqn([var_subs.get(v, v) for v in eqn.invars], eqn.outvars,
eqn.primitive, eqn.params, eqn.source_info)
if eqn.primitive is core.convert_element_type_p:
c = consts.get(eqn.invars[0])
if type(c) in core.literalable_types and not np.shape(c):
# constant-fold dtype conversion of literals to be inlined
consts[eqn.outvars[0]] = np.array(c, eqn.params['new_dtype'])
continue
if c is not None and dtypes.dtype(c) == eqn.params['new_dtype']:
# don't stage out no-op convert_element_type calls as clutter
elif c is not None and dtypes.dtype(c) == eqn.params['new_dtype']:
# don't include no-op convert_element_type calls on consts; instead,
# just make the outvar refer to the constant, and skip the eqn
consts[eqn.outvars[0]] = c
continue
elif c is None and eqn.invars[0].aval.dtype == eqn.params['new_dtype']:
# don't include no-op convert_element_type calls on vars; instead,
# substitute all occurences of the outvar with the invar, and skip eqn
var_subs[eqn.outvars[0]] = eqn.invars[0]
continue
new_eqns.append(eqn)
new_constvars, new_constvals = unzip2(consts.items())
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, jaxpr.outvars, new_eqns)
new_outvars = [var_subs.get(v, v) for v in jaxpr.outvars]
new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns)
return new_jaxpr, new_constvals

def _inline_literals(jaxpr, constvals):
Expand Down
13 changes: 13 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3176,6 +3176,19 @@ def test_jnp_array_doesnt_device_put(self):
api.make_jaxpr(lambda: jnp.array(3))()
self.assertEqual(count[0], 0)

def test_elide_trivial_convert_element_types(self):
if config.x64_enabled:
x = jnp.float64(1.)
else:
x = jnp.float32(1.)

jaxpr = api.make_jaxpr(lambda x, y: x + y)(x, 2.)
self.assertLen(jaxpr.eqns, 1)

cet = partial(lax.convert_element_type, new_dtype=x.dtype)
jaxpr = api.make_jaxpr(lambda x: cet(cet(cet(x))))(1.)
self.assertLen(jaxpr.eqns, 0)


class RematTest(jtu.JaxTestCase):

Expand Down

0 comments on commit 4d9f277

Please sign in to comment.