diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 747ba63431..b4e19f7099 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -31,6 +31,7 @@ ) from collections.abc import Callable, Iterable +from etils import epath import jax import orbax.checkpoint as ocp from absl import logging @@ -76,7 +77,7 @@ # Orbax main checkpoint file name. ORBAX_CKPT_FILENAME = 'checkpoint' -ORBAX_MANIFEST_OCDBT = 'manifest.ocdbt' +ORBAX_METADATA_FILENAME = '_METADATA' PyTree = Any @@ -123,7 +124,8 @@ def _safe_remove(path: str): def _is_orbax_checkpoint(path: str) -> bool: return io.exists(os.path.join(path, ORBAX_CKPT_FILENAME)) or io.exists( - os.path.join(path, ORBAX_MANIFEST_OCDBT) + os.path.join(path, ORBAX_METADATA_FILENAME) + or ocp.type_handlers.is_ocdbt_checkpoint(epath.Path(path)) )