From 4f88d9f4ec9aee0cf46bf66f695a825d0736147f Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Tue, 10 Oct 2023 01:12:57 +0000 Subject: [PATCH 1/7] Support synchronous saving and loading in CheckpointManager --- test/spmd/test_xla_distributed_checkpoint.py | 87 ++++++++++++++++++- .../distributed_checkpoint/manager.py | 67 ++++++++++++-- 2 files changed, 147 insertions(+), 7 deletions(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 276571e5979..f51d7dd95f5 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -1,3 +1,4 @@ +import functools import os import sys import tempfile @@ -15,11 +16,23 @@ create_default_local_save_plan, create_default_global_save_plan, ) -from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner +from torch_xla.experimental.distributed_checkpoint import SPMDLoadPlanner, SPMDSavePlanner, CheckpointManager from torch_xla.experimental.distributed_checkpoint._helpers import ( _sharded_cpu_state_dict, _CpuShards, _is_sharded_tensor) +# Wrapper to manage a temporary directory for the wrapped test +def run_with_tmpdir(f): + + @functools.wraps(f) + def run(*args, **kwargs): + assert 'tmpdir' not in kwargs + with tempfile.TemporaryDirectory() as tmpdir: + f(*args, **kwargs, tmpdir=tmpdir) + + return run + + class DistributedCheckpointTestBase(test_xla_sharding_base.XlaShardingTest): @classmethod @@ -319,6 +332,78 @@ def test_sharded_cpu_state_dict(self): self.assertTrue(param.device == torch.device("cpu")) +class CheckpointManagerTest(DistributedCheckpointTestBase): + + def setUp(self): + super().setUp() + # Initialize the a minimal process group + dist.init_process_group( + init_method='tcp://127.1:8932', world_size=1, rank=0) + + def tearDown(self): + super().tearDown() + # Destroy the CPU process group after the test + dist.destroy_process_group() + + @run_with_tmpdir + def test_manager_checkpointing(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + state_dict = self._get_sharded_model().state_dict() + + # Take a checkpoint on step 0 + self.assertTrue(chkpt_mgr.save(0, state_dict)) + + # Load the checkpoint into a new state_dict + new_state_dict = self._get_sharded_model().state_dict() + self.assertFalse( + any( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + chkpt_mgr.restore(0, new_state_dict) + self.assertTrue( + all( + torch.allclose(v, new_state_dict[k]) + for k, v in state_dict.items())) + + @run_with_tmpdir + def test_manager_step_tracking(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + state_dict = self._get_sharded_model().state_dict() + + # No steps are being tracked initially + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps not divisible by 10 should not be saved + for step in range(1, 10): + self.assertFalse(chkpt_mgr.save(step, state_dict)) + self.assertEqual(chkpt_mgr.all_steps(), []) + + # Steps divisible by 10 should be saved + saved = set() + for step in range(0, 100, 10): + self.assertTrue(chkpt_mgr.save(step, state_dict)) + saved.add(step) + self.assertEqual(set(chkpt_mgr.all_steps()), saved) + + @run_with_tmpdir + def test_manager_max_to_keep(self, tmpdir): + chkpt_mgr = CheckpointManager(tmpdir, save_period=10, max_to_keep=2) + state_dict = self._get_sharded_model().state_dict() + + # No steps are being tracked initially + self.assertEqual(chkpt_mgr.all_steps(), []) + + self.assertTrue(chkpt_mgr.save(10, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {10}) + + self.assertTrue(chkpt_mgr.save(20, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {10, 20}) + + # The oldest checkpoint should be erased + self.assertTrue(chkpt_mgr.save(30, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {30, 20}) + + if __name__ == '__main__': test = unittest.main() sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index cd36cbe1eb6..0232f2b8f95 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -1,9 +1,17 @@ +import os +import torch.distributed as dist import torch.distributed.checkpoint as dist_cp +import torch_xla.runtime as xr import torch_xla.experimental.distributed_checkpoint as xc -from typing import List, Optional +from fsspec.core import url_to_fs +from os.path import basename +from typing import List, Optional, Union from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE +# TODO(jonbolin): Import path will change +from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter + class CheckpointManager: """ @@ -53,6 +61,14 @@ class CheckpointManager: https://github.com/google/orbax/blob/efc079c4e5b437782a80138913d322cb3ed365c7/checkpoint/orbax/checkpoint/checkpoint_manager.py """ + # The base path to write checkpoints to. Each checkpoint taken by the manager + # will be written into a subdirectory of this path, identified by the + # checkpoint's step. + base_path: Union[str, os.PathLike] + + # The period to take checkpoints, in steps. + save_period: int + def __init__(self, path: str, save_period: int, @@ -77,14 +93,36 @@ def __init__(self, Default: 1, which only allows a single async checkpoint to be pending at a time. """ - raise NotImplementedError + assert dist.is_initialized(), "A process group is required." + + self.base_path = path + self.save_period = save_period + self.max_to_keep = max_to_keep + self.async_queue_size = async_queue_size + assert self.save_period > 0, "save_period must be positive" + assert self.async_queue_size > 0, "async_queue_size must be positive" + assert self.max_to_keep != 0, "max_to_keep must be non-zero" + + def _get_path(self, step: int) -> str: + return os.path.join(self.base_path, str(step)) + + def _release_oldest_checkpoints(self): + if self.max_to_keep > 0: + tracked_steps = sorted(self.all_steps(), reverse=True) + while len(tracked_steps) > self.max_to_keep: + # Delete the oldest checkpoint step to free up space for the new one. + oldest_step = tracked_steps.pop() + path = self._get_path(oldest_step) + fs, raw_path = url_to_fs(path) + fs.rm(raw_path, recursive=True) def should_save(self, step: int) -> bool: """ Returns true if a checkpoint should be saved for the current step or if a preemption has been detected. """ - raise NotImplementedError + # TODO(jonbolin): Support preemption notice for auto checkpointing + return step % self.save_period == 0 def save(self, step, @@ -101,7 +139,16 @@ def save(self, Returns: True if a checkpoint was taken and False otherwise. """ - raise NotImplementedError + if self.should_save(step): + path = self._get_path(step) + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=FsspecWriter(path), + planner=xc.SPMDSavePlanner(), + ) + self._release_oldest_checkpoints() + return True + return False def save_async(self, step: int, @@ -139,10 +186,18 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None: state_dict: The state dict to restore the checkpoint into. Values are updated in-place within the state_dict. """ - raise NotImplementedError + path = self._get_path(step) + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader=FsspecReader(path), + planner=xc.SPMDLoadPlanner(), + ) def all_steps(self) -> List[int]: """ List all steps tracked by the CheckpointManager. """ - raise NotImplementedError + fs, raw_path = url_to_fs(self.base_path) + all_paths = fs.ls(raw_path, detail=False) + all_steps = map(basename, all_paths) + return list(map(int, all_steps)) From 7326cc8ec4607ac3f955ea5905c66768c92c53cb Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Tue, 10 Oct 2023 22:37:57 +0000 Subject: [PATCH 2/7] Use 0 to indicate no upper bound --- .../experimental/distributed_checkpoint/manager.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index 0232f2b8f95..dac61ecf4d6 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -69,10 +69,16 @@ class CheckpointManager: # The period to take checkpoints, in steps. save_period: int + # The maximum number of checkpoints to keep. + max_to_keep: int + + # The size of the queue which processes async checkpoints. + async_queue_size: int + def __init__(self, path: str, save_period: int, - max_to_keep: Optional[int] = -1, + max_to_keep: Optional[int] = 0, async_queue_size: Optional[int] = 1): """ Create a checkpoint manager that reads and writes checkpoints into @@ -84,7 +90,7 @@ def __init__(self, max_to_keep: The maximum number of checkpoints to be tracked by the CheckpointManager. When a new checkpoint will be taken, the checkpoint for the lowest tracked step will be deleted. - Default: -1, indicating no upper bound on the number of checkpoints. + Default: 0, indicating no upper bound on the number of checkpoints. async_queue_size: The size of the execution queue which processes async checkpoints. This should be a small value to ensure training doesn't get too far ahead of the last finished checkpoint, but increasing @@ -101,7 +107,7 @@ def __init__(self, self.async_queue_size = async_queue_size assert self.save_period > 0, "save_period must be positive" assert self.async_queue_size > 0, "async_queue_size must be positive" - assert self.max_to_keep != 0, "max_to_keep must be non-zero" + assert self.max_to_keep >= 0, "max_to_keep must be non-negative" def _get_path(self, step: int) -> str: return os.path.join(self.base_path, str(step)) @@ -110,7 +116,7 @@ def _release_oldest_checkpoints(self): if self.max_to_keep > 0: tracked_steps = sorted(self.all_steps(), reverse=True) while len(tracked_steps) > self.max_to_keep: - # Delete the oldest checkpoint step to free up space for the new one. + # Delete the oldest checkpoint step oldest_step = tracked_steps.pop() path = self._get_path(oldest_step) fs, raw_path = url_to_fs(path) From c3240d71f037af719f8ceee93f7db5df7689a935 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Tue, 10 Oct 2023 23:04:14 +0000 Subject: [PATCH 3/7] Don't track async_queue_size --- .../experimental/distributed_checkpoint/manager.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index dac61ecf4d6..9c6696ce019 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -72,9 +72,6 @@ class CheckpointManager: # The maximum number of checkpoints to keep. max_to_keep: int - # The size of the queue which processes async checkpoints. - async_queue_size: int - def __init__(self, path: str, save_period: int, @@ -100,14 +97,13 @@ def __init__(self, pending at a time. """ assert dist.is_initialized(), "A process group is required." + assert save_period > 0, "save_period must be positive" + assert async_queue_size > 0, "async_queue_size must be positive" + assert max_to_keep >= 0, "max_to_keep must be non-negative" self.base_path = path self.save_period = save_period self.max_to_keep = max_to_keep - self.async_queue_size = async_queue_size - assert self.save_period > 0, "save_period must be positive" - assert self.async_queue_size > 0, "async_queue_size must be positive" - assert self.max_to_keep >= 0, "max_to_keep must be non-negative" def _get_path(self, step: int) -> str: return os.path.join(self.base_path, str(step)) @@ -145,7 +141,7 @@ def save(self, Returns: True if a checkpoint was taken and False otherwise. """ - if self.should_save(step): + if self.should_save(step) or force: path = self._get_path(step) dist_cp.save_state_dict( state_dict=state_dict, From d248eeca36a0890546ccb55cbadc60d3ad96e19a Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Wed, 11 Oct 2023 22:32:45 +0000 Subject: [PATCH 4/7] Cache tracked steps locally --- test/spmd/test_xla_distributed_checkpoint.py | 6 ++-- .../distributed_checkpoint/manager.py | 30 +++++++++++++------ 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index f51d7dd95f5..c7dd119160b 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -26,9 +26,9 @@ def run_with_tmpdir(f): @functools.wraps(f) def run(*args, **kwargs): - assert 'tmpdir' not in kwargs with tempfile.TemporaryDirectory() as tmpdir: - f(*args, **kwargs, tmpdir=tmpdir) + kwargs.setdefault('tmpdir', tmpdir) + f(*args, **kwargs) return run @@ -338,7 +338,7 @@ def setUp(self): super().setUp() # Initialize the a minimal process group dist.init_process_group( - init_method='tcp://127.1:8932', world_size=1, rank=0) + backend='gloo', init_method='tcp://127.1:8932', world_size=1, rank=0) def tearDown(self): super().tearDown() diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index 9c6696ce019..a8d7a7d6576 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -1,3 +1,4 @@ +import heapq import os import torch.distributed as dist import torch.distributed.checkpoint as dist_cp @@ -105,15 +106,28 @@ def __init__(self, self.save_period = save_period self.max_to_keep = max_to_keep + # Cache tracked steps in a heap for efficient clearing. + self._tracked_steps = self._load_tracked_steps() + heapq.heapify(self._tracked_steps) + + def _load_tracked_steps(self) -> List[int]: + """ Loads a list of all tracked steps from the storage backend. """ + fs, raw_path = url_to_fs(self.base_path) + all_paths = fs.ls(raw_path, detail=False) + all_steps = map(basename, all_paths) + return list(map(int, all_steps)) + def _get_path(self, step: int) -> str: return os.path.join(self.base_path, str(step)) def _release_oldest_checkpoints(self): - if self.max_to_keep > 0: - tracked_steps = sorted(self.all_steps(), reverse=True) - while len(tracked_steps) > self.max_to_keep: - # Delete the oldest checkpoint step - oldest_step = tracked_steps.pop() + """ + Delete oldest checkpoints until the number of tracked checkpoints is below + self.max_to_keep. This operation is only execution on the rank 0 process. + """ + if dist.get_rank() == 0 and self.max_to_keep > 0: + while len(self._tracked_steps) > self.max_to_keep: + oldest_step = heapq.heappop(self._tracked_steps) path = self._get_path(oldest_step) fs, raw_path = url_to_fs(path) fs.rm(raw_path, recursive=True) @@ -148,6 +162,7 @@ def save(self, storage_writer=FsspecWriter(path), planner=xc.SPMDSavePlanner(), ) + heapq.heappush(self._tracked_steps, step) self._release_oldest_checkpoints() return True return False @@ -199,7 +214,4 @@ def all_steps(self) -> List[int]: """ List all steps tracked by the CheckpointManager. """ - fs, raw_path = url_to_fs(self.base_path) - all_paths = fs.ls(raw_path, detail=False) - all_steps = map(basename, all_paths) - return list(map(int, all_steps)) + return sorted(self._tracked_steps) From 44308775b5f7628b06ee61cba85af741e6e7f57b Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Thu, 12 Oct 2023 00:35:35 +0000 Subject: [PATCH 5/7] Track creation time in metadata --- test/spmd/test_xla_distributed_checkpoint.py | 9 +++ .../distributed_checkpoint/manager.py | 71 ++++++++++++++----- 2 files changed, 63 insertions(+), 17 deletions(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index c7dd119160b..2a4a1f0323e 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -403,6 +403,15 @@ def test_manager_max_to_keep(self, tmpdir): self.assertTrue(chkpt_mgr.save(30, state_dict)) self.assertEqual(set(chkpt_mgr.all_steps()), {30, 20}) + # The oldest is selected by creation timestamp, not step + self.assertTrue(chkpt_mgr.save(10, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {30, 10}) + + # The deletion order should persist across executions + chkpt_mgr = CheckpointManager(tmpdir, save_period=10, max_to_keep=2) + self.assertTrue(chkpt_mgr.save(20, state_dict)) + self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10}) + if __name__ == '__main__': test = unittest.main() diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index a8d7a7d6576..c005f49de05 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -1,18 +1,35 @@ -import heapq +import fsspec +import logging import os +import pickle import torch.distributed as dist import torch.distributed.checkpoint as dist_cp import torch_xla.runtime as xr import torch_xla.experimental.distributed_checkpoint as xc +from dataclasses import dataclass +from datetime import datetime +from collections import deque from fsspec.core import url_to_fs from os.path import basename -from typing import List, Optional, Union +from typing import Deque, List, Optional, Union from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE # TODO(jonbolin): Import path will change from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter +# File to track manager-specific metadata within each checkpoint path +_MANAGER_METADATA_FILE = '.manager_metadata' + + +@dataclass +class _CheckpointMetadata: + # The step at which the checkpoint was taken + step: int + + # The time at which the checkpoint was taken + ts: datetime + class CheckpointManager: """ @@ -106,31 +123,44 @@ def __init__(self, self.save_period = save_period self.max_to_keep = max_to_keep - # Cache tracked steps in a heap for efficient clearing. - self._tracked_steps = self._load_tracked_steps() - heapq.heapify(self._tracked_steps) + self._tracked_chkpts = self._load_tracked_chkpts() - def _load_tracked_steps(self) -> List[int]: - """ Loads a list of all tracked steps from the storage backend. """ + def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]: + """ + Loads a list of all tracked checkpoints from the storage backend. + """ + all_chkpts = [] + invalid_paths = [] fs, raw_path = url_to_fs(self.base_path) - all_paths = fs.ls(raw_path, detail=False) - all_steps = map(basename, all_paths) - return list(map(int, all_steps)) + for path in fs.ls(raw_path, detail=False): + try: + with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'rb') as f: + all_chkpts.append(pickle.load(f)) + except: + invalid_paths.append(path) + + if invalid_paths: + logging.warning(f'Ignoring invalid checkpoints: {invalid_paths}') + return deque(sorted(all_chkpts, key=lambda m: m.ts)) def _get_path(self, step: int) -> str: return os.path.join(self.base_path, str(step)) + def _delete_chkpt_at_step(self, step): + path = self._get_path(step) + fs, raw_path = url_to_fs(path) + if fs.exists(raw_path): + fs.rm(raw_path, recursive=True) + def _release_oldest_checkpoints(self): """ Delete oldest checkpoints until the number of tracked checkpoints is below self.max_to_keep. This operation is only execution on the rank 0 process. """ if dist.get_rank() == 0 and self.max_to_keep > 0: - while len(self._tracked_steps) > self.max_to_keep: - oldest_step = heapq.heappop(self._tracked_steps) - path = self._get_path(oldest_step) - fs, raw_path = url_to_fs(path) - fs.rm(raw_path, recursive=True) + while len(self._tracked_chkpts) > self.max_to_keep: + oldest_chkpt = self._tracked_chkpts.popleft() + self._delete_chkpt_at_step(oldest_chkpt.step) def should_save(self, step: int) -> bool: """ @@ -157,12 +187,17 @@ def save(self, """ if self.should_save(step) or force: path = self._get_path(step) + # Delete any existing checkpoint at the current step. + self._delete_chkpt_at_step(step) dist_cp.save_state_dict( state_dict=state_dict, storage_writer=FsspecWriter(path), planner=xc.SPMDSavePlanner(), ) - heapq.heappush(self._tracked_steps, step) + metadata = _CheckpointMetadata(step=step, ts=datetime.now()) + with fsspec.open(os.path.join(path, _MANAGER_METADATA_FILE), 'wb') as f: + pickle.dump(metadata, f) + self._tracked_chkpts.append(metadata) self._release_oldest_checkpoints() return True return False @@ -203,6 +238,8 @@ def restore(self, step: int, state_dict: STATE_DICT_TYPE) -> None: state_dict: The state dict to restore the checkpoint into. Values are updated in-place within the state_dict. """ + tracked_steps = set(x.step for x in self._tracked_chkpts) + assert step in tracked_steps, f'Cannot restore from untracked step {step}. Valid steps are: {tracked_steps}' path = self._get_path(step) dist_cp.load_state_dict( state_dict=state_dict, @@ -214,4 +251,4 @@ def all_steps(self) -> List[int]: """ List all steps tracked by the CheckpointManager. """ - return sorted(self._tracked_steps) + return sorted(x.step for x in self._tracked_chkpts) From 3ebc8af5c9fef8d5a969b23d05266b727e41b720 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Thu, 12 Oct 2023 20:07:37 +0000 Subject: [PATCH 6/7] Rename save_period to save_interval --- .../experimental/distributed_checkpoint/manager.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/torch_xla/experimental/distributed_checkpoint/manager.py b/torch_xla/experimental/distributed_checkpoint/manager.py index c005f49de05..abf56bff940 100644 --- a/torch_xla/experimental/distributed_checkpoint/manager.py +++ b/torch_xla/experimental/distributed_checkpoint/manager.py @@ -84,15 +84,15 @@ class CheckpointManager: # checkpoint's step. base_path: Union[str, os.PathLike] - # The period to take checkpoints, in steps. - save_period: int + # The interval to take checkpoints, in steps. + save_interval: int # The maximum number of checkpoints to keep. max_to_keep: int def __init__(self, path: str, - save_period: int, + save_interval: int, max_to_keep: Optional[int] = 0, async_queue_size: Optional[int] = 1): """ @@ -101,7 +101,7 @@ def __init__(self, Args: path: The base path for the CheckpointManager to write checkpoints into. - save_period: The number of steps between saving checkpoints. + save_interval: The number of steps between saving checkpoints. max_to_keep: The maximum number of checkpoints to be tracked by the CheckpointManager. When a new checkpoint will be taken, the checkpoint for the lowest tracked step will be deleted. @@ -115,12 +115,12 @@ def __init__(self, pending at a time. """ assert dist.is_initialized(), "A process group is required." - assert save_period > 0, "save_period must be positive" + assert save_interval > 0, "save_interval must be positive" assert async_queue_size > 0, "async_queue_size must be positive" assert max_to_keep >= 0, "max_to_keep must be non-negative" self.base_path = path - self.save_period = save_period + self.save_interval = save_interval self.max_to_keep = max_to_keep self._tracked_chkpts = self._load_tracked_chkpts() @@ -168,7 +168,7 @@ def should_save(self, step: int) -> bool: a preemption has been detected. """ # TODO(jonbolin): Support preemption notice for auto checkpointing - return step % self.save_period == 0 + return step % self.save_interval == 0 def save(self, step, From 395059a727fdafb4a55bbefde0749e0c18091489 Mon Sep 17 00:00:00 2001 From: Jon Bolin Date: Thu, 12 Oct 2023 22:10:44 +0000 Subject: [PATCH 7/7] Fix tests --- test/spmd/test_xla_distributed_checkpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/spmd/test_xla_distributed_checkpoint.py b/test/spmd/test_xla_distributed_checkpoint.py index 2a4a1f0323e..74bc27cdf98 100644 --- a/test/spmd/test_xla_distributed_checkpoint.py +++ b/test/spmd/test_xla_distributed_checkpoint.py @@ -347,7 +347,7 @@ def tearDown(self): @run_with_tmpdir def test_manager_checkpointing(self, tmpdir): - chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) state_dict = self._get_sharded_model().state_dict() # Take a checkpoint on step 0 @@ -367,7 +367,7 @@ def test_manager_checkpointing(self, tmpdir): @run_with_tmpdir def test_manager_step_tracking(self, tmpdir): - chkpt_mgr = CheckpointManager(tmpdir, save_period=10) + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10) state_dict = self._get_sharded_model().state_dict() # No steps are being tracked initially @@ -387,7 +387,7 @@ def test_manager_step_tracking(self, tmpdir): @run_with_tmpdir def test_manager_max_to_keep(self, tmpdir): - chkpt_mgr = CheckpointManager(tmpdir, save_period=10, max_to_keep=2) + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10, max_to_keep=2) state_dict = self._get_sharded_model().state_dict() # No steps are being tracked initially @@ -408,7 +408,7 @@ def test_manager_max_to_keep(self, tmpdir): self.assertEqual(set(chkpt_mgr.all_steps()), {30, 10}) # The deletion order should persist across executions - chkpt_mgr = CheckpointManager(tmpdir, save_period=10, max_to_keep=2) + chkpt_mgr = CheckpointManager(tmpdir, save_interval=10, max_to_keep=2) self.assertTrue(chkpt_mgr.save(20, state_dict)) self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10})