Skip to content

Commit

Permalink
Support async checkpointing through CheckpointManager (#5697)
Browse files Browse the repository at this point in the history
* Support async checkpointing through CheckpointManager

* Allow threads to exit when CheckpointManager is freed

* Use rank from tracked process group

* Add TODO
  • Loading branch information
jonb377 authored and bhavya01 committed Apr 22, 2024
1 parent 624fac9 commit de07fe9
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 3 deletions.
65 changes: 65 additions & 0 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
import unittest
import test_xla_sharding_base
import threading

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -412,6 +413,70 @@ def test_manager_max_to_keep(self, tmpdir):
self.assertTrue(chkpt_mgr.save(20, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10})

@run_with_tmpdir
def test_manager_async(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10)
state_dict = self._get_sharded_model().state_dict()

# Patch the manager's save method to block until this thread signals.
cond = threading.Condition()
old_save = chkpt_mgr.save

def patched_save(*args, **kwargs):
cond.wait()
old_save(*args, **kwargs)

with unittest.mock.patch.object(chkpt_mgr, 'save', patched_save):
chkpt_mgr.save_async(10, state_dict)

# No new steps should be tracked immediately after calling save_async
self.assertEqual(chkpt_mgr.all_steps(), [])

# Trigger the actual checkpoint in the background thread and wait for
# completion.
with cond:
cond.notify()
chkpt_mgr.join()

# The manager should track all steps which were asynchronously saved.
self.assertEqual(set(chkpt_mgr.all_steps()), {10})

@run_with_tmpdir
def test_manager_async_step_tracking(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10)
state_dict = self._get_sharded_model().state_dict()

self.assertEqual(chkpt_mgr.all_steps(), [])

# Steps not divisible by 10 should not be saved
for step in range(1, 10):
self.assertFalse(chkpt_mgr.save_async(step, state_dict))
self.assertEqual(chkpt_mgr.all_steps(), [])

# Steps divisible by 10 should be saved
saved = set()
for step in range(0, 100, 10):
self.assertTrue(chkpt_mgr.save_async(step, state_dict))
saved.add(step)

# Join to allow pending async checkpoints to complete
chkpt_mgr.join()

# The manager should track all steps which were asynchronously saved.
self.assertEqual(set(chkpt_mgr.all_steps()), saved)

# Load a checkpoint into a new state_dict
new_state_dict = self._get_sharded_model().state_dict()
self.assertFalse(
any(
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))
chkpt_mgr.restore(0, new_state_dict)
self.assertTrue(
all(
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))


if __name__ == '__main__':
test = unittest.main()
Expand Down
54 changes: 51 additions & 3 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import logging
import os
import pickle
import queue
import threading
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla.runtime as xr
import torch_xla.experimental.distributed_checkpoint as xc
import traceback

from dataclasses import dataclass
from datetime import datetime
Expand All @@ -14,6 +17,7 @@
from os.path import basename
from typing import Deque, List, Optional, Union
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from ._helpers import _sharded_cpu_state_dict

# TODO(jonbolin): Import path will change
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
Expand Down Expand Up @@ -94,7 +98,8 @@ def __init__(self,
path: str,
save_interval: int,
max_to_keep: Optional[int] = 0,
async_queue_size: Optional[int] = 1):
async_queue_size: Optional[int] = 1,
process_group: dist.ProcessGroup = None):
"""
Create a checkpoint manager that reads and writes checkpoints into
the provided directory.
Expand All @@ -113,6 +118,9 @@ def __init__(self,
network issues which slow down the active checkpoint.
Default: 1, which only allows a single async checkpoint to be
pending at a time.
process_group: The process group to use when coordinating the checkpoint.
Default: None, in which case a subgroup of the default process
group will be created.
"""
assert dist.is_initialized(), "A process group is required."
assert save_interval > 0, "save_interval must be positive"
Expand All @@ -124,6 +132,16 @@ def __init__(self,
self.max_to_keep = max_to_keep

self._tracked_chkpts = self._load_tracked_chkpts()
self._async_queue = queue.Queue(maxsize=async_queue_size)
self._alive = threading.Event()
self._alive.set()
self._chkpt_thread = threading.Thread(
target=self._async_worker, daemon=True)
self._chkpt_thread.start()

# Create a new group if none is provided
# TODO(jonbolin): Verify subgroup on GPU backend
self.pg = process_group or dist.new_group()

def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
"""
Expand All @@ -143,6 +161,25 @@ def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}')
return deque(sorted(all_chkpts, key=lambda m: m.ts))

def __del__(self):
self._alive.clear()
# Send a sentinel value to tell the worker to exit, and wait for pending
# checkpoints to complete.
self._async_queue.put(None)
self._chkpt_thread.join()

def _async_worker(self):
while self._alive.is_set():
try:
item = self._async_queue.get()
if item:
step, state_dict = item
self.save(step, state_dict, force=True)
except:
traceback.print_exc()
finally:
self._async_queue.task_done()

def _get_path(self, step: int) -> str:
return os.path.join(self.base_path, str(step))

Expand All @@ -157,7 +194,7 @@ def _release_oldest_checkpoints(self):
Delete oldest checkpoints until the number of tracked checkpoints is below
self.max_to_keep. This operation is only execution on the rank 0 process.
"""
if dist.get_rank() == 0 and self.max_to_keep > 0:
if dist.get_rank(self.pg) == 0 and self.max_to_keep > 0:
while len(self._tracked_chkpts) > self.max_to_keep:
oldest_chkpt = self._tracked_chkpts.popleft()
self._delete_chkpt_at_step(oldest_chkpt.step)
Expand Down Expand Up @@ -193,6 +230,7 @@ def save(self,
state_dict=state_dict,
storage_writer=FsspecWriter(path),
planner=xc.SPMDSavePlanner(),
process_group=self.pg,
)
metadata = _CheckpointMetadata(step=step, ts=datetime.now())
with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f:
Expand Down Expand Up @@ -225,7 +263,12 @@ def save_async(self,
Returns:
True if a checkpoint was taken and False otherwise.
"""
raise NotImplementedError
if self.should_save(step) or force:
# Move the state_dict to CPU
cpu_state_dict = _sharded_cpu_state_dict(state_dict)
self._async_queue.put((step, cpu_state_dict))
return True
return False

def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
"""
Expand All @@ -245,10 +288,15 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
state_dict=state_dict,
storage_reader=FsspecReader(path),
planner=xc.SPMDLoadPlanner(),
process_group=self.pg,
)

def all_steps(self) -> List[int]:
"""
List all steps tracked by the CheckpointManager.
"""
return sorted(x.step for x in self._tracked_chkpts)

def join(self):
""" Wait for all pending async checkpoints to complete. """
self._async_queue.join()

0 comments on commit de07fe9

Please sign in to comment.