diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index 39175db696f8..ebf04e37337a 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -1429,7 +1429,7 @@ def _partial_eval_jaxpr_custom_rule( with core.extend_axis_env_nd(mesh.shape.items()): jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \ pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable) - jaxpr_known, jaxpr_staged = _add_reshapes(num_res, jaxpr_known, jaxpr_staged) + jaxpr_known, jaxpr_staged = _add_reshapes(num_res, jaxpr_known, jaxpr_staged) ins_known, _ = partition_list(unks_in, eqn.invars) out_binders_known, _ = partition_list(unks_out, eqn.outvars) _, ins_staged = partition_list(inst_in, eqn.invars) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 3d704cae06c2..ef4cbec6b377 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1135,6 +1135,18 @@ def body(q, k, v): jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2) + def test_axis_env_extension_regression(self): + def foo(x): + i = jax.lax.axis_index('x') + return jnp.exp(x) + i.astype('float') + + @partial(jax.remat, policy=lambda *args, **kwargs: True) + def bar(x): + return shard_map(foo, mesh=Mesh(jax.devices(), ['x']), in_specs=(P('x'),), + out_specs=P('x'), check_rep=False)(x) + + jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash + class FunSpec(NamedTuple): name: str