Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support autocheckpointing in CheckpointManager #5753

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ def tearDown(self):

@run_with_tmpdir
def test_manager_checkpointing(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10)
chkpt_mgr = CheckpointManager(
tmpdir, save_interval=10, chkpt_on_preemption=False)
state_dict = self._get_sharded_model().state_dict()

# Take a checkpoint on step 0
Expand All @@ -376,7 +377,8 @@ def test_manager_checkpointing(self, tmpdir):

@run_with_tmpdir
def test_manager_step_tracking(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10)
chkpt_mgr = CheckpointManager(
tmpdir, save_interval=10, chkpt_on_preemption=False)
state_dict = self._get_sharded_model().state_dict()

# No steps are being tracked initially
Expand All @@ -396,7 +398,8 @@ def test_manager_step_tracking(self, tmpdir):

@run_with_tmpdir
def test_manager_max_to_keep(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10, max_to_keep=2)
chkpt_mgr = CheckpointManager(
tmpdir, save_interval=10, max_to_keep=2, chkpt_on_preemption=False)
state_dict = self._get_sharded_model().state_dict()

# No steps are being tracked initially
Expand All @@ -417,13 +420,15 @@ def test_manager_max_to_keep(self, tmpdir):
self.assertEqual(set(chkpt_mgr.all_steps()), {30, 10})

# The deletion order should persist across executions
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10, max_to_keep=2)
chkpt_mgr = CheckpointManager(
tmpdir, save_interval=10, max_to_keep=2, chkpt_on_preemption=False)
self.assertTrue(chkpt_mgr.save(20, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10})

@run_with_tmpdir
def test_manager_async(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10)
chkpt_mgr = CheckpointManager(
tmpdir, save_interval=10, chkpt_on_preemption=False)
state_dict = self._get_sharded_model().state_dict()

# Patch the manager's save method to block until this thread signals.
Expand Down Expand Up @@ -451,7 +456,8 @@ def patched_save(*args, **kwargs):

@run_with_tmpdir
def test_manager_async_step_tracking(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10)
chkpt_mgr = CheckpointManager(
tmpdir, save_interval=10, chkpt_on_preemption=False)
state_dict = self._get_sharded_model().state_dict()

self.assertEqual(chkpt_mgr.all_steps(), [])
Expand Down Expand Up @@ -522,6 +528,24 @@ def test_preemption_sync_manager(self):
# Scope the PreemptionSyncManager to the lifespan of the test.
torch_xla._XLAC._deactivate_preemption_sync_manager()

@unittest.skipUnless(xr.device_type() == 'TPU',
'TPU required for worker IP discovery')
@run_with_tmpdir
def test_auto_checkpoint(self, tmpdir):
# Create a checkpoint manager with a long save interval
chkpt_mgr = CheckpointManager(tmpdir, save_interval=100)
state_dict = self._get_sharded_model().state_dict()

preemption_step = 10
# Skip step 0 so the manager will track no checkpoints before preemption
for step in range(1, preemption_step):
self.assertFalse(chkpt_mgr.save(step, state_dict))

with unittest.mock.patch('torch_xla._XLAC._sync_point_reached',
lambda x: True):
self.assertTrue(chkpt_mgr.save(preemption_step, state_dict))
self.assertTrue(chkpt_mgr.reached_preemption(step))


if __name__ == '__main__':
test = unittest.main()
Expand Down
41 changes: 36 additions & 5 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla
import torch_xla.runtime as xr
import torch_xla.experimental.distributed_checkpoint as xc
import traceback
Expand Down Expand Up @@ -94,12 +95,16 @@ class CheckpointManager:
# The maximum number of checkpoints to keep.
max_to_keep: int

# Whether a checkpoint should be taken when a preemption is detected.
chkpt_on_preemption: bool

def __init__(self,
path: str,
save_interval: int,
max_to_keep: Optional[int] = 0,
async_queue_size: Optional[int] = 1,
process_group: dist.ProcessGroup = None):
process_group: dist.ProcessGroup = None,
chkpt_on_preemption: bool = True):
"""
Create a checkpoint manager that reads and writes checkpoints into
the provided directory.
Expand All @@ -121,6 +126,9 @@ def __init__(self,
process_group: The process group to use when coordinating the checkpoint.
Default: None, in which case a subgroup of the default process
group will be created.
chkpt_on_preemption: Whether or not to take a checkpoint when a
preemption has been detected.
Default: True
"""
assert dist.is_initialized(), "A process group is required."
assert save_interval > 0, "save_interval must be positive"
Expand All @@ -130,6 +138,7 @@ def __init__(self,
self.base_path = path
self.save_interval = save_interval
self.max_to_keep = max_to_keep
self.chkpt_on_preemption = chkpt_on_preemption

self._tracked_chkpts = self._load_tracked_chkpts()
self._async_queue = queue.Queue(maxsize=async_queue_size)
Expand All @@ -143,6 +152,13 @@ def __init__(self,
# TODO(jonbolin): Verify subgroup on GPU backend
self.pg = process_group or dist.new_group()

if self.chkpt_on_preemption:
# Initialize the distributed runtime for preemption detection
master_ip = xr.get_master_ip()
torch_xla._XLAC._ensure_xla_coordinator_initialized(
xr.process_index(), xr.process_count(), master_ip)
torch_xla._XLAC._activate_preemption_sync_manager()

def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
"""
Loads a list of all tracked checkpoints from the storage backend.
Expand Down Expand Up @@ -201,11 +217,19 @@ def _release_oldest_checkpoints(self):

def should_save(self, step: int) -> bool:
"""
Returns true if a checkpoint should be saved for the current step or if
a preemption has been detected.
Returns true if a checkpoint should be saved for the current step. A
checkpoint should be taken if any of the following conditions are met:
- The step aligns with the CheckpointManager's save_interval.
- The CheckpointManager was created with the `chkpt_on_preemption` option
and a preemption has been detected.
"""
# TODO(jonbolin): Support preemption notice for auto checkpointing
return step % self.save_interval == 0
preemption_detected = False
if self.chkpt_on_preemption and self.reached_preemption(step):
logging.warn(
f"Preemption sync point reached at step {step}. Triggering a checkpoint."
)
preemption_detected = True
return step % self.save_interval == 0 or preemption_detected

def save(self,
step,
Expand Down Expand Up @@ -300,3 +324,10 @@ def all_steps(self) -> List[int]:
def join(self):
""" Wait for all pending async checkpoints to complete. """
self._async_queue.join()

def reached_preemption(self, step: int) -> bool:
""" Returns True if a preemption has been detected at the given step. """
assert self.chkpt_on_preemption, (
"Preemption detection not enabled. Please set `chkpt_on_preemption` "
" when creating the CheckpointManager")
return torch_xla._XLAC._sync_point_reached(step)