diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 276571e5979..74bc27cdf98 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -1,3 +1,4 @@ +import functools import os import sys import tempfile @@ -15,11 +16,23 @@ create_default_local_save_plan, create_default_global_save_plan, ) -from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner +from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner, CheckpointManager from torch_xla.experimental.distributed_checkpoint._helpers import ( _sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor) +# Wrapper to manage a temporary directory for the wrapped test +def run_with_tmpdir(f): + + @functools.wraps(f) + def run(*args, **kwargs): + with tempfile.TemporaryDirectory() as tmpdir: + kwargs.setdefault('tmpdir', tmpdir) + f(*args, **kwargs) + + return run + + class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest): @classmethod @@ -319,6 +332,87 @@ def test_sharded_cpu_state_dict(self): self.assertTrue(param.device == torch.device("cpu")) +class CheckpointManagerTest(DistributedCheckpointTestBase): + + def setUp(self): + super().setUp() + # Initialize the a minimal process group + dist.init_process_group( + backend='gloo', init_method='tcp://127.1:8932', world_size=1, rank=0) + + def tearDown(self): + super().tearDown() + # Destroy the CPU process group after the test + dist.destroy_process_group() + + @run_with_tmpdir + def test_manager_checkpointing(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) + state_dict = self._get_sharded_model().state_dict() + + # Take a checkpoint on step 0 + self.assertTrue(chkpt_mgr.save(0, state_dict)) + + # Load the checkpoint into a new state_dict + new_state_dict = self._get_sharded_model().state_dict() + self.assertFalse( + any( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + chkpt_mgr.restore(0, new_state_dict) + self.assertTrue( + all( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + + @run_with_tmpdir + def test_manager_step_tracking(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) + state_dict = self._get_sharded_model().state_dict() + + # No steps are being tracked initially + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps not divisible by 10 should not be saved + for step in range(1, 10): + self.assertFalse(chkpt_mgr.save(step, state_dict)) + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps divisible by 10 should be saved + saved = set() + for step in range(0, 100, 10): + self.assertTrue(chkpt_mgr.save(step, state_dict)) + saved.add(step) + self.assertEqual(set(chkpt_mgr.all_steps()), saved) + + @run_with_tmpdir + def test_manager_max_to_keep(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10, max_to_keep=2) + state_dict = self._get_sharded_model().state_dict() + + # No steps are being tracked initially + self.assertEqual(chkpt_mgr.all_steps(), []) + + self.assertTrue(chkpt_mgr.save(10, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {10}) + + self.assertTrue(chkpt_mgr.save(20, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {10, 20}) + + # The oldest checkpoint should be erased + self.assertTrue(chkpt_mgr.save(30, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {30, 20}) + + # The oldest is selected by creation timestamp, not step + self.assertTrue(chkpt_mgr.save(10, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {30, 10}) + + # The deletion order should persist across executions + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10, max_to_keep=2) + self.assertTrue(chkpt_mgr.save(20, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10}) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index cd36cbe1eb6..abf56bff940 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -1,9 +1,35 @@ +import fsspec +import logging +import os +import pickle +import torch.distributed as dist import torch.distributed.checkpoint as dist_cp +import torch_xla.runtime as xr import torch_xla.experimental.distributed_checkpoint as xc -from typing import List, Optional +from dataclasses import dataclass +from datetime import datetime +from collections import deque +from fsspec.core import url_to_fs +from os.path import basename +from typing import Deque, List, Optional, Union from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +# TODO(jonbolin): Import path will change +from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter + +# File to track manager-specific metadata within each checkpoint path +_MANAGER_METADATA_FILE = '.manager_metadata' + + +@dataclass +class _CheckpointMetadata: + # The step at which the checkpoint was taken + step: int + + # The time at which the checkpoint was taken + ts: datetime + class CheckpointManager: """ @@ -53,10 +79,21 @@ class CheckpointManager: https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py """ + # The base path to write checkpoints to. Each checkpoint taken by the manager + # will be written into a subdirectory of this path, identified by the + # checkpoint's step. + base_path: Union[str, os.PathLike] + + # The interval to take checkpoints, in steps. + save_interval: int + + # The maximum number of checkpoints to keep. + max_to_keep: int + def __init__(self, path: str, - save_period: int, - max_to_keep: Optional[int] = -1, + save_interval: int, + max_to_keep: Optional[int] = 0, async_queue_size: Optional[int] = 1): """ Create a checkpoint manager that reads and writes checkpoints into @@ -64,11 +101,11 @@ def __init__(self, Args: path: The base path for the CheckpointManager to write checkpoints into. - save_period: The number of steps between saving checkpoints. + save_interval: The number of steps between saving checkpoints. max_to_keep: The maximum number of checkpoints to be tracked by the CheckpointManager. When a new checkpoint will be taken, the checkpoint for the lowest tracked step will be deleted. - Default: -1, indicating no upper bound on the number of checkpoints. + Default: 0, indicating no upper bound on the number of checkpoints. async_queue_size: The size of the execution queue which processes async checkpoints. This should be a small value to ensure training doesn't get too far ahead of the last finished checkpoint, but increasing @@ -77,14 +114,61 @@ def __init__(self, Default: 1, which only allows a single async checkpoint to be pending at a time. """ - raise NotImplementedError + assert dist.is_initialized(), "A process group is required." + assert save_interval > 0, "save_interval must be positive" + assert async_queue_size > 0, "async_queue_size must be positive" + assert max_to_keep >= 0, "max_to_keep must be non-negative" + + self.base_path = path + self.save_interval = save_interval + self.max_to_keep = max_to_keep + + self._tracked_chkpts = self._load_tracked_chkpts() + + def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: + """ + Loads a list of all tracked checkpoints from the storage backend. + """ + all_chkpts = [] + invalid_paths = [] + fs, raw_path = url_to_fs(self.base_path) + for path in fs.ls(raw_path, detail=False): + try: + with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'rb') as f: + all_chkpts.append(pickle.load(f)) + except: + invalid_paths.append(path) + + if invalid_paths: + logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}') + return deque(sorted(all_chkpts, key=lambda m: m.ts)) + + def _get_path(self, step: int) -> str: + return os.path.join(self.base_path, str(step)) + + def _delete_chkpt_at_step(self, step): + path = self._get_path(step) + fs, raw_path = url_to_fs(path) + if fs.exists(raw_path): + fs.rm(raw_path, recursive=True) + + def _release_oldest_checkpoints(self): + """ + Delete oldest checkpoints until the number of tracked checkpoints is below + self.max_to_keep. This operation is only execution on the rank 0 process. + """ + if dist.get_rank() == 0 and self.max_to_keep > 0: + while len(self._tracked_chkpts) > self.max_to_keep: + oldest_chkpt = self._tracked_chkpts.popleft() + self._delete_chkpt_at_step(oldest_chkpt.step) def should_save(self, step: int) -> bool: """ Returns true if a checkpoint should be saved for the current step or if a preemption has been detected. """ - raise NotImplementedError + # TODO(jonbolin): Support preemption notice for auto checkpointing + return step % self.save_interval == 0 def save(self, step, @@ -101,7 +185,22 @@ def save(self, Returns: True if a checkpoint was taken and False otherwise. """ - raise NotImplementedError + if self.should_save(step) or force: + path = self._get_path(step) + # Delete any existing checkpoint at the current step. + self._delete_chkpt_at_step(step) + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=FsspecWriter(path), + planner=xc.SPMDSavePlanner(), + ) + metadata = _CheckpointMetadata(step=step, ts=datetime.now()) + with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f: + pickle.dump(metadata, f) + self._tracked_chkpts.append(metadata) + self._release_oldest_checkpoints() + return True + return False def save_async(self, step: int, @@ -139,10 +238,17 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None: state_dict: The state dict to restore the checkpoint into. Values are updated in-place within the state_dict. """ - raise NotImplementedError + tracked_steps = set(x.step for x in self._tracked_chkpts) + assert step in tracked_steps, f'Cannot restore from untracked step {step}. Valid steps are: {tracked_steps}' + path = self._get_path(step) + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader=FsspecReader(path), + planner=xc.SPMDLoadPlanner(), + ) def all_steps(self) -> List[int]: """ List all steps tracked by the CheckpointManager. """ - raise NotImplementedError + return sorted(x.step for x in self._tracked_chkpts)