diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 9d9bdee1033b..0ee787810499 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1215,20 +1215,31 @@ def find_progenitors(self, tracer): def _prune_convert_element_types(jaxpr, constvals): consts = dict(zip(jaxpr.constvars, constvals)) new_eqns = [] + var_subs: Dict[Var, Var] = {} for eqn in jaxpr.eqns: + # apply invar substitutions + eqn = JaxprEqn([var_subs.get(v, v) for v in eqn.invars], eqn.outvars, + eqn.primitive, eqn.params, eqn.source_info) if eqn.primitive is core.convert_element_type_p: c = consts.get(eqn.invars[0]) if type(c) in core.literalable_types and not np.shape(c): # constant-fold dtype conversion of literals to be inlined consts[eqn.outvars[0]] = np.array(c, eqn.params['new_dtype']) continue - if c is not None and dtypes.dtype(c) == eqn.params['new_dtype']: - # don't stage out no-op convert_element_type calls as clutter + elif c is not None and dtypes.dtype(c) == eqn.params['new_dtype']: + # don't include no-op convert_element_type calls on consts; instead, + # just make the outvar refer to the constant, and skip the eqn consts[eqn.outvars[0]] = c continue + elif c is None and eqn.invars[0].aval.dtype == eqn.params['new_dtype']: + # don't include no-op convert_element_type calls on vars; instead, + # substitute all occurences of the outvar with the invar, and skip eqn + var_subs[eqn.outvars[0]] = eqn.invars[0] + continue new_eqns.append(eqn) new_constvars, new_constvals = unzip2(consts.items()) - new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, jaxpr.outvars, new_eqns) + new_outvars = [var_subs.get(v, v) for v in jaxpr.outvars] + new_jaxpr = Jaxpr(new_constvars, jaxpr.invars, new_outvars, new_eqns) return new_jaxpr, new_constvals def _inline_literals(jaxpr, constvals): diff --git a/tests/api_test.py b/tests/api_test.py index 5aecec55d3b9..994c5bfa1f2a 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -3176,6 +3176,19 @@ def test_jnp_array_doesnt_device_put(self): api.make_jaxpr(lambda: jnp.array(3))() self.assertEqual(count[0], 0) + def test_elide_trivial_convert_element_types(self): + if config.x64_enabled: + x = jnp.float64(1.) + else: + x = jnp.float32(1.) + + jaxpr = api.make_jaxpr(lambda x, y: x + y)(x, 2.) + self.assertLen(jaxpr.eqns, 1) + + cet = partial(lax.convert_element_type, new_dtype=x.dtype) + jaxpr = api.make_jaxpr(lambda x: cet(cet(cet(x))))(1.) + self.assertLen(jaxpr.eqns, 0) + class RematTest(jtu.JaxTestCase):