Skip to content

Commit

Permalink
Remove some outdated backward-compatibility code.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 650786879
  • Loading branch information
IvyZX authored and Flax Authors committed Jul 10, 2024
1 parent 1b58348 commit 692e9c0
Showing 1 changed file with 7 additions and 28 deletions.
35 changes: 7 additions & 28 deletions flax/training/orbax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

"""Utils for Orbax Checkpointing, available even after Flax Checkpointing is deprecated."""

import dataclasses
import inspect
import warnings
from typing import Any

Expand Down Expand Up @@ -80,42 +78,23 @@ def find_sharding(x):
]
):
return jax.tree_util.tree_map(
lambda x: ocp.RestoreArgs(restore_type=np.ndarray), target
lambda x: ocp.RestoreArgs(restore_type=np.ndarray), target
)

# JAX arrays: find sharding from the given target and create RestoreArgs

# TODO(ivyzheng): remove after Orbax new release.
ocp_kwargs = {}
if (
'set_global_shape'
in inspect.signature(ocp.checkpoint_utils.construct_restore_args).parameters
):
ocp_kwargs['set_global_shape'] = False

sharding_tree = jax.tree_util.tree_map(find_sharding, target)
if mesh is not None:
warnings.warn(
(
'restore_args_from_target(): `mesh` arg is deprecated. Simply'
' calling the function with target pytree should suffice.'
),
DeprecationWarning,
(
'restore_args_from_target(): `mesh` arg is deprecated. Simply'
' calling the function with target pytree should suffice.'
),
DeprecationWarning,
)

def substitute_embedding(s):
return jax.sharding.NamedSharding(mesh, s.spec)

sharding_tree = jax.tree_util.tree_map(substitute_embedding, sharding_tree)
restore_args = ocp.checkpoint_utils.construct_restore_args(
target, sharding_tree, **ocp_kwargs
target, sharding_tree, set_global_shape=False
)
# TODO(ivyzheng): remove after Orbax new release.
if not ocp_kwargs:
restore_args = jax.tree_util.tree_map(
lambda ra: dataclasses.replace(ra, global_shape=None)
if isinstance(ra, ocp.ArrayRestoreArgs)
else ra,
restore_args,
)
return restore_args

0 comments on commit 692e9c0

Please sign in to comment.