Skip to content

Commit

Permalink
Allow threads to exit when CheckpointManager is freed
Browse files Browse the repository at this point in the history
  • Loading branch information
jonb377 committed Oct 12, 2023
1 parent 32144fc commit b9e1952
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 15 deletions.
37 changes: 32 additions & 5 deletions test/spmd/test_xla_distributed_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import tempfile
import unittest
import test_xla_sharding_base
import threading

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -413,7 +414,35 @@ def test_manager_max_to_keep(self, tmpdir):
self.assertEqual(set(chkpt_mgr.all_steps()), {20, 10})

@run_with_tmpdir
def test_manager_async_checkpoint(self, tmpdir):
def test_manager_async(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_period=10)
state_dict = self._get_sharded_model().state_dict()

# Patch the manager's save method to block until this thread signals.
cond = threading.Condition()
old_save = chkpt_mgr.save

def patched_save(*args, **kwargs):
cond.wait()
old_save(*args, **kwargs)

with unittest.mock.patch.object(chkpt_mgr, 'save', patched_save):
chkpt_mgr.save_async(10, state_dict)

# No new steps should be tracked immediately after calling save_async
self.assertEqual(chkpt_mgr.all_steps(), [])

# Trigger the actual checkpoint in the background thread and wait for
# completion.
with cond:
cond.notify()
chkpt_mgr.join()

# The manager should track all steps which were asynchronously saved.
self.assertEqual(set(chkpt_mgr.all_steps()), {10})

@run_with_tmpdir
def test_manager_async_step_tracking(self, tmpdir):
chkpt_mgr = CheckpointManager(tmpdir, save_period=10)
state_dict = self._get_sharded_model().state_dict()

Expand All @@ -430,12 +459,10 @@ def test_manager_async_checkpoint(self, tmpdir):
self.assertTrue(chkpt_mgr.save_async(step, state_dict))
saved.add(step)

# Delete the checkpoint manager to block this thread until all pending
# async checkpoints are complete.
del chkpt_mgr
# Join to allow pending async checkpoints to complete
chkpt_mgr.join()

# The manager should track all steps which were asynchronously saved.
chkpt_mgr = CheckpointManager(tmpdir, save_period=10)
self.assertEqual(set(chkpt_mgr.all_steps()), saved)

# Load a checkpoint into a new state_dict
Expand Down
35 changes: 25 additions & 10 deletions torch_xla/experimental/distributed_checkpoint/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ def __init__(self,
path: str,
save_period: int,
max_to_keep: Optional[int] = 0,
async_queue_size: Optional[int] = 1):
async_queue_size: Optional[int] = 1,
process_group: dist.ProcessGroup = None):
"""
Create a checkpoint manager that reads and writes checkpoints into
the provided directory.
Expand All @@ -117,6 +118,9 @@ def __init__(self,
network issues which slow down the active checkpoint.
Default: 1, which only allows a single async checkpoint to be
pending at a time.
process_group: The process group to use when coordinating the checkpoint.
Default: None, in which case a subgroup of the default process
group will be created.
"""
assert dist.is_initialized(), "A process group is required."
assert save_period > 0, "save_period must be positive"
Expand All @@ -128,13 +132,15 @@ def __init__(self,
self.max_to_keep = max_to_keep

self._tracked_chkpts = self._load_tracked_chkpts()

self._async_queue = queue.Queue(maxsize=async_queue_size)
self._chkpt_thread = threading.Thread(target=self._async_worker, daemon=True)
self._alive = threading.Event()
self._alive.set()
self._chkpt_thread = threading.Thread(
target=self._async_worker, daemon=True)
self._chkpt_thread.start()

# Create a CPU process group to coordinate the checkpoint.
self.pg = dist.new_group(backend='gloo')
# Create a new group if none is provided
self.pg = process_group or dist.new_group()

def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
"""
Expand All @@ -155,14 +161,19 @@ def _load_tracked_chkpts(self) -> Deque[_CheckpointMetadata]:
return deque(sorted(all_chkpts, key=lambda m: m.ts))

def __del__(self):
# Ensure pending checkpoints are finished
self._async_queue.join()
self._alive.clear()
# Send a sentinel value to tell the worker to exit, and wait for pending
# checkpoints to complete.
self._async_queue.put((None, None))
self._chkpt_thread.join()

def _async_worker(self):
while True:
while self._alive.is_set():
try:
step, state_dict = self._async_queue.get()
self.save(step, state_dict, force=True)
item = self._async_queue.get()
if item:
step, state_dict = item
self.save(step, state_dict, force=True)
except:
traceback.print_exc()
finally:
Expand Down Expand Up @@ -284,3 +295,7 @@ def all_steps(self) -> List[int]:
List all steps tracked by the CheckpointManager.
"""
return sorted(x.step for x in self._tracked_chkpts)

def join(self):
""" Wait for all pending async checkpoints to complete. """
self._async_queue.join()

0 comments on commit b9e1952

Please sign in to comment.