Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

checkpoint saver tracks all checkpoints/intervals in state #2819

Merged
merged 14 commits into from
Jan 9, 2024
43 changes: 30 additions & 13 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import tempfile
import textwrap
from pathlib import Path
from typing import Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

from composer.core import Callback, Event, State, Time
from composer.core import Callback, Event, State, Time, Timestamp
from composer.loggers import Logger
from composer.utils import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, PartialFilePath,
checkpoint, create_interval_scheduler, create_symlink_file, dist,
ensure_folder_has_no_conflicting_files, format_name_with_dist,
format_name_with_dist_and_time, is_model_deepspeed, reproducibility, using_torch_2)
format_name_with_dist_and_time, is_model_deepspeed, using_torch_2)
from composer.utils.checkpoint import _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -264,6 +264,7 @@ def __init__(

self.overwrite = overwrite
self.saved_checkpoints: List[str] = []
self.all_saved_checkpoints_to_timestamp: Dict[str, Timestamp] = {}
self.num_checkpoints_to_keep = num_checkpoints_to_keep
self.weights_only = weights_only

Expand Down Expand Up @@ -303,11 +304,24 @@ def epoch_checkpoint(self, state: State, logger: Logger):
logger,
)

def get_state_dict(self, state):
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
return {
'state': state.state_dict(),
'rng': reproducibility.get_rng_state(),
}
def state_dict(self) -> Dict[str, Any]:
state_dict = super().state_dict()
aspfohl marked this conversation as resolved.
Show resolved Hide resolved

all_checkpoints = {}
for save_filename, timestamp in self.all_saved_checkpoints_to_timestamp.items():
all_checkpoints[save_filename] = timestamp.state_dict()

# TODO: consider saving additional state for checkpoint rotation
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
state_dict['all_saved_checkpoints_to_timestamp'] = all_checkpoints
return state_dict

def load_state_dict(self, state: Dict[str, Any]):
super().load_state_dict(state)
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
if 'all_saved_checkpoints_to_timestamp' in state:
for save_filename, timestamp_state in state['all_saved_checkpoints_to_timestamp'].items():
new_timetamp = Timestamp()
new_timetamp.load_state_dict(timestamp_state)
self.all_saved_checkpoints_to_timestamp[save_filename] = new_timetamp

def _save_checkpoint(self, state: State, logger: Logger):
self.last_checkpoint_batch = state.timestamp.batch
Expand All @@ -319,16 +333,19 @@ def _save_checkpoint(self, state: State, logger: Logger):

# save the checkpoint to the filename
filename_with_placeholders = self.filename.format(state, is_deepspeed, keep_placeholders=True)
save_filename = checkpoint.get_save_filename(state, filename_with_placeholders)
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
self.all_saved_checkpoints_to_timestamp[save_filename] = state.timestamp

saved_path = checkpoint.save_checkpoint(
saved_path = checkpoint._save_checkpoint(
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
state=state,
filename=filename_with_placeholders,
save_filename=save_filename,
weights_only=self.weights_only,
)
log.debug(f'Checkpoint locally saved to {saved_path}')

if not saved_path: # not all ranks save
return

metadata_local_file_path = None
if dist.get_global_rank() == 0 and state.fsdp_elastic_sharded_enabled:
metadata_local_file_path = format_name_with_dist_and_time(
Expand Down Expand Up @@ -423,10 +440,10 @@ def _rotate_checkpoints(self, sharding_enabled: bool = False):

while len(self.saved_checkpoints) > self.num_checkpoints_to_keep:
prefix_dir = None
checkpoint = self.saved_checkpoints.pop(0)
prefix_dir = str(Path(checkpoint).parent)
checkpoint_to_delete = self.saved_checkpoints.pop(0)
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
prefix_dir = str(Path(checkpoint_to_delete).parent)
if not sharding_enabled:
os.remove(checkpoint)
os.remove(checkpoint_to_delete)
else:
if dist.get_global_rank() == 0:
shutil.rmtree(prefix_dir)
52 changes: 34 additions & 18 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,8 @@ def _remove_paths(obj: Union[list, dict[str, Any]], exclude_paths: list[list[str

# Recurse first, so in the case of a list, the indexing is consistent
for key, paths_to_recurse in keys_to_recurse.items():
_remove_paths(obj[key], paths_to_recurse)
if key in obj:
_remove_paths(obj[key], paths_to_recurse)
aspfohl marked this conversation as resolved.
Show resolved Hide resolved

# Sort the keys in reverse order, so in the case of a list, the indexing is consistent
keys_to_remove.sort(reverse=True)
Expand Down Expand Up @@ -876,9 +877,29 @@ def _restore_checkpoint(
return state_dict.get('rng', None)


def save_checkpoint(
def get_save_filename(
state: State,
filename: str = 'ep{epoch}-ba{batch}-rank{rank}',
) -> str:
if not state.fsdp_sharded_state_dict_enabled:
is_deepspeed = is_model_deepspeed(state.model)
return PartialFilePath(filename).format(state, is_deepspeed)

# Sharded checkpoints get their own little folder.
assert state.sharded_ckpt_prefix_dir is not None
save_dirpath = Path(Path(filename).parent) / Path(state.sharded_ckpt_prefix_dir)
save_dirpath = format_name_with_dist_and_time(str(save_dirpath), state.run_name, state.timestamp)
# New name is now Trainer.save_folder / sharded_ckpt_prefix_dir / __{dist.get_global_rank()}_0.distcp’ if torch > 2
# else Trainer.save_folder / sharded_ckpt_prefix_dir / ba{batch}_rank{dist.get_global_rank()}.pt’
# e.g. path/to/my/checkpoints/ep1-ba2/__1_0.distcp if torch >2 else its path/to/my/checkpoints/ep1-ba2/b2-rank1.pt
ckpt_filename = _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME if using_torch_2() else format_name_with_dist_and_time(
Path(filename).name, state.run_name, state.timestamp)
return str(Path(save_dirpath) / Path(ckpt_filename))


def _save_checkpoint(
state: State,
save_filename: str,
*,
weights_only: bool = False,
) -> Union[str, None]: # noqa: D103
Expand All @@ -900,9 +921,6 @@ def save_checkpoint(
'rng': reproducibility.get_rng_state(),
}

log.debug('State dict created.')

# Sharded checkpoints get their own little folder.
if state.fsdp_sharded_state_dict_enabled:
# To load optimizer states with torch 2.0, the optimizer state must be at the top
# level of the state dict because the load_sharded_optimizer_state_dict function
Expand All @@ -912,19 +930,7 @@ def save_checkpoint(
if using_torch_2():
if not weights_only:
state_dict['optimizers'] = state_dict['state'].pop('optimizers')

# Specify save directory path and save_f
assert state.sharded_ckpt_prefix_dir is not None
save_dirpath = Path(Path(filename).parent) / Path(state.sharded_ckpt_prefix_dir)
save_dirpath = format_name_with_dist_and_time(str(save_dirpath), state.run_name, state.timestamp)
# New name is now Trainer.save_folder / sharded_ckpt_prefix_dir / __{dist.get_global_rank()}_0.distcp’ if torch > 2
# else Trainer.save_folder / sharded_ckpt_prefix_dir / ba{batch}_rank{dist.get_global_rank()}.pt’
# e.g. path/to/my/checkpoints/ep1-ba2/__1_0.distcp if torch >2 else its path/to/my/checkpoints/ep1-ba2/b2-rank1.pt
ckpt_filename = _TORCH_DISTRIBUTED_CHECKPOINTS_FILENAME if using_torch_2() else format_name_with_dist_and_time(
Path(filename).name, state.run_name, state.timestamp)
save_filename = str(Path(save_dirpath) / Path(ckpt_filename))
else:
save_filename = PartialFilePath(filename).format(state, is_deepspeed)
log.debug('State dict created.')

dirname = os.path.dirname(save_filename)
if dirname:
Expand Down Expand Up @@ -1014,6 +1020,16 @@ def _save_deepspeed_model(model, filename: str):
tar.add(tmpdir, arcname='')


def save_checkpoint(
state: State,
filename: str = 'ep{epoch}-ba{batch}-rank{rank}',
*,
weights_only: bool = False,
) -> Union[str, None]: # noqa: D103
save_filename = get_save_filename(state, filename)
return _save_checkpoint(state, save_filename, weights_only=weights_only)


save_checkpoint.__doc__ = f"""Checkpoint the training ``state``.

Args:
Expand Down
43 changes: 43 additions & 0 deletions tests/callbacks/test_checkpoint_saver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from composer.callbacks import CheckpointSaver
from composer.core import Timestamp


def test_stateful_checkpoint_saver():
checkpoint_saver = CheckpointSaver()
assert not checkpoint_saver.all_saved_checkpoints_to_timestamp

# empty state dict
empty_state_dict = checkpoint_saver.state_dict()
assert 'all_saved_checkpoints_to_timestamp' in empty_state_dict
assert len(empty_state_dict['all_saved_checkpoints_to_timestamp']) == 0

# backwards compatibility; empty state dict should not raise
checkpoint_saver.load_state_dict({})
assert not checkpoint_saver.all_saved_checkpoints_to_timestamp

# add a checkpoint and confirm it can save and load
checkpoint_saver.all_saved_checkpoints_to_timestamp = {
'example-checkpoint.pt': Timestamp(epoch=1, batch=2),
}
new_state_dict = checkpoint_saver.state_dict()
assert 'all_saved_checkpoints_to_timestamp' in new_state_dict
assert len(new_state_dict['all_saved_checkpoints_to_timestamp']) == 1
assert 'example-checkpoint.pt' in new_state_dict['all_saved_checkpoints_to_timestamp']
ts = new_state_dict['all_saved_checkpoints_to_timestamp']['example-checkpoint.pt']
assert isinstance(ts, dict)
assert ts['epoch'] == 1
assert ts['batch'] == 2
assert ts['sample'] == 0

checkpoint_saver.load_state_dict(new_state_dict)

assert checkpoint_saver.all_saved_checkpoints_to_timestamp
assert 'example-checkpoint.pt' in checkpoint_saver.all_saved_checkpoints_to_timestamp
aspfohl marked this conversation as resolved.
Show resolved Hide resolved
ts = checkpoint_saver.all_saved_checkpoints_to_timestamp['example-checkpoint.pt']
assert isinstance(ts, Timestamp)
assert ts.epoch == 1
assert ts.batch == 2
assert ts.sample == 0
4 changes: 4 additions & 0 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def _assert_checkpoints_equivalent(file1, file2, atol=0.0, rtol=0.0):
if 'DummyStatefulCallback' in ckpt['state']['callbacks']:
del ckpt['state']['callbacks']['DummyStatefulCallback']

# Remove all saved checkpoints to timestamp (accumulates between runs)
del checkpoint_1['state']['callbacks']['CheckpointSaver']['all_saved_checkpoints_to_timestamp']
del checkpoint_2['state']['callbacks']['CheckpointSaver']['all_saved_checkpoints_to_timestamp']

deep_compare(checkpoint_1, checkpoint_2, atol=atol, rtol=rtol)

# deepspeed checkpoints do not have model or optimizer
Expand Down
Loading