diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 7b3d5eb86fbd..a47d34e0899f 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -15,9 +15,10 @@ create_default_local_save_plan, create_default_global_save_plan, ) -from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner -from torch_xla.experimental._distributed_checkpoint_helpers import ( - _sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor) +from torch_xla.experimental.checkpoint import SPMDLoadPlanner, SPMDSavePlanner +from torch_xla.experimental.checkpoint._helpers import (_sharded_cpu_state_dict, + _CpuShards, + _is_sharded_tensor) class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest): diff --git a/torch_xla/experimental/checkpoint/__init__.py b/torch_xla/experimental/checkpoint/__init__.py new file mode 100644 index 000000000000..7c91aba0126d --- /dev/null +++ b/torch_xla/experimental/checkpoint/__init__.py @@ -0,0 +1,6 @@ +from .planners import SPMDSavePlanner, SPMDLoadPlanner + +__all__ = [ + "SPMDSavePlanner", + "SPMDLoadPlanner", +] diff --git a/torch_xla/experimental/_distributed_checkpoint_helpers.py b/torch_xla/experimental/checkpoint/_helpers.py similarity index 100% rename from torch_xla/experimental/_distributed_checkpoint_helpers.py rename to torch_xla/experimental/checkpoint/_helpers.py diff --git a/torch_xla/experimental/distributed_checkpoint.py b/torch_xla/experimental/checkpoint/planners.py similarity index 99% rename from torch_xla/experimental/distributed_checkpoint.py rename to torch_xla/experimental/checkpoint/planners.py index 09be65d4b0a9..8700da28c30d 100644 --- a/torch_xla/experimental/distributed_checkpoint.py +++ b/torch_xla/experimental/checkpoint/planners.py @@ -35,16 +35,11 @@ from torch.distributed.checkpoint.utils import find_state_dict_object from torch.utils._pytree import tree_map from torch_xla.experimental.xla_sharding import XLAShardedTensor, XLAShard -from torch_xla.experimental._distributed_checkpoint_helpers import ( +from torch_xla.experimental.checkpoint._helpers import ( FLATTEN_MAPPING, flatten_state_dict, dedup_tensors, _is_sharded_tensor, set_element, narrow_tensor_by_index, _unwrap_xla_sharded_tensor, _CpuShards) from typing import Any, Dict, List, Tuple, Union -__all__ = [ - "SPMDSavePlanner", - "SPMDLoadPlanner", -] - class SPMDSavePlanner(SavePlanner): """