Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address deprecation in torch.distributed.checkpoint #6773

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ train_loader = pl.MpDeviceLoader(

PyTorch/XLA SPMD is compatible with the [torch.distributed.checkpoint](https://pytorch.org/docs/stable/distributed.checkpoint.html) library through a dedicated `Planner` instance. Users are able to synchronously save and load checkpoints through this common interface.

The SPMDSavePlanner and SPMDLoadPlanner ([src](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/distributed_checkpoint.py)) classes enable the `save_state_dict` and `load_state_dict` functions to operate directly on the shards of an `XLAShardedTensor`, enabling all of the benefits of distributed checkpointing in SPMD training.
The SPMDSavePlanner and SPMDLoadPlanner ([src](https://github.com/pytorch/xla/blob/master/torch_xla/experimental/distributed_checkpoint.py)) classes enable the `save` and `load` functions to operate directly on the shards of an `XLAShardedTensor`, enabling all of the benefits of distributed checkpointing in SPMD training.

Here is a demonstration of the synchronous distributed checkpointing API:

Expand All @@ -249,7 +249,7 @@ state_dict = {
"optim": optim.state_dict(),
}

dist_cp.save_state_dict(
dist_cp.save(
state_dict=state_dict,
storage_writer=dist_cp.FileSystemWriter(CHECKPOINT_DIR),
planner=xc.SPMDSavePlanner(),
Expand All @@ -262,7 +262,7 @@ state_dict = {
"model": model.state_dict(),
}

dist_cp.load_state_dict(
dist_cp.load(
state_dict=state_dict,
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
planner=xc.SPMDLoadPlanner(),
Expand Down
16 changes: 5 additions & 11 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def _save_and_restore(self,
save_planner=None,
load_planner=None,
is_sharded_cpu_state_dict=False,
no_dist=True,
chkpt_path=None):
"""
Checkpoint model_in using the provided save_planner and load into model_out
Expand All @@ -90,24 +89,22 @@ def _save_and_restore(self,
if is_sharded_cpu_state_dict:
model_in_state_dict = _sharded_cpu_state_dict(model_in_state_dict)
model_out_state_dict = model_out.state_dict()
dist_cp.save_state_dict(
dist_cp.save(
state_dict=model_in_state_dict,
storage_writer=dist_cp.FileSystemWriter(
chkpt_path,
per_thread_copy_ahead=0,
),
planner=save_planner,
no_dist=no_dist,
)
# Load the checkpoint using the provided load planner
for p1, p2 in zip(model_in.parameters(), model_out.parameters()):
self.assertFalse(torch.allclose(p1.cpu(), p2.cpu()))

dist_cp.load_state_dict(
dist_cp.load(
state_dict=model_out_state_dict,
storage_reader=dist_cp.FileSystemReader(chkpt_path),
planner=load_planner,
no_dist=no_dist,
)
for p1, p2 in zip(model_in.parameters(), model_out.parameters()):
self.assertTrue(torch.allclose(p1.cpu(), p2.cpu()))
Expand Down Expand Up @@ -142,15 +139,13 @@ def test_resharding_different_device_mesh(self):
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner())

@unittest.skipUnless(
{'CHKPT_PATH', 'MASTER_ADDR', 'MASTER_PORT', 'RANK', 'WORLD_SIZE'
} <= os.environ.keys(),
'CHKPT_PATH and distributed config must be set for multihost checkpoint')
@unittest.skipUnless('CHKPT_PATH' in os.environ,
'CHKPT_PATH must be set for multihost checkpoint')
def test_multihost_checkpoint(self):
torch.manual_seed(42)

# Initialize the default CPU process group from the environment.
dist.init_process_group()
dist.init_process_group(backend='gloo', init_method='xla://')

model1 = self._get_sharded_model(mesh_shape=(1, self.n_devices))
model2 = self._get_sharded_model(mesh_shape=(self.n_devices, 1))
Expand All @@ -160,7 +155,6 @@ def test_multihost_checkpoint(self):
model2,
save_planner=SPMDSavePlanner(),
load_planner=SPMDLoadPlanner(),
no_dist=False,
chkpt_path=os.environ['CHKPT_PATH'])

# Destroy the CPU process group after the test
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _save(self, step, state_dict):
path = self._get_path(step)
# Delete any existing checkpoint at the current step.
self._delete_chkpt_at_step(step)
dist_cp.save_state_dict(
dist_cp.save(
state_dict=state_dict,
storage_writer=FsspecWriter(
path,
Expand All @@ -244,7 +244,7 @@ def should_save(self, step: int) -> bool:
"""
preemption_detected = False
if self.chkpt_on_preemption and self.reached_preemption(step):
logging.warn(
logging.warning(
f"Preemption sync point reached at step {step}. Triggering a checkpoint."
)
preemption_detected = True
Expand Down Expand Up @@ -319,7 +319,7 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
tracked_steps = set(x.step for x in self._tracked_chkpts)
assert step in tracked_steps, f'Cannot restore from untracked step {step}. Valid steps are: {tracked_steps}'
path = self._get_path(step)
dist_cp.load_state_dict(
dist_cp.load(
state_dict=state_dict,
storage_reader=FsspecReader(path),
planner=xc.SPMDLoadPlanner(),
Expand Down
Loading