diff --git a/docs/spmd.md b/docs/spmd.md index c139c613d8d..5d6e554092d 100644 --- a/docs/spmd.md +++ b/docs/spmd.md @@ -235,7 +235,7 @@ train_loader = pl.MpDeviceLoader( PyTorch/XLA SPMD is compatible with the [torch.distributed.checkpoint](https://pytorch.org/docs/stable/distributed.checkpoint.html) library through a dedicated `Planner` instance. Users are able to synchronously save and load checkpoints through this common interface. -The SPMDSavePlanner and SPMDLoadPlanner ([src](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/distributed_checkpoint.py)) classes enable the `save_state_dict` and `load_state_dict` functions to operate directly on the shards of an `XLAShardedTensor`, enabling all of the benefits of distributed checkpointing in SPMD training. +The SPMDSavePlanner and SPMDLoadPlanner ([src](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/distributed_checkpoint.py)) classes enable the `save` and `load` functions to operate directly on the shards of an `XLAShardedTensor`, enabling all of the benefits of distributed checkpointing in SPMD training. Here is a demonstration of the synchronous distributed checkpointing API: @@ -249,7 +249,7 @@ state_dict = { "optim": optim.state_dict(), } -dist_cp.save_state_dict( +dist_cp.save( state_dict=state_dict, storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR), planner=xc.SPMDSavePlanner(), @@ -262,7 +262,7 @@ state_dict = { "model": model.state_dict(), } -dist_cp.load_state_dict( +dist_cp.load( state_dict=state_dict, storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR), planner=xc.SPMDLoadPlanner(), diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 98c465e0718..f4467e6a85f 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -76,7 +76,6 @@ def _save_and_restore(self, save_planner=None, load_planner=None, is_sharded_cpu_state_dict=False, - no_dist=True, chkpt_path=None): """ Checkpoint model_in using the provided save_planner and load into model_out @@ -90,24 +89,22 @@ def _save_and_restore(self, if is_sharded_cpu_state_dict: model_in_state_dict = _sharded_cpu_state_dict(model_in_state_dict) model_out_state_dict = model_out.state_dict() - dist_cp.save_state_dict( + dist_cp.save( state_dict=model_in_state_dict, storage_writer=dist_cp.FileSystemWriter( chkpt_path, per_thread_copy_ahead=0, ), planner=save_planner, - no_dist=no_dist, ) # Load the checkpoint using the provided load planner for p1, p2 in zip(model_in.parameters(), model_out.parameters()): self.assertFalse(torch.allclose(p1.cpu(), p2.cpu())) - dist_cp.load_state_dict( + dist_cp.load( state_dict=model_out_state_dict, storage_reader=dist_cp.FileSystemReader(chkpt_path), planner=load_planner, - no_dist=no_dist, ) for p1, p2 in zip(model_in.parameters(), model_out.parameters()): self.assertTrue(torch.allclose(p1.cpu(), p2.cpu())) @@ -142,15 +139,13 @@ def test_resharding_different_device_mesh(self): save_planner=SPMDSavePlanner(), load_planner=SPMDLoadPlanner()) - @unittest.skipUnless( - {'CHKPT_PATH', 'MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE' - } <= os.environ.keys(), - 'CHKPT_PATH and distributed config must be set for multihost checkpoint') + @unittest.skipUnless('CHKPT_PATH' in os.environ, + 'CHKPT_PATH must be set for multihost checkpoint') def test_multihost_checkpoint(self): torch.manual_seed(42) # Initialize the default CPU process group from the environment. - dist.init_process_group() + dist.init_process_group(backend='gloo', init_method='xla://') model1 = self._get_sharded_model(mesh_shape=(1, self.n_devices)) model2 = self._get_sharded_model(mesh_shape=(self.n_devices, 1)) @@ -160,7 +155,6 @@ def test_multihost_checkpoint(self): model2, save_planner=SPMDSavePlanner(), load_planner=SPMDLoadPlanner(), - no_dist=False, chkpt_path=os.environ['CHKPT_PATH']) # Destroy the CPU process group after the test diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index 13b6abfacfc..4ce57b5fb38 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -218,7 +218,7 @@ def _save(self, step, state_dict): path = self._get_path(step) # Delete any existing checkpoint at the current step. self._delete_chkpt_at_step(step) - dist_cp.save_state_dict( + dist_cp.save( state_dict=state_dict, storage_writer=FsspecWriter( path, @@ -244,7 +244,7 @@ def should_save(self, step: int) -> bool: """ preemption_detected = False if self.chkpt_on_preemption and self.reached_preemption(step): - logging.warn( + logging.warning( f"Preemption sync point reached at step {step}. Triggering a checkpoint." ) preemption_detected = True @@ -319,7 +319,7 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None: tracked_steps = set(x.step for x in self._tracked_chkpts) assert step in tracked_steps, f'Cannot restore from untracked step {step}. Valid steps are: {tracked_steps}' path = self._get_path(step) - dist_cp.load_state_dict( + dist_cp.load( state_dict=state_dict, storage_reader=FsspecReader(path), planner=xc.SPMDLoadPlanner(),