Skip to content

Commit

Permalink
Support async checkpointing through CheckpointManager
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 12, 2023
1 parent 4430877 commit 32144fc
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
38 changes: 38 additions & 0 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,44 @@ def test_manager_max_to_keep(self, tmpdir):
self.assertTrue(chkpt_mgr.save(20, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10})

@run_with_tmpdir
def test_manager_async_checkpoint(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_period=10)
state_dict = self._get_sharded_model().state_dict()

self.assertEqual(chkpt_mgr.all_steps(), [])

# Steps not divisible by 10 should not be saved
for step in range(1, 10):
self.assertFalse(chkpt_mgr.save_async(step, state_dict))
self.assertEqual(chkpt_mgr.all_steps(), [])

# Steps divisible by 10 should be saved
saved = set()
for step in range(0, 100, 10):
self.assertTrue(chkpt_mgr.save_async(step, state_dict))
saved.add(step)

# Delete the checkpoint manager to block this thread until all pending
# async checkpoints are complete.
del chkpt_mgr

# The manager should track all steps which were asynchronously saved.
chkpt_mgr = CheckpointManager(tmpdir, save_period=10)
self.assertEqual(set(chkpt_mgr.all_steps()), saved)

# Load a checkpoint into a new state_dict
new_state_dict = self._get_sharded_model().state_dict()
self.assertFalse(
any(
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))
chkpt_mgr.restore(0, new_state_dict)
self.assertTrue(
all(
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))


if __name__ == '__main__':
test = unittest.main()
Expand Down
34 changes: 33 additions & 1 deletion torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import logging
import os
import pickle
import queue
import threading
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla.runtime as xr
import torch_xla.experimental.distributed_checkpoint as xc
import traceback

from dataclasses import dataclass
from datetime import datetime
Expand All @@ -14,6 +17,7 @@
from os.path import basename
from typing import Deque, List, Optional, Union
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from ._helpers import _sharded_cpu_state_dict

# TODO(jonbolin): Import path will change
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
Expand Down Expand Up @@ -125,6 +129,13 @@ def __init__(self,

self._tracked_chkpts = self._load_tracked_chkpts()

self._async_queue = queue.Queue(maxsize=async_queue_size)
self._chkpt_thread = threading.Thread(target=self._async_worker, daemon=True)
self._chkpt_thread.start()

# Create a CPU process group to coordinate the checkpoint.
self.pg = dist.new_group(backend='gloo')

def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
"""
Loads a list of all tracked checkpoints from the storage backend.
Expand All @@ -143,6 +154,20 @@ def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}')
return deque(sorted(all_chkpts, key=lambda m: m.ts))

def __del__(self):
# Ensure pending checkpoints are finished
self._async_queue.join()

def _async_worker(self):
while True:
try:
step, state_dict = self._async_queue.get()
self.save(step, state_dict, force=True)
except:
traceback.print_exc()
finally:
self._async_queue.task_done()

def _get_path(self, step: int) -> str:
return os.path.join(self.base_path, str(step))

Expand Down Expand Up @@ -193,6 +218,7 @@ def save(self,
state_dict=state_dict,
storage_writer=FsspecWriter(path),
planner=xc.SPMDSavePlanner(),
process_group=self.pg,
)
metadata = _CheckpointMetadata(step=step, ts=datetime.now())
with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f:
Expand Down Expand Up @@ -225,7 +251,12 @@ def save_async(self,
Returns:
True if a checkpoint was taken and False otherwise.
"""
raise NotImplementedError
if self.should_save(step) or force:
# Move the state_dict to CPU
cpu_state_dict = _sharded_cpu_state_dict(state_dict)
self._async_queue.put((step, cpu_state_dict))
return True
return False

def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
"""
Expand All @@ -245,6 +276,7 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
state_dict=state_dict,
storage_reader=FsspecReader(path),
planner=xc.SPMDLoadPlanner(),
process_group=self.pg,
)

def all_steps(self) -> List[int]:
Expand Down

0 comments on commit 32144fc

Please sign in to comment.