Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support synchronous saving and loading in CheckpointManager #5693

Merged
merged 7 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import os
import sys
import tempfile
Expand All @@ -15,11 +16,23 @@
create_default_local_save_plan,
create_default_global_save_plan,
)
from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner
from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner, CheckpointManager
from torch_xla.experimental.distributed_checkpoint._helpers import (
_sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor)


# Wrapper to manage a temporary directory for the wrapped test
def run_with_tmpdir(f):

@functools.wraps(f)
def run(*args, **kwargs):
with tempfile.TemporaryDirectory() as tmpdir:
kwargs.setdefault('tmpdir', tmpdir)
f(*args, **kwargs)

return run


class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest):

@classmethod
Expand Down Expand Up @@ -319,6 +332,87 @@ def test_sharded_cpu_state_dict(self):
self.assertTrue(param.device == torch.device("cpu"))


class CheckpointManagerTest(DistributedCheckpointTestBase):

def setUp(self):
super().setUp()
# Initialize the a minimal process group
dist.init_process_group(
jonb377 marked this conversation as resolved.
Show resolved Hide resolved
backend='gloo', init_method='tcp://127.1:8932', world_size=1, rank=0)

def tearDown(self):
super().tearDown()
# Destroy the CPU process group after the test
dist.destroy_process_group()

@run_with_tmpdir
def test_manager_checkpointing(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_period=10)
state_dict = self._get_sharded_model().state_dict()

# Take a checkpoint on step 0
self.assertTrue(chkpt_mgr.save(0, state_dict))

# Load the checkpoint into a new state_dict
new_state_dict = self._get_sharded_model().state_dict()
self.assertFalse(
any(
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))
chkpt_mgr.restore(0, new_state_dict)
self.assertTrue(
all(
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))

@run_with_tmpdir
def test_manager_step_tracking(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_period=10)
state_dict = self._get_sharded_model().state_dict()

# No steps are being tracked initially
self.assertEqual(chkpt_mgr.all_steps(), [])

# Steps not divisible by 10 should not be saved
for step in range(1, 10):
self.assertFalse(chkpt_mgr.save(step, state_dict))
self.assertEqual(chkpt_mgr.all_steps(), [])

# Steps divisible by 10 should be saved
saved = set()
for step in range(0, 100, 10):
self.assertTrue(chkpt_mgr.save(step, state_dict))
saved.add(step)
self.assertEqual(set(chkpt_mgr.all_steps()), saved)

@run_with_tmpdir
def test_manager_max_to_keep(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_period=10, max_to_keep=2)
state_dict = self._get_sharded_model().state_dict()

# No steps are being tracked initially
self.assertEqual(chkpt_mgr.all_steps(), [])

self.assertTrue(chkpt_mgr.save(10, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {10})

self.assertTrue(chkpt_mgr.save(20, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {10, 20})

# The oldest checkpoint should be erased
self.assertTrue(chkpt_mgr.save(30, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {30, 20})

# The oldest is selected by creation timestamp, not step
self.assertTrue(chkpt_mgr.save(10, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {30, 10})

# The deletion order should persist across executions
chkpt_mgr = CheckpointManager(tmpdir, save_period=10, max_to_keep=2)
self.assertTrue(chkpt_mgr.save(20, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10})


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
122 changes: 114 additions & 8 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,35 @@
import fsspec
import logging
import os
import pickle
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla.runtime as xr
import torch_xla.experimental.distributed_checkpoint as xc

from typing import List, Optional
from dataclasses import dataclass
from datetime import datetime
from collections import deque
from fsspec.core import url_to_fs
from os.path import basename
from typing import Deque, List, Optional, Union
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE

# TODO(jonbolin): Import path will change
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The import path will change when the API becomes public in the upstream. @alanwaketan @yeounoh do you have any thoughts on how to handle this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's okay. The upstream test will break our CI in the upstream, and then we can have a companion change to fix it.


# File to track manager-specific metadata within each checkpoint path
_MANAGER_METADATA_FILE = '.manager_metadata'


@dataclass
class _CheckpointMetadata:
# The step at which the checkpoint was taken
step: int

# The time at which the checkpoint was taken
ts: datetime


class CheckpointManager:
"""
Expand Down Expand Up @@ -53,10 +79,21 @@ class CheckpointManager:
https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py
"""

# The base path to write checkpoints to. Each checkpoint taken by the manager
# will be written into a subdirectory of this path, identified by the
# checkpoint's step.
base_path: Union[str, os.PathLike]

# The period to take checkpoints, in steps.
save_period: int
jonb377 marked this conversation as resolved.
Show resolved Hide resolved

# The maximum number of checkpoints to keep.
max_to_keep: int

def __init__(self,
path: str,
save_period: int,
max_to_keep: Optional[int] = -1,
max_to_keep: Optional[int] = 0,
async_queue_size: Optional[int] = 1):
"""
Create a checkpoint manager that reads and writes checkpoints into
Expand All @@ -68,7 +105,7 @@ def __init__(self,
max_to_keep: The maximum number of checkpoints to be tracked by the
CheckpointManager. When a new checkpoint will be taken, the
checkpoint for the lowest tracked step will be deleted.
Default: -1, indicating no upper bound on the number of checkpoints.
Default: 0, indicating no upper bound on the number of checkpoints.
async_queue_size: The size of the execution queue which processes async
checkpoints. This should be a small value to ensure training doesn't
get too far ahead of the last finished checkpoint, but increasing
Expand All @@ -77,14 +114,61 @@ def __init__(self,
Default: 1, which only allows a single async checkpoint to be
pending at a time.
"""
raise NotImplementedError
assert dist.is_initialized(), "A process group is required."
assert save_period > 0, "save_period must be positive"
assert async_queue_size > 0, "async_queue_size must be positive"
assert max_to_keep >= 0, "max_to_keep must be non-negative"

self.base_path = path
self.save_period = save_period
self.max_to_keep = max_to_keep

self._tracked_chkpts = self._load_tracked_chkpts()

def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
"""
Loads a list of all tracked checkpoints from the storage backend.
"""
all_chkpts = []
invalid_paths = []
fs, raw_path = url_to_fs(self.base_path)
for path in fs.ls(raw_path, detail=False):
try:
with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'rb') as f:
all_chkpts.append(pickle.load(f))
except:
invalid_paths.append(path)

if invalid_paths:
logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}')
return deque(sorted(all_chkpts, key=lambda m: m.ts))

def _get_path(self, step: int) -> str:
return os.path.join(self.base_path, str(step))

def _delete_chkpt_at_step(self, step):
path = self._get_path(step)
fs, raw_path = url_to_fs(path)
if fs.exists(raw_path):
fs.rm(raw_path, recursive=True)

def _release_oldest_checkpoints(self):
"""
Delete oldest checkpoints until the number of tracked checkpoints is below
self.max_to_keep. This operation is only execution on the rank 0 process.
"""
if dist.get_rank() == 0 and self.max_to_keep > 0:
while len(self._tracked_chkpts) > self.max_to_keep:
oldest_chkpt = self._tracked_chkpts.popleft()
self._delete_chkpt_at_step(oldest_chkpt.step)

def should_save(self, step: int) -> bool:
"""
Returns true if a checkpoint should be saved for the current step or if
a preemption has been detected.
"""
raise NotImplementedError
# TODO(jonbolin): Support preemption notice for auto checkpointing
return step % self.save_period == 0

def save(self,
step,
Expand All @@ -101,7 +185,22 @@ def save(self,
Returns:
True if a checkpoint was taken and False otherwise.
"""
raise NotImplementedError
if self.should_save(step) or force:
path = self._get_path(step)
# Delete any existing checkpoint at the current step.
self._delete_chkpt_at_step(step)
dist_cp.save_state_dict(
state_dict=state_dict,
storage_writer=FsspecWriter(path),
planner=xc.SPMDSavePlanner(),
)
metadata = _CheckpointMetadata(step=step, ts=datetime.now())
with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f:
pickle.dump(metadata, f)
self._tracked_chkpts.append(metadata)
self._release_oldest_checkpoints()
return True
return False

def save_async(self,
step: int,
Expand Down Expand Up @@ -139,10 +238,17 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
state_dict: The state dict to restore the checkpoint into. Values are
updated in-place within the state_dict.
"""
raise NotImplementedError
tracked_steps = set(x.step for x in self._tracked_chkpts)
assert step in tracked_steps, f'Cannot restore from untracked step {step}. Valid steps are: {tracked_steps}'
path = self._get_path(step)
dist_cp.load_state_dict(
state_dict=state_dict,
storage_reader=FsspecReader(path),
planner=xc.SPMDLoadPlanner(),
)

def all_steps(self) -> List[int]:
"""
List all steps tracked by the CheckpointManager.
"""
raise NotImplementedError
return sorted(x.step for x in self._tracked_chkpts)