Skip to content

Commit

Permalink
Going back to old debug mode behavior, but with fixed handling of non…
Browse files Browse the repository at this point in the history
…-broadcast "scalar" params passed to staged functions. Also putting disable_jit around method calls to prevent compilation in JAX control flow constructs.

PiperOrigin-RevId: 700332069
  • Loading branch information
james-martens authored and KfacJaxDev committed Nov 26, 2024
1 parent 1ab9aca commit 5e135a6
Showing 1 changed file with 50 additions and 18 deletions.
68 changes: 50 additions & 18 deletions kfac_jax/_src/utils/staging.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@

import functools
import numbers
import operator
from typing import Any, Callable, Sequence

import jax
from jax import lax
import jax.numpy as jnp

from kfac_jax._src.utils import misc
from kfac_jax._src.utils import parallel
from kfac_jax._src.utils import types
Expand Down Expand Up @@ -97,6 +100,8 @@ def multi_device(self) -> bool:
@property
def pmap_axis_name(self) -> str | None:
"""The name of the `jax.pmap` axis to use for staged methods."""
if self.debug:
return None
return self._pmap_axis_name

@property
Expand Down Expand Up @@ -144,6 +149,12 @@ def pmean_if_pmap_wrapper(
return func


def _is_scalar(x: Any) -> bool:
return isinstance(x, numbers.Number) or (
isinstance(x, jax.Array) and not x.shape
)


def staged(
method: Callable[..., TArrayTree],
static_argnums: int | Sequence[int] | None = None,
Expand Down Expand Up @@ -186,6 +197,8 @@ def try(self, x):
else:
donate_argnums: tuple[int, ...] = tuple(donate_argnums)

original_static_argnums = static_argnums or ()

# shift static_argnums by 1 and include instance (self)
static_argnums = (0,) + tuple(i + 1 for i in (static_argnums or ()))
# shift donate_argnums by 1 and include state
Expand All @@ -203,24 +216,52 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:
return method(instance, *args)

with instance.staging_context():
if instance.multi_device and instance.debug:
# In this case we want to call `method` once for each device index.
# Note that this might not always produce sensible behavior, and will
# depend on the details of the method and if it has side effects on the
# state of the class. Note that pmean operations won't happen, since the
# actual output of pmapped methods won't be numerically correct.

bcast_argnums = [
i for i in range(len(args)) if (i in original_static_argnums
or _is_scalar(args[i]))]

if instance.multi_device:
outs = []
non_bcast_args = [args[i] if i not in bcast_argnums else None
for i in range(len(args))]

for i in range(jax.local_device_count()):

non_bcast_args_i = jax.tree_util.tree_map(
operator.itemgetter(i), non_bcast_args)

args_i = [
non_bcast_args_i[j] if j not in bcast_argnums else args[j]
for j in range(len(args))
]

with jax.disable_jit():
outs.append(method(instance, *args_i))

outs = jax.tree_util.tree_map(lambda *args_: jnp.stack(args_), *outs)

elif instance.debug:
with jax.disable_jit():
outs = method(instance, *args)

elif instance.multi_device:

# Compute in_axes so we broadcast any argument that is a scalar
in_axes = [None]
for i in range(len(args)):

if (isinstance(args[i], numbers.Number) or
(isinstance(args[i], jax.Array) and not args[i].shape)):
# Single scalar
if _is_scalar(args[i]):
in_axes.append(None)

else:
in_axes.append(0)

in_axes = tuple(in_axes)
key = (instance.pmap_axis_name, in_axes)

func = pmap_funcs.get(key)

if func is None:
Expand All @@ -233,19 +274,10 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree:
)
pmap_funcs[key] = func

if instance.debug:
with jax.disable_jit():
outs = func(instance, *args)
else:
outs = func(instance, *args)
outs = func(instance, *args)

else:

if instance.debug:
with jax.disable_jit():
outs = jitted_func(instance, *args)
else:
outs = jitted_func(instance, *args)
outs = jitted_func(instance, *args)

return outs

Expand Down

0 comments on commit 5e135a6

Please sign in to comment.