diff --git a/flax/training/orbax_utils.py b/flax/training/orbax_utils.py index d30c124771..c8ed84d4fb 100644 --- a/flax/training/orbax_utils.py +++ b/flax/training/orbax_utils.py @@ -14,8 +14,6 @@ """Utils for Orbax Checkpointing, available even after Flax Checkpointing is deprecated.""" -import dataclasses -import inspect import warnings from typing import Any @@ -80,42 +78,23 @@ def find_sharding(x): ] ): return jax.tree_util.tree_map( - lambda x: ocp.RestoreArgs(restore_type=np.ndarray), target + lambda x: ocp.RestoreArgs(restore_type=np.ndarray), target ) # JAX arrays: find sharding from the given target and create RestoreArgs - - # TODO(ivyzheng): remove after Orbax new release. - ocp_kwargs = {} - if ( - 'set_global_shape' - in inspect.signature(ocp.checkpoint_utils.construct_restore_args).parameters - ): - ocp_kwargs['set_global_shape'] = False - sharding_tree = jax.tree_util.tree_map(find_sharding, target) if mesh is not None: warnings.warn( - ( - 'restore_args_from_target(): `mesh` arg is deprecated. Simply' - ' calling the function with target pytree should suffice.' - ), - DeprecationWarning, + ( + 'restore_args_from_target(): `mesh` arg is deprecated. Simply' + ' calling the function with target pytree should suffice.' + ), + DeprecationWarning, ) - def substitute_embedding(s): return jax.sharding.NamedSharding(mesh, s.spec) - sharding_tree = jax.tree_util.tree_map(substitute_embedding, sharding_tree) restore_args = ocp.checkpoint_utils.construct_restore_args( - target, sharding_tree, **ocp_kwargs + target, sharding_tree, set_global_shape=False ) - # TODO(ivyzheng): remove after Orbax new release. - if not ocp_kwargs: - restore_args = jax.tree_util.tree_map( - lambda ra: dataclasses.replace(ra, global_shape=None) - if isinstance(ra, ocp.ArrayRestoreArgs) - else ra, - restore_args, - ) return restore_args