From 5e135a6a69d95c6a9b4009da04b383d94fb05b41 Mon Sep 17 00:00:00 2001 From: James Martens Date: Tue, 26 Nov 2024 07:14:02 -0800 Subject: [PATCH] Going back to old debug mode behavior, but with fixed handling of non-broadcast "scalar" params passed to staged functions. Also putting disable_jit around method calls to prevent compilation in JAX control flow constructs. PiperOrigin-RevId: 700332069 --- kfac_jax/_src/utils/staging.py | 68 +++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/kfac_jax/_src/utils/staging.py b/kfac_jax/_src/utils/staging.py index 5efe18d..24d7500 100644 --- a/kfac_jax/_src/utils/staging.py +++ b/kfac_jax/_src/utils/staging.py @@ -15,10 +15,13 @@ import functools import numbers +import operator from typing import Any, Callable, Sequence import jax from jax import lax +import jax.numpy as jnp + from kfac_jax._src.utils import misc from kfac_jax._src.utils import parallel from kfac_jax._src.utils import types @@ -97,6 +100,8 @@ def multi_device(self) -> bool: @property def pmap_axis_name(self) -> str | None: """The name of the `jax.pmap` axis to use for staged methods.""" + if self.debug: + return None return self._pmap_axis_name @property @@ -144,6 +149,12 @@ def pmean_if_pmap_wrapper( return func +def _is_scalar(x: Any) -> bool: + return isinstance(x, numbers.Number) or ( + isinstance(x, jax.Array) and not x.shape + ) + + def staged( method: Callable[..., TArrayTree], static_argnums: int | Sequence[int] | None = None, @@ -186,6 +197,8 @@ def try(self, x): else: donate_argnums: tuple[int, ...] = tuple(donate_argnums) + original_static_argnums = static_argnums or () + # shift static_argnums by 1 and include instance (self) static_argnums = (0,) + tuple(i + 1 for i in (static_argnums or ())) # shift donate_argnums by 1 and include state @@ -203,24 +216,52 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree: return method(instance, *args) with instance.staging_context(): + if instance.multi_device and instance.debug: + # In this case we want to call `method` once for each device index. + # Note that this might not always produce sensible behavior, and will + # depend on the details of the method and if it has side effects on the + # state of the class. Note that pmean operations won't happen, since the + # actual output of pmapped methods won't be numerically correct. + + bcast_argnums = [ + i for i in range(len(args)) if (i in original_static_argnums + or _is_scalar(args[i]))] - if instance.multi_device: + outs = [] + non_bcast_args = [args[i] if i not in bcast_argnums else None + for i in range(len(args))] + + for i in range(jax.local_device_count()): + + non_bcast_args_i = jax.tree_util.tree_map( + operator.itemgetter(i), non_bcast_args) + + args_i = [ + non_bcast_args_i[j] if j not in bcast_argnums else args[j] + for j in range(len(args)) + ] + + with jax.disable_jit(): + outs.append(method(instance, *args_i)) + + outs = jax.tree_util.tree_map(lambda *args_: jnp.stack(args_), *outs) + + elif instance.debug: + with jax.disable_jit(): + outs = method(instance, *args) + + elif instance.multi_device: # Compute in_axes so we broadcast any argument that is a scalar in_axes = [None] for i in range(len(args)): - - if (isinstance(args[i], numbers.Number) or - (isinstance(args[i], jax.Array) and not args[i].shape)): - # Single scalar + if _is_scalar(args[i]): in_axes.append(None) - else: in_axes.append(0) in_axes = tuple(in_axes) key = (instance.pmap_axis_name, in_axes) - func = pmap_funcs.get(key) if func is None: @@ -233,19 +274,10 @@ def decorated(instance: "WithStagedMethods", *args: Any) -> TArrayTree: ) pmap_funcs[key] = func - if instance.debug: - with jax.disable_jit(): - outs = func(instance, *args) - else: - outs = func(instance, *args) + outs = func(instance, *args) else: - - if instance.debug: - with jax.disable_jit(): - outs = jitted_func(instance, *args) - else: - outs = jitted_func(instance, *args) + outs = jitted_func(instance, *args) return outs