From c060b6e6ef54786cb52b7d01763d24651aa1c524 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Thu, 28 Sep 2023 23:55:56 +0000 Subject: [PATCH] Move distributed checkpointing to a subdirectory of experimental --- test/spmd/test_xla_distributed_checkpoint.py | 7 ++++--- torch_xla/experimental/checkpoint/__init__.py | 6 ++++++ .../_helpers.py} | 0 .../{distributed_checkpoint.py => checkpoint/planners.py} | 7 +------ 4 files changed, 11 insertions(+), 9 deletions(-) create mode 100644 torch_xla/experimental/checkpoint/__init__.py rename torch_xla/experimental/{_distributed_checkpoint_helpers.py => checkpoint/_helpers.py} (100%) rename torch_xla/experimental/{distributed_checkpoint.py => checkpoint/planners.py} (99%) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 7b3d5eb86fbd..a47d34e0899f 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -15,9 +15,10 @@ create_default_local_save_plan, create_default_global_save_plan, ) -from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner -from torch_xla.experimental._distributed_checkpoint_helpers import ( - _sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor) +from torch_xla.experimental.checkpoint import SPMDLoadPlanner, SPMDSavePlanner +from torch_xla.experimental.checkpoint._helpers import (_sharded_cpu_state_dict, + _CpuShards, + _is_sharded_tensor) class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest): diff --git a/torch_xla/experimental/checkpoint/__init__.py b/torch_xla/experimental/checkpoint/__init__.py new file mode 100644 index 000000000000..7c91aba0126d --- /dev/null +++ b/torch_xla/experimental/checkpoint/__init__.py @@ -0,0 +1,6 @@ +from .planners import SPMDSavePlanner, SPMDLoadPlanner + +__all__ = [ + "SPMDSavePlanner", + "SPMDLoadPlanner", +] diff --git a/torch_xla/experimental/_distributed_checkpoint_helpers.py b/torch_xla/experimental/checkpoint/_helpers.py similarity index 100% rename from torch_xla/experimental/_distributed_checkpoint_helpers.py rename to torch_xla/experimental/checkpoint/_helpers.py diff --git a/torch_xla/experimental/distributed_checkpoint.py b/torch_xla/experimental/checkpoint/planners.py similarity index 99% rename from torch_xla/experimental/distributed_checkpoint.py rename to torch_xla/experimental/checkpoint/planners.py index 09be65d4b0a9..8700da28c30d 100644 --- a/torch_xla/experimental/distributed_checkpoint.py +++ b/torch_xla/experimental/checkpoint/planners.py @@ -35,16 +35,11 @@ from torch.distributed.checkpoint.utils import find_state_dict_object from torch.utils._pytree import tree_map from torch_xla.experimental.xla_sharding import XLAShardedTensor, XLAShard -from torch_xla.experimental._distributed_checkpoint_helpers import ( +from torch_xla.experimental.checkpoint._helpers import ( FLATTEN_MAPPING, flatten_state_dict, dedup_tensors, _is_sharded_tensor, set_element, narrow_tensor_by_index, _unwrap_xla_sharded_tensor, _CpuShards) from typing import Any, Dict, List, Tuple, Union -__all__ = [ - "SPMDSavePlanner", - "SPMDLoadPlanner", -] - class SPMDSavePlanner(SavePlanner): """