Skip to content

Commit

Permalink
Track creation time in metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 12, 2023
1 parent d248eec commit 4430877
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 17 deletions.
9 changes: 9 additions & 0 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,15 @@ def test_manager_max_to_keep(self, tmpdir):
self.assertTrue(chkpt_mgr.save(30, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {30, 20})

# The oldest is selected by creation timestamp, not step
self.assertTrue(chkpt_mgr.save(10, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {30, 10})

# The deletion order should persist across executions
chkpt_mgr = CheckpointManager(tmpdir, save_period=10, max_to_keep=2)
self.assertTrue(chkpt_mgr.save(20, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10})


if __name__ == '__main__':
test = unittest.main()
Expand Down
71 changes: 54 additions & 17 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,35 @@
import heapq
import fsspec
import logging
import os
import pickle
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla.runtime as xr
import torch_xla.experimental.distributed_checkpoint as xc

from dataclasses import dataclass
from datetime import datetime
from collections import deque
from fsspec.core import url_to_fs
from os.path import basename
from typing import List, Optional, Union
from typing import Deque, List, Optional, Union
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE

# TODO(jonbolin): Import path will change
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter

# File to track manager-specific metadata within each checkpoint path
_MANAGER_METADATA_FILE = '.manager_metadata'


@dataclass
class _CheckpointMetadata:
# The step at which the checkpoint was taken
step: int

# The time at which the checkpoint was taken
ts: datetime


class CheckpointManager:
"""
Expand Down Expand Up @@ -106,31 +123,44 @@ def __init__(self,
self.save_period = save_period
self.max_to_keep = max_to_keep

# Cache tracked steps in a heap for efficient clearing.
self._tracked_steps = self._load_tracked_steps()
heapq.heapify(self._tracked_steps)
self._tracked_chkpts = self._load_tracked_chkpts()

def _load_tracked_steps(self) -> List[int]:
""" Loads a list of all tracked steps from the storage backend. """
def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
"""
Loads a list of all tracked checkpoints from the storage backend.
"""
all_chkpts = []
invalid_paths = []
fs, raw_path = url_to_fs(self.base_path)
all_paths = fs.ls(raw_path, detail=False)
all_steps = map(basename, all_paths)
return list(map(int, all_steps))
for path in fs.ls(raw_path, detail=False):
try:
with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'rb') as f:
all_chkpts.append(pickle.load(f))
except:
invalid_paths.append(path)

if invalid_paths:
logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}')
return deque(sorted(all_chkpts, key=lambda m: m.ts))

def _get_path(self, step: int) -> str:
return os.path.join(self.base_path, str(step))

def _delete_chkpt_at_step(self, step):
path = self._get_path(step)
fs, raw_path = url_to_fs(path)
if fs.exists(raw_path):
fs.rm(raw_path, recursive=True)

def _release_oldest_checkpoints(self):
"""
Delete oldest checkpoints until the number of tracked checkpoints is below
self.max_to_keep. This operation is only execution on the rank 0 process.
"""
if dist.get_rank() == 0 and self.max_to_keep > 0:
while len(self._tracked_steps) > self.max_to_keep:
oldest_step = heapq.heappop(self._tracked_steps)
path = self._get_path(oldest_step)
fs, raw_path = url_to_fs(path)
fs.rm(raw_path, recursive=True)
while len(self._tracked_chkpts) > self.max_to_keep:
oldest_chkpt = self._tracked_chkpts.popleft()
self._delete_chkpt_at_step(oldest_chkpt.step)

def should_save(self, step: int) -> bool:
"""
Expand All @@ -157,12 +187,17 @@ def save(self,
"""
if self.should_save(step) or force:
path = self._get_path(step)
# Delete any existing checkpoint at the current step.
self._delete_chkpt_at_step(step)
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=FsspecWriter(path),
planner=xc.SPMDSavePlanner(),
)
heapq.heappush(self._tracked_steps, step)
metadata = _CheckpointMetadata(step=step, ts=datetime.now())
with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f:
pickle.dump(metadata, f)
self._tracked_chkpts.append(metadata)
self._release_oldest_checkpoints()
return True
return False
Expand Down Expand Up @@ -203,6 +238,8 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
state_dict: The state dict to restore the checkpoint into. Values are
updated in-place within the state_dict.
"""
tracked_steps = set(x.step for x in self._tracked_chkpts)
assert step in tracked_steps, f'Cannot restore from untracked step {step}. Valid steps are: {tracked_steps}'
path = self._get_path(step)
dist_cp.load_state_dict(
state_dict=state_dict,
Expand All @@ -214,4 +251,4 @@ def all_steps(self) -> List[int]:
"""
List all steps tracked by the CheckpointManager.
"""
return sorted(self._tracked_steps)
return sorted(x.step for x in self._tracked_chkpts)

0 comments on commit 4430877

Please sign in to comment.