diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 74bc27cdf98..37c0224a0f4 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -4,6 +4,7 @@ import tempfile import unittest import test_xla_sharding_base +import threading import torch import torch.distributed as dist @@ -412,6 +413,70 @@ def test_manager_max_to_keep(self, tmpdir): self.assertTrue(chkpt_mgr.save(20, state_dict)) self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10}) + @run_with_tmpdir + def test_manager_async(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) + state_dict = self._get_sharded_model().state_dict() + + # Patch the manager's save method to block until this thread signals. + cond = threading.Condition() + old_save = chkpt_mgr.save + + def patched_save(*args, **kwargs): + cond.wait() + old_save(*args, **kwargs) + + with unittest.mock.patch.object(chkpt_mgr, 'save', patched_save): + chkpt_mgr.save_async(10, state_dict) + + # No new steps should be tracked immediately after calling save_async + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Trigger the actual checkpoint in the background thread and wait for + # completion. + with cond: + cond.notify() + chkpt_mgr.join() + + # The manager should track all steps which were asynchronously saved. + self.assertEqual(set(chkpt_mgr.all_steps()), {10}) + + @run_with_tmpdir + def test_manager_async_step_tracking(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) + state_dict = self._get_sharded_model().state_dict() + + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps not divisible by 10 should not be saved + for step in range(1, 10): + self.assertFalse(chkpt_mgr.save_async(step, state_dict)) + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps divisible by 10 should be saved + saved = set() + for step in range(0, 100, 10): + self.assertTrue(chkpt_mgr.save_async(step, state_dict)) + saved.add(step) + + # Join to allow pending async checkpoints to complete + chkpt_mgr.join() + + # The manager should track all steps which were asynchronously saved. + self.assertEqual(set(chkpt_mgr.all_steps()), saved) + + # Load a checkpoint into a new state_dict + new_state_dict = self._get_sharded_model().state_dict() + self.assertFalse( + any( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + chkpt_mgr.restore(0, new_state_dict) + self.assertTrue( + all( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index abf56bff940..9e5cde711b8 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -2,10 +2,13 @@ import logging import os import pickle +import queue +import threading import torch.distributed as dist import torch.distributed.checkpoint as dist_cp import torch_xla.runtime as xr import torch_xla.experimental.distributed_checkpoint as xc +import traceback from dataclasses import dataclass from datetime import datetime @@ -14,6 +17,7 @@ from os.path import basename from typing import Deque, List, Optional, Union from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from ._helpers import _sharded_cpu_state_dict # TODO(jonbolin): Import path will change from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter @@ -94,7 +98,8 @@ def __init__(self, path: str, save_interval: int, max_to_keep: Optional[int] = 0, - async_queue_size: Optional[int] = 1): + async_queue_size: Optional[int] = 1, + process_group: dist.ProcessGroup = None): """ Create a checkpoint manager that reads and writes checkpoints into the provided directory. @@ -113,6 +118,9 @@ def __init__(self, network issues which slow down the active checkpoint. Default: 1, which only allows a single async checkpoint to be pending at a time. + process_group: The process group to use when coordinating the checkpoint. + Default: None, in which case a subgroup of the default process + group will be created. """ assert dist.is_initialized(), "A process group is required." assert save_interval > 0, "save_interval must be positive" @@ -124,6 +132,16 @@ def __init__(self, self.max_to_keep = max_to_keep self._tracked_chkpts = self._load_tracked_chkpts() + self._async_queue = queue.Queue(maxsize=async_queue_size) + self._alive = threading.Event() + self._alive.set() + self._chkpt_thread = threading.Thread( + target=self._async_worker, daemon=True) + self._chkpt_thread.start() + + # Create a new group if none is provided + # TODO(jonbolin): Verify subgroup on GPU backend + self.pg = process_group or dist.new_group() def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: """ @@ -143,6 +161,25 @@ def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}') return deque(sorted(all_chkpts, key=lambda m: m.ts)) + def __del__(self): + self._alive.clear() + # Send a sentinel value to tell the worker to exit, and wait for pending + # checkpoints to complete. + self._async_queue.put(None) + self._chkpt_thread.join() + + def _async_worker(self): + while self._alive.is_set(): + try: + item = self._async_queue.get() + if item: + step, state_dict = item + self.save(step, state_dict, force=True) + except: + traceback.print_exc() + finally: + self._async_queue.task_done() + def _get_path(self, step: int) -> str: return os.path.join(self.base_path, str(step)) @@ -157,7 +194,7 @@ def _release_oldest_checkpoints(self): Delete oldest checkpoints until the number of tracked checkpoints is below self.max_to_keep. This operation is only execution on the rank 0 process. """ - if dist.get_rank() == 0 and self.max_to_keep > 0: + if dist.get_rank(self.pg) == 0 and self.max_to_keep > 0: while len(self._tracked_chkpts) > self.max_to_keep: oldest_chkpt = self._tracked_chkpts.popleft() self._delete_chkpt_at_step(oldest_chkpt.step) @@ -193,6 +230,7 @@ def save(self, state_dict=state_dict, storage_writer=FsspecWriter(path), planner=xc.SPMDSavePlanner(), + process_group=self.pg, ) metadata = _CheckpointMetadata(step=step, ts=datetime.now()) with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f: @@ -225,7 +263,12 @@ def save_async(self, Returns: True if a checkpoint was taken and False otherwise. """ - raise NotImplementedError + if self.should_save(step) or force: + # Move the state_dict to CPU + cpu_state_dict = _sharded_cpu_state_dict(state_dict) + self._async_queue.put((step, cpu_state_dict)) + return True + return False def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None: """ @@ -245,6 +288,7 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None: state_dict=state_dict, storage_reader=FsspecReader(path), planner=xc.SPMDLoadPlanner(), + process_group=self.pg, ) def all_steps(self) -> List[int]: @@ -252,3 +296,7 @@ def all_steps(self) -> List[int]: List all steps tracked by the CheckpointManager. """ return sorted(x.step for x in self._tracked_chkpts) + + def join(self): + """ Wait for all pending async checkpoints to complete. """ + self._async_queue.join()