Skip to content

Commit

Permalink
Stop writing msgpack file for new checkpoints and update empty nodes …
Browse files Browse the repository at this point in the history
…handling so that it no longer depends on this file.

PiperOrigin-RevId: 649179323
  • Loading branch information
cpgaffney1 authored and Flax Authors committed Jul 3, 2024
1 parent 1367acb commit 0fb1777
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from collections.abc import Callable, Iterable

from etils import epath
import jax
import orbax.checkpoint as ocp
from absl import logging
Expand Down Expand Up @@ -76,7 +77,7 @@

# Orbax main checkpoint file name.
ORBAX_CKPT_FILENAME = 'checkpoint'
ORBAX_MANIFEST_OCDBT = 'manifest.ocdbt'
ORBAX_METADATA_FILENAME = '_METADATA'

PyTree = Any

Expand Down Expand Up @@ -123,7 +124,8 @@ def _safe_remove(path: str):

def _is_orbax_checkpoint(path: str) -> bool:
return io.exists(os.path.join(path, ORBAX_CKPT_FILENAME)) or io.exists(
os.path.join(path, ORBAX_MANIFEST_OCDBT)
os.path.join(path, ORBAX_METADATA_FILENAME)
or ocp.type_handlers.is_ocdbt_checkpoint(epath.Path(path))
)


Expand Down

0 comments on commit 0fb1777

Please sign in to comment.