Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support async checkpointing through CheckpointManager #5697

Merged
merged 4 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
import unittest
import test_xla_sharding_base
import threading

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -412,6 +413,70 @@ def test_manager_max_to_keep(self, tmpdir):
self.assertTrue(chkpt_mgr.save(20, state_dict))
self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10})

@run_with_tmpdir
def test_manager_async(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10)
state_dict = self._get_sharded_model().state_dict()

# Patch the manager's save method to block until this thread signals.
cond = threading.Condition()
old_save = chkpt_mgr.save

def patched_save(*args, **kwargs):
cond.wait()
old_save(*args, **kwargs)

with unittest.mock.patch.object(chkpt_mgr, 'save', patched_save):
chkpt_mgr.save_async(10, state_dict)

# No new steps should be tracked immediately after calling save_async
self.assertEqual(chkpt_mgr.all_steps(), [])

# Trigger the actual checkpoint in the background thread and wait for
# completion.
with cond:
cond.notify()
chkpt_mgr.join()

# The manager should track all steps which were asynchronously saved.
self.assertEqual(set(chkpt_mgr.all_steps()), {10})

@run_with_tmpdir
def test_manager_async_step_tracking(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_interval=10)
state_dict = self._get_sharded_model().state_dict()

self.assertEqual(chkpt_mgr.all_steps(), [])

# Steps not divisible by 10 should not be saved
for step in range(1, 10):
self.assertFalse(chkpt_mgr.save_async(step, state_dict))
self.assertEqual(chkpt_mgr.all_steps(), [])

# Steps divisible by 10 should be saved
saved = set()
for step in range(0, 100, 10):
self.assertTrue(chkpt_mgr.save_async(step, state_dict))
saved.add(step)

# Join to allow pending async checkpoints to complete
chkpt_mgr.join()

# The manager should track all steps which were asynchronously saved.
self.assertEqual(set(chkpt_mgr.all_steps()), saved)

# Load a checkpoint into a new state_dict
new_state_dict = self._get_sharded_model().state_dict()
self.assertFalse(
any(
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))
chkpt_mgr.restore(0, new_state_dict)
self.assertTrue(
all(
torch.allclose(v, new_state_dict[k])
for k, v in state_dict.items()))


if __name__ == '__main__':
test = unittest.main()
Expand Down
54 changes: 51 additions & 3 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import logging
import os
import pickle
import queue
import threading
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch_xla.runtime as xr
import torch_xla.experimental.distributed_checkpoint as xc
import traceback

from dataclasses import dataclass
from datetime import datetime
Expand All @@ -14,6 +17,7 @@
from os.path import basename
from typing import Deque, List, Optional, Union
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from ._helpers import _sharded_cpu_state_dict

# TODO(jonbolin): Import path will change
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
Expand Down Expand Up @@ -94,7 +98,8 @@ def __init__(self,
path: str,
save_interval: int,
max_to_keep: Optional[int] = 0,
async_queue_size: Optional[int] = 1):
async_queue_size: Optional[int] = 1,
process_group: dist.ProcessGroup = None):
"""
Create a checkpoint manager that reads and writes checkpoints into
the provided directory.
Expand All @@ -113,6 +118,9 @@ def __init__(self,
network issues which slow down the active checkpoint.
Default: 1, which only allows a single async checkpoint to be
pending at a time.
process_group: The process group to use when coordinating the checkpoint.
Default: None, in which case a subgroup of the default process
group will be created.
"""
assert dist.is_initialized(), "A process group is required."
assert save_interval > 0, "save_interval must be positive"
Expand All @@ -124,6 +132,16 @@ def __init__(self,
self.max_to_keep = max_to_keep

self._tracked_chkpts = self._load_tracked_chkpts()
self._async_queue = queue.Queue(maxsize=async_queue_size)
self._alive = threading.Event()
self._alive.set()
self._chkpt_thread = threading.Thread(
target=self._async_worker, daemon=True)
self._chkpt_thread.start()

# Create a new group if none is provided
# TODO(jonbolin): Verify subgroup on GPU backend
self.pg = process_group or dist.new_group()

def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
"""
Expand All @@ -143,6 +161,25 @@ def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}')
return deque(sorted(all_chkpts, key=lambda m: m.ts))

def __del__(self):
self._alive.clear()
# Send a sentinel value to tell the worker to exit, and wait for pending
# checkpoints to complete.
self._async_queue.put(None)
self._chkpt_thread.join()

def _async_worker(self):
while self._alive.is_set():
try:
item = self._async_queue.get()
if item:
step, state_dict = item
self.save(step, state_dict, force=True)
except:
traceback.print_exc()
finally:
self._async_queue.task_done()

def _get_path(self, step: int) -> str:
return os.path.join(self.base_path, str(step))

Expand All @@ -157,7 +194,7 @@ def _release_oldest_checkpoints(self):
Delete oldest checkpoints until the number of tracked checkpoints is below
self.max_to_keep. This operation is only execution on the rank 0 process.
"""
if dist.get_rank() == 0 and self.max_to_keep > 0:
if dist.get_rank(self.pg) == 0 and self.max_to_keep > 0:
while len(self._tracked_chkpts) > self.max_to_keep:
oldest_chkpt = self._tracked_chkpts.popleft()
self._delete_chkpt_at_step(oldest_chkpt.step)
Expand Down Expand Up @@ -193,6 +230,7 @@ def save(self,
state_dict=state_dict,
storage_writer=FsspecWriter(path),
planner=xc.SPMDSavePlanner(),
process_group=self.pg,
)
metadata = _CheckpointMetadata(step=step, ts=datetime.now())
with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f:
Expand Down Expand Up @@ -225,7 +263,12 @@ def save_async(self,
Returns:
True if a checkpoint was taken and False otherwise.
"""
raise NotImplementedError
if self.should_save(step) or force:
# Move the state_dict to CPU
cpu_state_dict = _sharded_cpu_state_dict(state_dict)
self._async_queue.put((step, cpu_state_dict))
jonb377 marked this conversation as resolved.
Show resolved Hide resolved
return True
return False

def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
"""
Expand All @@ -245,10 +288,15 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None:
state_dict=state_dict,
storage_reader=FsspecReader(path),
planner=xc.SPMDLoadPlanner(),
process_group=self.pg,
)

def all_steps(self) -> List[int]:
"""
List all steps tracked by the CheckpointManager.
"""
return sorted(x.step for x in self._tracked_chkpts)

def join(self):
""" Wait for all pending async checkpoints to complete. """
self._async_queue.join()