Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make an AxisData struct that bundles axis name, size, and spmd name. #23796

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,19 +701,18 @@ def transposed(*args_flat):
transposed_jaxpr = core.ClosedJaxpr(transposed_jaxpr_, consts)
return transposed_jaxpr, cell.in_cts_zero # pytype: disable=attribute-error

def remat_vmap(spmd_axis_name, axis_size, axis_name, main_type, args, dims, *,
jaxpr, **params):
def remat_vmap(axis_data, main_type, args, dims, *, jaxpr, **params):
assert not jaxpr.constvars
jaxpr_batched_, out_batched = batching.batch_jaxpr_axes(
pe.close_jaxpr(jaxpr), axis_size, dims,
pe.close_jaxpr(jaxpr), axis_data, dims,
[batching.zero_if_mapped] * len(jaxpr.outvars),
axis_name=axis_name, spmd_axis_name=spmd_axis_name, main_type=main_type)
main_type=main_type)
jaxpr_batched, consts = jaxpr_batched_.jaxpr, jaxpr_batched_.consts
if consts:
jaxpr_batched = pe.convert_constvars_jaxpr(jaxpr_batched)
out_dims = [0 if b else None for b in out_batched]
return remat_p.bind(*consts, *args, jaxpr=jaxpr_batched, **params), out_dims
batching.axis_primitive_batchers[remat_p] = partial(remat_vmap, None)
batching.axis_primitive_batchers[remat_p] = remat_vmap
batching.spmd_axis_primitive_batchers[remat_p] = remat_vmap

# TODO(mattjj,sharadmv): de-duplicate with pe.dce_jaxpr_call_rule
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,10 +983,10 @@ def vmap_f(*args, **kwargs):
axis_size_ = (axis_size if axis_size is not None else
_mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap"))
try:
axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name)
out_flat = batching.batch(
flat_fun, axis_name, axis_size_, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes),
spmd_axis_name=spmd_axis_name
flat_fun, axis_data, in_axes_flat,
lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
).call_wrapped(*args_flat)
except batching.SpecMatchError as e:
out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes)
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ def maybe_bdim_at_front(x, bdim):
# `f` is pytree-flattened
def vmap_unrestricted(f: lu.WrappedFun, *args, in_axes, axis_name, axis_size):
f, out_axes = batching.batch_subtrace(f)
f = batching._batch_outer(f, axis_name, axis_size, in_axes,
batching.BatchTrace, None)
axis_data = batching.AxisData(axis_name, axis_size, None)
f = batching._batch_outer(f, axis_data, in_axes, batching.BatchTrace)
outs = f.call_wrapped(*args)
return outs, out_axes()

Expand Down
26 changes: 10 additions & 16 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,35 +921,31 @@ def _custom_vjp_call_jaxpr_jvp(
ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp

def _custom_vjp_call_jaxpr_vmap(
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims, *,
axis_data, main_type, args, in_dims, *,
fun_jaxpr: core.ClosedJaxpr,
fwd_jaxpr_thunk: Callable[..., tuple[core.Jaxpr, Sequence[Any]]],
num_consts: int, bwd: Callable, out_trees: Callable, symbolic_zeros: bool):
args = [batching.moveaxis(x, d, 0) if d is not not_mapped and d != 0
else x for x, d in zip(args, in_dims)]

in_batched = [d is not not_mapped for d in in_dims]
_, args_batched = split_list(in_batched, [num_consts])
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, axis_size, in_batched, False, axis_name, spmd_axis_name,
main_type)
fun_jaxpr, axis_data, in_batched, False, main_type)
out_dims1 = [0 if b else not_mapped for b in out_batched]
out_dims2 = []

@pe._memoize
def batched_fwd_jaxpr_thunk(*zeros):
fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk(*zeros)) # consts can be tracers
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_size, args_batched, False, axis_name, spmd_axis_name,
main_type)
fwd_jaxpr, axis_data, args_batched, False, main_type)
out_dims2.append([0 if b else not_mapped for b in out_batched])
return batched_fwd_jaxpr.jaxpr, batched_fwd_jaxpr.consts

fwd_args_batched = [0 if b else not_mapped for b in args_batched]
fwd_out_dims = lambda: out_dims2[0]
batched_bwd = batching.batch_custom_vjp_bwd(
bwd, axis_name, axis_size, fwd_out_dims, fwd_args_batched, main_type,
spmd_axis_name)
bwd, axis_data, fwd_out_dims, fwd_args_batched, main_type)

batched_outs = custom_vjp_call_jaxpr_p.bind(
*args, fun_jaxpr=batched_fun_jaxpr,
Expand All @@ -959,8 +955,8 @@ def batched_fwd_jaxpr_thunk(*zeros):
return batched_outs, out_dims
batching.spmd_axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \
_custom_vjp_call_jaxpr_vmap
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = partial(
_custom_vjp_call_jaxpr_vmap, None)
batching.axis_primitive_batchers[custom_vjp_call_jaxpr_p] = \
_custom_vjp_call_jaxpr_vmap

xla.register_initial_style_primitive(custom_vjp_call_jaxpr_p)

Expand Down Expand Up @@ -1532,7 +1528,7 @@ def _remat_opt_abstract_eval(*args, fwd_jaxpr: core.ClosedJaxpr, **_):
return fwd_jaxpr.out_avals, fwd_jaxpr.effects

def _remat_opt_vmap(
spmd_axis_name, axis_size, axis_name, main_type, args, in_dims,
axis_data, main_type, args, in_dims,
*,
num_consts: int,
num_res: int,
Expand All @@ -1544,8 +1540,7 @@ def _remat_opt_vmap(

in_batched = [d is not not_mapped for d in in_dims]
batched_fwd_jaxpr, out_batched = batching.batch_jaxpr(
fwd_jaxpr, axis_size, in_batched, False,
axis_name, spmd_axis_name, main_type)
fwd_jaxpr, axis_data, in_batched, False, main_type)
extra_consts = batched_fwd_jaxpr.consts
batched_fwd_jaxpr = pe.close_jaxpr(
pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr))
Expand All @@ -1557,8 +1552,7 @@ def _remat_opt_vmap(
def batched_fun_jaxpr_thunk():
fun_jaxpr = core.ClosedJaxpr(*fun_jaxpr_thunk())
batched_fun_jaxpr, out_batched = batching.batch_jaxpr(
fun_jaxpr, axis_size, prim_batched, False, axis_name, spmd_axis_name,
main_type)
fun_jaxpr, axis_data, prim_batched, False, main_type)
return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts

batched_outs = remat_opt_p.bind(*extra_consts, *args,
Expand Down Expand Up @@ -1667,7 +1661,7 @@ def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn):
mlir.register_lowering(remat_opt_p, mlir.lower_fun(
_remat_opt_impl, multiple_results=True))
batching.spmd_axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap
batching.axis_primitive_batchers[remat_opt_p] = partial(_remat_opt_vmap, None)
batching.axis_primitive_batchers[remat_opt_p] = _remat_opt_vmap
ad.primitive_jvps[remat_opt_p] = _remat_opt_jvp
ad.primitive_transposes[remat_opt_p] = _remat_opt_transpose
pe.dce_rules[remat_opt_p] = _remat_opt_dce
Loading