diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 29ed825d015..1081ce0b188 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -356,7 +356,8 @@ def tearDown(self): @run_with_tmpdir def test_manager_checkpointing(self, tmpdir): - chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) + chkpt_mgr = CheckpointManager( + tmpdir, save_interval=10, chkpt_on_preemption=False) state_dict = self._get_sharded_model().state_dict() # Take a checkpoint on step 0 @@ -376,7 +377,8 @@ def test_manager_checkpointing(self, tmpdir): @run_with_tmpdir def test_manager_step_tracking(self, tmpdir): - chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) + chkpt_mgr = CheckpointManager( + tmpdir, save_interval=10, chkpt_on_preemption=False) state_dict = self._get_sharded_model().state_dict() # No steps are being tracked initially @@ -396,7 +398,8 @@ def test_manager_step_tracking(self, tmpdir): @run_with_tmpdir def test_manager_max_to_keep(self, tmpdir): - chkpt_mgr = CheckpointManager(tmpdir, save_interval=10, max_to_keep=2) + chkpt_mgr = CheckpointManager( + tmpdir, save_interval=10, max_to_keep=2, chkpt_on_preemption=False) state_dict = self._get_sharded_model().state_dict() # No steps are being tracked initially @@ -417,13 +420,15 @@ def test_manager_max_to_keep(self, tmpdir): self.assertEqual(set(chkpt_mgr.all_steps()), {30, 10}) # The deletion order should persist across executions - chkpt_mgr = CheckpointManager(tmpdir, save_interval=10, max_to_keep=2) + chkpt_mgr = CheckpointManager( + tmpdir, save_interval=10, max_to_keep=2, chkpt_on_preemption=False) self.assertTrue(chkpt_mgr.save(20, state_dict)) self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10}) @run_with_tmpdir def test_manager_async(self, tmpdir): - chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) + chkpt_mgr = CheckpointManager( + tmpdir, save_interval=10, chkpt_on_preemption=False) state_dict = self._get_sharded_model().state_dict() # Patch the manager's save method to block until this thread signals. @@ -451,7 +456,8 @@ def patched_save(*args, **kwargs): @run_with_tmpdir def test_manager_async_step_tracking(self, tmpdir): - chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) + chkpt_mgr = CheckpointManager( + tmpdir, save_interval=10, chkpt_on_preemption=False) state_dict = self._get_sharded_model().state_dict() self.assertEqual(chkpt_mgr.all_steps(), []) @@ -522,6 +528,24 @@ def test_preemption_sync_manager(self): # Scope the PreemptionSyncManager to the lifespan of the test. torch_xla._XLAC._deactivate_preemption_sync_manager() + @unittest.skipUnless(xr.device_type() == 'TPU', + 'TPU required for worker IP discovery') + @run_with_tmpdir + def test_auto_checkpoint(self, tmpdir): + # Create a checkpoint manager with a long save interval + chkpt_mgr = CheckpointManager(tmpdir, save_interval=100) + state_dict = self._get_sharded_model().state_dict() + + preemption_step = 10 + # Skip step 0 so the manager will track no checkpoints before preemption + for step in range(1, preemption_step): + self.assertFalse(chkpt_mgr.save(step, state_dict)) + + with unittest.mock.patch('torch_xla._XLAC._sync_point_reached', + lambda x: True): + self.assertTrue(chkpt_mgr.save(preemption_step, state_dict)) + self.assertTrue(chkpt_mgr.reached_preemption(step)) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index 9e5cde711b8..0eaf184910a 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -6,6 +6,7 @@ import threading import torch.distributed as dist import torch.distributed.checkpoint as dist_cp +import torch_xla import torch_xla.runtime as xr import torch_xla.experimental.distributed_checkpoint as xc import traceback @@ -94,12 +95,16 @@ class CheckpointManager: # The maximum number of checkpoints to keep. max_to_keep: int + # Whether a checkpoint should be taken when a preemption is detected. + chkpt_on_preemption: bool + def __init__(self, path: str, save_interval: int, max_to_keep: Optional[int] = 0, async_queue_size: Optional[int] = 1, - process_group: dist.ProcessGroup = None): + process_group: dist.ProcessGroup = None, + chkpt_on_preemption: bool = True): """ Create a checkpoint manager that reads and writes checkpoints into the provided directory. @@ -121,6 +126,9 @@ def __init__(self, process_group: The process group to use when coordinating the checkpoint. Default: None, in which case a subgroup of the default process group will be created. + chkpt_on_preemption: Whether or not to take a checkpoint when a + preemption has been detected. + Default: True """ assert dist.is_initialized(), "A process group is required." assert save_interval > 0, "save_interval must be positive" @@ -130,6 +138,7 @@ def __init__(self, self.base_path = path self.save_interval = save_interval self.max_to_keep = max_to_keep + self.chkpt_on_preemption = chkpt_on_preemption self._tracked_chkpts = self._load_tracked_chkpts() self._async_queue = queue.Queue(maxsize=async_queue_size) @@ -143,6 +152,13 @@ def __init__(self, # TODO(jonbolin): Verify subgroup on GPU backend self.pg = process_group or dist.new_group() + if self.chkpt_on_preemption: + # Initialize the distributed runtime for preemption detection + master_ip = xr.get_master_ip() + torch_xla._XLAC._ensure_xla_coordinator_initialized( + xr.process_index(), xr.process_count(), master_ip) + torch_xla._XLAC._activate_preemption_sync_manager() + def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: """ Loads a list of all tracked checkpoints from the storage backend. @@ -201,11 +217,19 @@ def _release_oldest_checkpoints(self): def should_save(self, step: int) -> bool: """ - Returns true if a checkpoint should be saved for the current step or if - a preemption has been detected. + Returns true if a checkpoint should be saved for the current step. A + checkpoint should be taken if any of the following conditions are met: + - The step aligns with the CheckpointManager's save_interval. + - The CheckpointManager was created with the `chkpt_on_preemption` option + and a preemption has been detected. """ - # TODO(jonbolin): Support preemption notice for auto checkpointing - return step % self.save_interval == 0 + preemption_detected = False + if self.chkpt_on_preemption and self.reached_preemption(step): + logging.warn( + f"Preemption sync point reached at step {step}. Triggering a checkpoint." + ) + preemption_detected = True + return step % self.save_interval == 0 or preemption_detected def save(self, step, @@ -300,3 +324,10 @@ def all_steps(self) -> List[int]: def join(self): """ Wait for all pending async checkpoints to complete. """ self._async_queue.join() + + def reached_preemption(self, step: int) -> bool: + """ Returns True if a preemption has been detected at the given step. """ + assert self.chkpt_on_preemption, ( + "Preemption detection not enabled. Please set `chkpt_on_preemption` " + " when creating the CheckpointManager") + return torch_xla._XLAC._sync_point_reached(step)