From 2ae45423462e6d57fa093ccd6d26dfef2e800edb Mon Sep 17 00:00:00 2001 From: Niket Kumar Bhumihar Date: Wed, 6 Nov 2024 13:19:02 -0800 Subject: [PATCH] Move Orbax `path` package under `_src`. PiperOrigin-RevId: 693839844 --- checkpoint/CHANGELOG.md | 3 ++- checkpoint/orbax/__init__.py | 2 +- checkpoint/orbax/checkpoint/__init__.py | 2 +- .../base_pytree_checkpoint_handler.py | 2 +- .../handlers/composite_checkpoint_handler.py | 8 ++++--- .../composite_checkpoint_handler_test.py | 2 +- .../checkpoint/{ => _src}/path/__init__.py | 0 .../checkpoint/{ => _src}/path/async_utils.py | 2 +- .../checkpoint/{ => _src}/path/atomicity.py | 4 ++-- .../{ => _src}/path/atomicity_test.py | 4 ++-- .../checkpoint/{ => _src}/path/deleter.py | 2 +- .../{ => _src}/path/deleter_test.py | 4 ++-- .../{ => _src}/path/format_utils.py | 0 .../{ => _src}/path/format_utils_test.py | 2 +- .../{path/utils.py => _src/path/locking.py} | 6 ++--- .../orbax/checkpoint/{ => _src}/path/step.py | 0 .../checkpoint/{ => _src}/path/step_test.py | 4 ++-- .../orbax/checkpoint/_src/path/utils.py | 2 +- .../_src/serialization/type_handlers.py | 4 ++-- .../orbax/checkpoint/async_checkpointer.py | 4 ++-- .../orbax/checkpoint/checkpoint_manager.py | 6 ++--- .../orbax/checkpoint/checkpoint_utils.py | 2 +- .../orbax/checkpoint/checkpoint_utils_test.py | 2 +- checkpoint/orbax/checkpoint/checkpointer.py | 2 +- .../emergency/checkpoint_manager.py | 2 +- checkpoint/orbax/checkpoint/path.py | 23 +++++++++++++++++++ .../orbax/checkpoint/standard_checkpointer.py | 2 +- checkpoint/orbax/checkpoint/test_utils.py | 4 ++-- checkpoint/orbax/checkpoint/utils.py | 14 +++++------ 29 files changed, 69 insertions(+), 45 deletions(-) rename checkpoint/orbax/checkpoint/{ => _src}/path/__init__.py (100%) rename checkpoint/orbax/checkpoint/{ => _src}/path/async_utils.py (96%) rename checkpoint/orbax/checkpoint/{ => _src}/path/atomicity.py (99%) rename checkpoint/orbax/checkpoint/{ => _src}/path/atomicity_test.py (97%) rename checkpoint/orbax/checkpoint/{ => _src}/path/deleter.py (99%) rename checkpoint/orbax/checkpoint/{ => _src}/path/deleter_test.py (94%) rename checkpoint/orbax/checkpoint/{ => _src}/path/format_utils.py (100%) rename checkpoint/orbax/checkpoint/{ => _src}/path/format_utils_test.py (98%) rename checkpoint/orbax/checkpoint/{path/utils.py => _src/path/locking.py} (94%) rename checkpoint/orbax/checkpoint/{ => _src}/path/step.py (100%) rename checkpoint/orbax/checkpoint/{ => _src}/path/step_test.py (99%) create mode 100644 checkpoint/orbax/checkpoint/path.py diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 0b0cee54..7f9e348a 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Move `type_handlers` to `_src/serialization` - Add notes to Barrier error `XlaRuntimeError(DEADLINE_EXCEEDED)` with actionable info. +- Make `NameFormat.find_all` impls concurrent. +- Move `path` package under `_src` package. ## [0.8.0] - 2024-10-29 @@ -37,7 +39,6 @@ exported in the same way. - De-duplicate `get_ts_context` usages and move to ts_utils. - Move `logging` to `_src`. - Move `metadata` to `_src`. -- Make `NameFormat.find_all` impls concurrent. ## [0.7.0] - 2024-10-07 diff --git a/checkpoint/orbax/__init__.py b/checkpoint/orbax/__init__.py index 6ce16605..ac1bcddc 100644 --- a/checkpoint/orbax/__init__.py +++ b/checkpoint/orbax/__init__.py @@ -37,7 +37,7 @@ from orbax.checkpoint import version # TODO(cpgaffney): Import the public multihost API. from orbax.checkpoint._src.multihost import multihost -from orbax.checkpoint.path import step +from orbax.checkpoint._src.path import step from orbax.checkpoint.future import Future diff --git a/checkpoint/orbax/checkpoint/__init__.py b/checkpoint/orbax/checkpoint/__init__.py index 6ce16605..ac1bcddc 100644 --- a/checkpoint/orbax/checkpoint/__init__.py +++ b/checkpoint/orbax/checkpoint/__init__.py @@ -37,7 +37,7 @@ from orbax.checkpoint import version # TODO(cpgaffney): Import the public multihost API. from orbax.checkpoint._src.multihost import multihost -from orbax.checkpoint.path import step +from orbax.checkpoint._src.path import step from orbax.checkpoint.future import Future diff --git a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py index ee2e21dc..bf5e782d 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/base_pytree_checkpoint_handler.py @@ -43,11 +43,11 @@ from orbax.checkpoint._src.handlers import async_checkpoint_handler from orbax.checkpoint._src.metadata import tree as tree_metadata from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.path import format_utils from orbax.checkpoint._src.serialization import serialization from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint._src.tree import utils as tree_utils -from orbax.checkpoint.path import format_utils import tensorstore as ts diff --git a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py index 83d7bdad..dfd5638c 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler.py @@ -63,8 +63,8 @@ from orbax.checkpoint._src.handlers import checkpoint_handler from orbax.checkpoint._src.handlers import handler_registration from orbax.checkpoint._src.handlers import proto_checkpoint_handler -from orbax.checkpoint.path import atomicity -from orbax.checkpoint.path import utils as path_utils +from orbax.checkpoint._src.path import atomicity +from orbax.checkpoint._src.path import locking CheckpointArgs = checkpoint_args.CheckpointArgs Future = future.Future @@ -721,7 +721,7 @@ def _existing_items(self, directory: epath.Path) -> List[str]: return [ p.name for p in directory.iterdir() - if p.is_dir() and p != path_utils.lockdir(directory) + if p.is_dir() and p != locking.lockdir(directory) ] def restore( @@ -883,7 +883,9 @@ def close(self): ) class CompositeArgs(Composite, CheckpointArgs): """Args for wrapping multiple checkpoint items together.""" + ... + # Returned object of CompositeCheckpointHandler is an alias of CompositeArgs. CompositeResults = CompositeArgs diff --git a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py index ec634a47..a7ec0563 100644 --- a/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py +++ b/checkpoint/orbax/checkpoint/_src/handlers/composite_checkpoint_handler_test.py @@ -27,7 +27,7 @@ from orbax.checkpoint._src.handlers import standard_checkpoint_handler from orbax.checkpoint._src.metadata import value as value_metadata from orbax.checkpoint._src.multihost import multihost -from orbax.checkpoint.path import step +from orbax.checkpoint._src.path import step CompositeArgs = composite_checkpoint_handler.CompositeArgs JsonCheckpointHandler = json_checkpoint_handler.JsonCheckpointHandler diff --git a/checkpoint/orbax/checkpoint/path/__init__.py b/checkpoint/orbax/checkpoint/_src/path/__init__.py similarity index 100% rename from checkpoint/orbax/checkpoint/path/__init__.py rename to checkpoint/orbax/checkpoint/_src/path/__init__.py diff --git a/checkpoint/orbax/checkpoint/path/async_utils.py b/checkpoint/orbax/checkpoint/_src/path/async_utils.py similarity index 96% rename from checkpoint/orbax/checkpoint/path/async_utils.py rename to checkpoint/orbax/checkpoint/_src/path/async_utils.py index 504c78ad..5f5b2a95 100644 --- a/checkpoint/orbax/checkpoint/path/async_utils.py +++ b/checkpoint/orbax/checkpoint/_src/path/async_utils.py @@ -18,7 +18,7 @@ from etils import epath from orbax.checkpoint._src import asyncio_utils -from orbax.checkpoint.path import step as step_lib +from orbax.checkpoint._src.path import step as step_lib # TODO(b/360190539): This functionality should be provided by either an external diff --git a/checkpoint/orbax/checkpoint/path/atomicity.py b/checkpoint/orbax/checkpoint/_src/path/atomicity.py similarity index 99% rename from checkpoint/orbax/checkpoint/path/atomicity.py rename to checkpoint/orbax/checkpoint/_src/path/atomicity.py index 678ee2b3..5d55d250 100644 --- a/checkpoint/orbax/checkpoint/path/atomicity.py +++ b/checkpoint/orbax/checkpoint/_src/path/atomicity.py @@ -64,9 +64,9 @@ from orbax.checkpoint._src.metadata import checkpoint as checkpoint_metadata from orbax.checkpoint._src.multihost import counters from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.path import async_utils +from orbax.checkpoint._src.path import step as step_lib from orbax.checkpoint._src.path import utils -from orbax.checkpoint.path import async_utils -from orbax.checkpoint.path import step as step_lib TMP_DIR_SUFFIX = step_lib.TMP_DIR_SUFFIX diff --git a/checkpoint/orbax/checkpoint/path/atomicity_test.py b/checkpoint/orbax/checkpoint/_src/path/atomicity_test.py similarity index 97% rename from checkpoint/orbax/checkpoint/path/atomicity_test.py rename to checkpoint/orbax/checkpoint/_src/path/atomicity_test.py index 31080140..49ddd5d4 100644 --- a/checkpoint/orbax/checkpoint/path/atomicity_test.py +++ b/checkpoint/orbax/checkpoint/_src/path/atomicity_test.py @@ -18,8 +18,8 @@ from etils import epath from orbax.checkpoint import test_utils from orbax.checkpoint._src.multihost import multihost -from orbax.checkpoint.path import atomicity -from orbax.checkpoint.path import step as step_lib +from orbax.checkpoint._src.path import atomicity +from orbax.checkpoint._src.path import step as step_lib AtomicRenameTemporaryPath = atomicity.AtomicRenameTemporaryPath CommitFileTemporaryPath = atomicity.CommitFileTemporaryPath diff --git a/checkpoint/orbax/checkpoint/path/deleter.py b/checkpoint/orbax/checkpoint/_src/path/deleter.py similarity index 99% rename from checkpoint/orbax/checkpoint/path/deleter.py rename to checkpoint/orbax/checkpoint/_src/path/deleter.py index 7a8af7f5..955fbc42 100644 --- a/checkpoint/orbax/checkpoint/path/deleter.py +++ b/checkpoint/orbax/checkpoint/_src/path/deleter.py @@ -22,7 +22,7 @@ from etils import epath import jax from orbax.checkpoint import utils -from orbax.checkpoint.path import step as step_lib +from orbax.checkpoint._src.path import step as step_lib _THREADED_DELETE_DURATION = ( '/jax/orbax/checkpoint_manager/threaded_checkpoint_deleter/duration' diff --git a/checkpoint/orbax/checkpoint/path/deleter_test.py b/checkpoint/orbax/checkpoint/_src/path/deleter_test.py similarity index 94% rename from checkpoint/orbax/checkpoint/path/deleter_test.py rename to checkpoint/orbax/checkpoint/_src/path/deleter_test.py index 2ee24fdc..8bf6a4f7 100644 --- a/checkpoint/orbax/checkpoint/path/deleter_test.py +++ b/checkpoint/orbax/checkpoint/_src/path/deleter_test.py @@ -17,8 +17,8 @@ from absl.testing import absltest from absl.testing import parameterized from etils import epath -from orbax.checkpoint.path import deleter as deleter_lib -from orbax.checkpoint.path import step as step_lib +from orbax.checkpoint._src.path import deleter as deleter_lib +from orbax.checkpoint._src.path import step as step_lib class CheckpointDeleterTest(parameterized.TestCase): diff --git a/checkpoint/orbax/checkpoint/path/format_utils.py b/checkpoint/orbax/checkpoint/_src/path/format_utils.py similarity index 100% rename from checkpoint/orbax/checkpoint/path/format_utils.py rename to checkpoint/orbax/checkpoint/_src/path/format_utils.py diff --git a/checkpoint/orbax/checkpoint/path/format_utils_test.py b/checkpoint/orbax/checkpoint/_src/path/format_utils_test.py similarity index 98% rename from checkpoint/orbax/checkpoint/path/format_utils_test.py rename to checkpoint/orbax/checkpoint/_src/path/format_utils_test.py index e69c17ae..f2efce40 100644 --- a/checkpoint/orbax/checkpoint/path/format_utils_test.py +++ b/checkpoint/orbax/checkpoint/_src/path/format_utils_test.py @@ -22,7 +22,7 @@ from orbax.checkpoint._src.handlers import pytree_checkpoint_handler from orbax.checkpoint._src.handlers import standard_checkpoint_handler from orbax.checkpoint._src.metadata import checkpoint as checkpoint_metadata -from orbax.checkpoint.path import format_utils +from orbax.checkpoint._src.path import format_utils diff --git a/checkpoint/orbax/checkpoint/path/utils.py b/checkpoint/orbax/checkpoint/_src/path/locking.py similarity index 94% rename from checkpoint/orbax/checkpoint/path/utils.py rename to checkpoint/orbax/checkpoint/_src/path/locking.py index 075342c8..ce36d5e5 100644 --- a/checkpoint/orbax/checkpoint/path/utils.py +++ b/checkpoint/orbax/checkpoint/_src/path/locking.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Path utility functions for Orbax.""" +"""Manages locking of checkpoint step dirs.""" # TODO(b/337137764): Add unit tests. # TODO(b/337137764): If needed, export the functions from @@ -24,8 +24,8 @@ from etils import epath from orbax.checkpoint._src import asyncio_utils -from orbax.checkpoint.path import async_utils -from orbax.checkpoint.path import step as step_lib +from orbax.checkpoint._src.path import async_utils +from orbax.checkpoint._src.path import step as step_lib _LOCK_ITEM_NAME = 'LOCKED' diff --git a/checkpoint/orbax/checkpoint/path/step.py b/checkpoint/orbax/checkpoint/_src/path/step.py similarity index 100% rename from checkpoint/orbax/checkpoint/path/step.py rename to checkpoint/orbax/checkpoint/_src/path/step.py diff --git a/checkpoint/orbax/checkpoint/path/step_test.py b/checkpoint/orbax/checkpoint/_src/path/step_test.py similarity index 99% rename from checkpoint/orbax/checkpoint/path/step_test.py rename to checkpoint/orbax/checkpoint/_src/path/step_test.py index fc614da3..fd4b0692 100644 --- a/checkpoint/orbax/checkpoint/path/step_test.py +++ b/checkpoint/orbax/checkpoint/_src/path/step_test.py @@ -21,8 +21,8 @@ from etils import epath from orbax.checkpoint import test_utils from orbax.checkpoint._src.metadata import checkpoint -from orbax.checkpoint.path import atomicity -from orbax.checkpoint.path import step as step_lib +from orbax.checkpoint._src.path import atomicity +from orbax.checkpoint._src.path import step as step_lib class StandardNameFormatTest(parameterized.TestCase): diff --git a/checkpoint/orbax/checkpoint/_src/path/utils.py b/checkpoint/orbax/checkpoint/_src/path/utils.py index 415c856d..63ae7dbb 100644 --- a/checkpoint/orbax/checkpoint/_src/path/utils.py +++ b/checkpoint/orbax/checkpoint/_src/path/utils.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Utils for snapshotting.""" +"""Utils for path constructs.""" import os import time diff --git a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py index 90e574d6..d67c13ba 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/type_handlers.py @@ -42,10 +42,10 @@ from orbax.checkpoint._src.metadata import value as value_metadata from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.multihost import multislice +from orbax.checkpoint._src.path import async_utils +from orbax.checkpoint._src.path import format_utils from orbax.checkpoint._src.serialization import serialization from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils -from orbax.checkpoint.path import async_utils -from orbax.checkpoint.path import format_utils import tensorstore as ts diff --git a/checkpoint/orbax/checkpoint/async_checkpointer.py b/checkpoint/orbax/checkpoint/async_checkpointer.py index 97cb16c7..2df417c9 100644 --- a/checkpoint/orbax/checkpoint/async_checkpointer.py +++ b/checkpoint/orbax/checkpoint/async_checkpointer.py @@ -31,8 +31,8 @@ from orbax.checkpoint._src.metadata import checkpoint from orbax.checkpoint._src.multihost import counters from orbax.checkpoint._src.multihost import multihost -from orbax.checkpoint.path import async_utils -from orbax.checkpoint.path import atomicity +from orbax.checkpoint._src.path import async_utils +from orbax.checkpoint._src.path import atomicity BarrierSyncFn = multihost.BarrierSyncFn diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index b1c15fe2..81a0078d 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -44,13 +44,13 @@ from orbax.checkpoint._src.handlers import proto_checkpoint_handler from orbax.checkpoint._src.metadata import checkpoint from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.path import atomicity +from orbax.checkpoint._src.path import deleter +from orbax.checkpoint._src.path import step as step_lib from orbax.checkpoint._src.path import utils as path_utils from orbax.checkpoint.logging import abstract_logger from orbax.checkpoint.logging import standard_logger from orbax.checkpoint.logging import step_statistics -from orbax.checkpoint.path import atomicity -from orbax.checkpoint.path import deleter -from orbax.checkpoint.path import step as step_lib from typing_extensions import Self # for Python version < 3.11 diff --git a/checkpoint/orbax/checkpoint/checkpoint_utils.py b/checkpoint/orbax/checkpoint/checkpoint_utils.py index 488bbd45..c6a6f30b 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_utils.py +++ b/checkpoint/orbax/checkpoint/checkpoint_utils.py @@ -25,9 +25,9 @@ from orbax.checkpoint import utils from orbax.checkpoint._src.metadata import value as value_metadata from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.path import step as step_lib from orbax.checkpoint._src.path.snapshot import snapshot as snapshot_lib from orbax.checkpoint._src.serialization import type_handlers -from orbax.checkpoint.path import step as step_lib PyTree = Any diff --git a/checkpoint/orbax/checkpoint/checkpoint_utils_test.py b/checkpoint/orbax/checkpoint/checkpoint_utils_test.py index 09e6656d..9b63b2f8 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_utils_test.py +++ b/checkpoint/orbax/checkpoint/checkpoint_utils_test.py @@ -27,7 +27,7 @@ from orbax.checkpoint import utils from orbax.checkpoint._src.handlers import pytree_checkpoint_handler from orbax.checkpoint._src.metadata import value as value_metadata -from orbax.checkpoint.path import step as step_lib +from orbax.checkpoint._src.path import step as step_lib RestoreArgs = pytree_checkpoint_handler.RestoreArgs diff --git a/checkpoint/orbax/checkpoint/checkpointer.py b/checkpoint/orbax/checkpoint/checkpointer.py index 0f9a719e..7620e1f9 100644 --- a/checkpoint/orbax/checkpoint/checkpointer.py +++ b/checkpoint/orbax/checkpoint/checkpointer.py @@ -30,7 +30,7 @@ from orbax.checkpoint._src.handlers import composite_checkpoint_handler from orbax.checkpoint._src.metadata import checkpoint from orbax.checkpoint._src.multihost import multihost -from orbax.checkpoint.path import atomicity +from orbax.checkpoint._src.path import atomicity from typing_extensions import Self # for Python version < 3.11 diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py index b31ca31a..a838836e 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py @@ -48,12 +48,12 @@ from orbax.checkpoint._src.multihost import counters from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.multihost import multislice +from orbax.checkpoint._src.path import step as step_lib from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint.experimental.emergency import multihost as emergency_multihost from orbax.checkpoint.logging import abstract_logger from orbax.checkpoint.logging import standard_logger from orbax.checkpoint.logging import step_statistics -from orbax.checkpoint.path import step as step_lib from typing_extensions import Self # for Python version < 3.11 diff --git a/checkpoint/orbax/checkpoint/path.py b/checkpoint/orbax/checkpoint/path.py new file mode 100644 index 00000000..97313760 --- /dev/null +++ b/checkpoint/orbax/checkpoint/path.py @@ -0,0 +1,23 @@ +# Copyright 2024 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines exported symbols from orbax.checkpoint.path package.""" + +# pylint: disable=unused-import + +from orbax.checkpoint._src.path import async_utils +from orbax.checkpoint._src.path import atomicity +from orbax.checkpoint._src.path import deleter +from orbax.checkpoint._src.path import format_utils +from orbax.checkpoint._src.path import step diff --git a/checkpoint/orbax/checkpoint/standard_checkpointer.py b/checkpoint/orbax/checkpoint/standard_checkpointer.py index b3dce2ad..40b8b1ca 100644 --- a/checkpoint/orbax/checkpoint/standard_checkpointer.py +++ b/checkpoint/orbax/checkpoint/standard_checkpointer.py @@ -20,7 +20,7 @@ from orbax.checkpoint import options as options_lib from orbax.checkpoint._src.handlers import standard_checkpoint_handler from orbax.checkpoint._src.metadata import checkpoint -from orbax.checkpoint.path import atomicity +from orbax.checkpoint._src.path import atomicity StandardCheckpointHandler = ( diff --git a/checkpoint/orbax/checkpoint/test_utils.py b/checkpoint/orbax/checkpoint/test_utils.py index 8dc6e73d..d6c8f3d2 100644 --- a/checkpoint/orbax/checkpoint/test_utils.py +++ b/checkpoint/orbax/checkpoint/test_utils.py @@ -42,12 +42,12 @@ from orbax.checkpoint._src.multihost import counters from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.multihost import multislice +from orbax.checkpoint._src.path import atomicity +from orbax.checkpoint._src.path import step as step_lib from orbax.checkpoint._src.serialization import serialization from orbax.checkpoint._src.serialization import tensorstore_utils as ts_utils from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint._src.tree import utils as tree_utils -from orbax.checkpoint.path import atomicity -from orbax.checkpoint.path import step as step_lib class MuNu(NamedTuple): diff --git a/checkpoint/orbax/checkpoint/utils.py b/checkpoint/orbax/checkpoint/utils.py index 25c38c38..400d4db4 100644 --- a/checkpoint/orbax/checkpoint/utils.py +++ b/checkpoint/orbax/checkpoint/utils.py @@ -27,10 +27,10 @@ import jax import numpy as np from orbax.checkpoint._src.multihost import multihost +from orbax.checkpoint._src.path import async_utils +from orbax.checkpoint._src.path import locking +from orbax.checkpoint._src.path import step as step_lib from orbax.checkpoint._src.tree import utils as tree_utils -from orbax.checkpoint.path import async_utils -from orbax.checkpoint.path import step as step_lib -from orbax.checkpoint.path import utils as path_utils TMP_DIR_SUFFIX = step_lib.TMP_DIR_SUFFIX @@ -54,8 +54,8 @@ async_makedirs = async_utils.async_makedirs async_write_bytes = async_utils.async_write_bytes async_exists = async_utils.async_exists -lockdir = path_utils.lockdir -is_locked = path_utils.is_locked +lockdir = locking.lockdir +is_locked = locking.is_locked is_gcs_path = step_lib.is_gcs_path @@ -115,9 +115,7 @@ def name_from_leaf_placeholder(placeholder: str) -> str: def all_leaves_are_placeholders(tree: PyTree) -> bool: """Determines if all leaves in `tree` are placeholders.""" - return all( - leaf_is_placeholder(leaf) for leaf in jax.tree.leaves(tree) - ) + return all(leaf_is_placeholder(leaf) for leaf in jax.tree.leaves(tree)) def pytree_structure(directory: epath.PathLike) -> PyTree: