Skip to content

Commit

Permalink
Support autocheckpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 27, 2023
1 parent d9ba7ca commit 5fdce13
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
16 changes: 16 additions & 0 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,22 @@ def test_preemption_sync_manager(self):
# Scope the distributed runtime to the lifespan of the test.
torch_xla._XLAC._deactivate_preemption_sync_manager()

@run_with_tmpdir
def test_auto_checkpoint(self, tmpdir):
# Create a checkpoint manager with a long save interval
chkpt_mgr = CheckpointManager(tmpdir, save_interval=100)
state_dict = self._get_sharded_model().state_dict()

preemption_step = 10
# Skip step 0 so the manager will track no checkpoints before preemption
for step in range(1, preemption_step):
self.assertFalse(chkpt_mgr.save(step, state_dict))

with unittest.mock.patch('torch_xla._XLAC._sync_point_reached',
lambda x: True):
self.assertTrue(chkpt_mgr.save(preemption_step, state_dict))
self.assertTrue(chkpt_mgr.reached_preemption(step))


if __name__ == '__main__':
test = unittest.main()
Expand Down
34 changes: 31 additions & 3 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla
import torch_xla.runtime as xr
import torch_xla.experimental.distributed_checkpoint as xc
import traceback
Expand Down Expand Up @@ -94,12 +95,16 @@ class CheckpointManager:
# The maximum number of checkpoints to keep.
max_to_keep: int

# Whether a checkpoint should be taken when a preemption is detected.
chkpt_on_preemption: bool

def __init__(self,
path: str,
save_interval: int,
max_to_keep: Optional[int] = 0,
async_queue_size: Optional[int] = 1,
process_group: dist.ProcessGroup = None):
process_group: dist.ProcessGroup = None,
chkpt_on_preemption: bool = True):
"""
Create a checkpoint manager that reads and writes checkpoints into
the provided directory.
Expand All @@ -121,6 +126,9 @@ def __init__(self,
process_group: The process group to use when coordinating the checkpoint.
Default: None, in which case a subgroup of the default process
group will be created.
chkpt_on_preemption: Whether or not to take a checkpoint when a
preemption has been detected.
Default: True
"""
assert dist.is_initialized(), "A process group is required."
assert save_interval > 0, "save_interval must be positive"
Expand All @@ -130,6 +138,7 @@ def __init__(self,
self.base_path = path
self.save_interval = save_interval
self.max_to_keep = max_to_keep
self.chkpt_on_preemption = chkpt_on_preemption

self._tracked_chkpts = self._load_tracked_chkpts()
self._async_queue = queue.Queue(maxsize=async_queue_size)
Expand All @@ -143,6 +152,13 @@ def __init__(self,
# TODO(jonbolin): Verify subgroup on GPU backend
self.pg = process_group or dist.new_group()

if self.chkpt_on_preemption:
# Initialize the distributed runtime for preemption detection
master_ip = xr.get_master_ip()
torch_xla._XLAC._ensure_xla_coordinator_initialized(
xr.process_index(), xr.process_count(), master_ip)
torch_xla._XLAC._activate_preemption_sync_manager()

def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
"""
Loads a list of all tracked checkpoints from the storage backend.
Expand Down Expand Up @@ -204,8 +220,13 @@ def should_save(self, step: int) -> bool:
Returns true if a checkpoint should be saved for the current step or if
a preemption has been detected.
"""
# TODO(jonbolin): Support preemption notice for auto checkpointing
return step % self.save_interval == 0
preemption_detected = False
if self.chkpt_on_preemption and self.reached_preemption(step):
logging.warn(
f"Preemption sync point reached at step {step}. Triggering a checkpoint."
)
preemption_detected = True
return step % self.save_interval == 0 or preemption_detected

def save(self,
step,
Expand Down Expand Up @@ -300,3 +321,10 @@ def all_steps(self) -> List[int]:
def join(self):
""" Wait for all pending async checkpoints to complete. """
self._async_queue.join()

def reached_preemption(self, step: int) -> bool:
""" Returns True if a preemption has been detected at the given step. """
assert self.chkpt_on_preemption, (
"Preemption detection not enabled. Please set `chkpt_on_preemption` "
" when creating the CheckpointManager")
return torch_xla._XLAC._sync_point_reached(step)

0 comments on commit 5fdce13

Please sign in to comment.