From cf1968c0a533e9d91904dcdb67ddf61710f4d6e1 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Tue, 10 Oct 2023 23:27:28 +0000 Subject: [PATCH 1/4] Support async checkpointing through CheckpointManager --- test/spmd/test_xla_distributed_checkpoint.py | 38 +++++++++++++++++++ .../distributed_checkpoint/manager.py | 34 ++++++++++++++++- 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 74bc27cdf98..f763c041bc0 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -412,6 +412,44 @@ def test_manager_max_to_keep(self, tmpdir): self.assertTrue(chkpt_mgr.save(20, state_dict)) self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10}) + @run_with_tmpdir + def test_manager_async_checkpoint(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + state_dict = self._get_sharded_model().state_dict() + + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps not divisible by 10 should not be saved + for step in range(1, 10): + self.assertFalse(chkpt_mgr.save_async(step, state_dict)) + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps divisible by 10 should be saved + saved = set() + for step in range(0, 100, 10): + self.assertTrue(chkpt_mgr.save_async(step, state_dict)) + saved.add(step) + + # Delete the checkpoint manager to block this thread until all pending + # async checkpoints are complete. + del chkpt_mgr + + # The manager should track all steps which were asynchronously saved. + chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + self.assertEqual(set(chkpt_mgr.all_steps()), saved) + + # Load a checkpoint into a new state_dict + new_state_dict = self._get_sharded_model().state_dict() + self.assertFalse( + any( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + chkpt_mgr.restore(0, new_state_dict) + self.assertTrue( + all( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index abf56bff940..c41a403514d 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -2,10 +2,13 @@ import logging import os import pickle +import queue +import threading import torch.distributed as dist import torch.distributed.checkpoint as dist_cp import torch_xla.runtime as xr import torch_xla.experimental.distributed_checkpoint as xc +import traceback from dataclasses import dataclass from datetime import datetime @@ -14,6 +17,7 @@ from os.path import basename from typing import Deque, List, Optional, Union from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +from ._helpers import _sharded_cpu_state_dict # TODO(jonbolin): Import path will change from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter @@ -125,6 +129,13 @@ def __init__(self, self._tracked_chkpts = self._load_tracked_chkpts() + self._async_queue = queue.Queue(maxsize=async_queue_size) + self._chkpt_thread = threading.Thread(target=self._async_worker, daemon=True) + self._chkpt_thread.start() + + # Create a CPU process group to coordinate the checkpoint. + self.pg = dist.new_group(backend='gloo') + def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: """ Loads a list of all tracked checkpoints from the storage backend. @@ -143,6 +154,20 @@ def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}') return deque(sorted(all_chkpts, key=lambda m: m.ts)) + def __del__(self): + # Ensure pending checkpoints are finished + self._async_queue.join() + + def _async_worker(self): + while True: + try: + step, state_dict = self._async_queue.get() + self.save(step, state_dict, force=True) + except: + traceback.print_exc() + finally: + self._async_queue.task_done() + def _get_path(self, step: int) -> str: return os.path.join(self.base_path, str(step)) @@ -193,6 +218,7 @@ def save(self, state_dict=state_dict, storage_writer=FsspecWriter(path), planner=xc.SPMDSavePlanner(), + process_group=self.pg, ) metadata = _CheckpointMetadata(step=step, ts=datetime.now()) with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f: @@ -225,7 +251,12 @@ def save_async(self, Returns: True if a checkpoint was taken and False otherwise. """ - raise NotImplementedError + if self.should_save(step) or force: + # Move the state_dict to CPU + cpu_state_dict = _sharded_cpu_state_dict(state_dict) + self._async_queue.put((step, cpu_state_dict)) + return True + return False def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None: """ @@ -245,6 +276,7 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None: state_dict=state_dict, storage_reader=FsspecReader(path), planner=xc.SPMDLoadPlanner(), + process_group=self.pg, ) def all_steps(self) -> List[int]: From 7b6c94b3c2053f46aa08296501898073cfb1d3b3 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Thu, 12 Oct 2023 02:09:30 +0000 Subject: [PATCH 2/4] Allow threads to exit when CheckpointManager is freed --- test/spmd/test_xla_distributed_checkpoint.py | 37 ++++++++++++++++--- .../distributed_checkpoint/manager.py | 35 +++++++++++++----- 2 files changed, 57 insertions(+), 15 deletions(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index f763c041bc0..044cc968c0b 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -4,6 +4,7 @@ import tempfile import unittest import test_xla_sharding_base +import threading import torch import torch.distributed as dist @@ -413,7 +414,35 @@ def test_manager_max_to_keep(self, tmpdir): self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10}) @run_with_tmpdir - def test_manager_async_checkpoint(self, tmpdir): + def test_manager_async(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + state_dict = self._get_sharded_model().state_dict() + + # Patch the manager's save method to block until this thread signals. + cond = threading.Condition() + old_save = chkpt_mgr.save + + def patched_save(*args, **kwargs): + cond.wait() + old_save(*args, **kwargs) + + with unittest.mock.patch.object(chkpt_mgr, 'save', patched_save): + chkpt_mgr.save_async(10, state_dict) + + # No new steps should be tracked immediately after calling save_async + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Trigger the actual checkpoint in the background thread and wait for + # completion. + with cond: + cond.notify() + chkpt_mgr.join() + + # The manager should track all steps which were asynchronously saved. + self.assertEqual(set(chkpt_mgr.all_steps()), {10}) + + @run_with_tmpdir + def test_manager_async_step_tracking(self, tmpdir): chkpt_mgr = CheckpointManager(tmpdir, save_period=10) state_dict = self._get_sharded_model().state_dict() @@ -430,12 +459,10 @@ def test_manager_async_checkpoint(self, tmpdir): self.assertTrue(chkpt_mgr.save_async(step, state_dict)) saved.add(step) - # Delete the checkpoint manager to block this thread until all pending - # async checkpoints are complete. - del chkpt_mgr + # Join to allow pending async checkpoints to complete + chkpt_mgr.join() # The manager should track all steps which were asynchronously saved. - chkpt_mgr = CheckpointManager(tmpdir, save_period=10) self.assertEqual(set(chkpt_mgr.all_steps()), saved) # Load a checkpoint into a new state_dict diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index c41a403514d..2b87ab10bba 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -98,7 +98,8 @@ def __init__(self, path: str, save_interval: int, max_to_keep: Optional[int] = 0, - async_queue_size: Optional[int] = 1): + async_queue_size: Optional[int] = 1, + process_group: dist.ProcessGroup = None): """ Create a checkpoint manager that reads and writes checkpoints into the provided directory. @@ -117,6 +118,9 @@ def __init__(self, network issues which slow down the active checkpoint. Default: 1, which only allows a single async checkpoint to be pending at a time. + process_group: The process group to use when coordinating the checkpoint. + Default: None, in which case a subgroup of the default process + group will be created. """ assert dist.is_initialized(), "A process group is required." assert save_interval > 0, "save_interval must be positive" @@ -128,13 +132,15 @@ def __init__(self, self.max_to_keep = max_to_keep self._tracked_chkpts = self._load_tracked_chkpts() - self._async_queue = queue.Queue(maxsize=async_queue_size) - self._chkpt_thread = threading.Thread(target=self._async_worker, daemon=True) + self._alive = threading.Event() + self._alive.set() + self._chkpt_thread = threading.Thread( + target=self._async_worker, daemon=True) self._chkpt_thread.start() - # Create a CPU process group to coordinate the checkpoint. - self.pg = dist.new_group(backend='gloo') + # Create a new group if none is provided + self.pg = process_group or dist.new_group() def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: """ @@ -155,14 +161,19 @@ def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: return deque(sorted(all_chkpts, key=lambda m: m.ts)) def __del__(self): - # Ensure pending checkpoints are finished - self._async_queue.join() + self._alive.clear() + # Send a sentinel value to tell the worker to exit, and wait for pending + # checkpoints to complete. + self._async_queue.put(None) + self._chkpt_thread.join() def _async_worker(self): - while True: + while self._alive.is_set(): try: - step, state_dict = self._async_queue.get() - self.save(step, state_dict, force=True) + item = self._async_queue.get() + if item: + step, state_dict = item + self.save(step, state_dict, force=True) except: traceback.print_exc() finally: @@ -284,3 +295,7 @@ def all_steps(self) -> List[int]: List all steps tracked by the CheckpointManager. """ return sorted(x.step for x in self._tracked_chkpts) + + def join(self): + """ Wait for all pending async checkpoints to complete. """ + self._async_queue.join() From 702fbb5661089420b7e6b90e08bb0e391643a4db Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Thu, 12 Oct 2023 02:13:30 +0000 Subject: [PATCH 3/4] Use rank from tracked process group --- torch_xla/experimental/distributed_checkpoint/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index 2b87ab10bba..476bae5b882 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -193,7 +193,7 @@ def _release_oldest_checkpoints(self): Delete oldest checkpoints until the number of tracked checkpoints is below self.max_to_keep. This operation is only execution on the rank 0 process. """ - if dist.get_rank() == 0 and self.max_to_keep > 0: + if dist.get_rank(self.pg) == 0 and self.max_to_keep > 0: while len(self._tracked_chkpts) > self.max_to_keep: oldest_chkpt = self._tracked_chkpts.popleft() self._delete_chkpt_at_step(oldest_chkpt.step) From fdecad3b6441fbbea9a2c6d25b888ec040745354 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Thu, 12 Oct 2023 22:10:01 +0000 Subject: [PATCH 4/4] Add TODO --- test/spmd/test_xla_distributed_checkpoint.py | 4 ++-- torch_xla/experimental/distributed_checkpoint/manager.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 044cc968c0b..37c0224a0f4 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -415,7 +415,7 @@ def test_manager_max_to_keep(self, tmpdir): @run_with_tmpdir def test_manager_async(self, tmpdir): - chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) state_dict = self._get_sharded_model().state_dict() # Patch the manager's save method to block until this thread signals. @@ -443,7 +443,7 @@ def patched_save(*args, **kwargs): @run_with_tmpdir def test_manager_async_step_tracking(self, tmpdir): - chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) state_dict = self._get_sharded_model().state_dict() self.assertEqual(chkpt_mgr.all_steps(), []) diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index 476bae5b882..9e5cde711b8 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -140,6 +140,7 @@ def __init__(self, self._chkpt_thread.start() # Create a new group if none is provided + # TODO(jonbolin): Verify subgroup on GPU backend self.pg = process_group or dist.new_group() def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: