Skip to content

Commit

Permalink
[shard_map] fix axis env extension bug
Browse files Browse the repository at this point in the history
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
  • Loading branch information
mattjj and sharadmv committed Oct 16, 2023
1 parent 675cb15 commit 3bfe1d2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,7 +1429,7 @@ def _partial_eval_jaxpr_custom_rule(
with core.extend_axis_env_nd(mesh.shape.items()):
jaxpr_known, jaxpr_staged, unks_out, inst_out, num_res = \
pe.partial_eval_jaxpr_custom(jaxpr, unks_in, inst_in, False, False, saveable)
jaxpr_known, jaxpr_staged = _add_reshapes(num_res, jaxpr_known, jaxpr_staged)
jaxpr_known, jaxpr_staged = _add_reshapes(num_res, jaxpr_known, jaxpr_staged)
ins_known, _ = partition_list(unks_in, eqn.invars)
out_binders_known, _ = partition_list(unks_out, eqn.outvars)
_, ins_staged = partition_list(inst_in, eqn.invars)
Expand Down
12 changes: 12 additions & 0 deletions tests/shard_map_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,6 +1135,18 @@ def body(q, k, v):

jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2)

def test_axis_env_extension_regression(self):
def foo(x):
i = jax.lax.axis_index('x')
return jnp.exp(x) + i.astype('float')

@partial(jax.remat, policy=lambda *args, **kwargs: True)
def bar(x):
return shard_map(foo, mesh=Mesh(jax.devices(), ['x']), in_specs=(P('x'),),
out_specs=P('x'), check_rep=False)(x)

jax.jit(jax.grad(lambda x: bar(x).sum()))(jnp.arange(8.)) # doesn't crash


class FunSpec(NamedTuple):
name: str
Expand Down

0 comments on commit 3bfe1d2

Please sign in to comment.