Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add save_ignore_keys #2868

Merged
merged 10 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,6 @@ class CheckpointSaver(Callback): # noqa: D101
progress). It should return ``True`` if a checkpoint should be saved given the current state and
event.

weights_only (bool): If ``True``, save only the model weights instead of the entire training state.
This parameter must be ``False`` when using DeepSpeed. Default: ``False``.


num_checkpoints_to_keep (int, optional): The number of checkpoints to keep locally. The oldest checkpoints
are removed first. Set to ``-1`` to keep all checkpoints locally. Default: ``-1``.

Expand All @@ -214,6 +210,31 @@ class CheckpointSaver(Callback): # noqa: D101
This parameter only controls how many checkpoints are kept locally; checkpoints are not deleted from
remote file systems.

weights_only (bool): If ``True``, save only the model weights instead of the entire training state.
This parameter must be ``False`` when using DeepSpeed. Default: ``False``.

ignore_keys (List[str] | (Dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint,
which, when provided, will be ignored from the state_dict before a checkpoint is saved. Each path is a list
of strings specifying the keys to index into ``state_dict`` joined together with `/` as a separator (as PyTorch
uses `.` in parameter names). If a prefix is provided, all children are also ignored (see Example 2).
See :mod:`composer.core.state` for the structure of state_dict.

Example 1: ``save_ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]`` would ignore
layer 1 weights and bias.

Example 2: ``save_ignore_keys = ["state/model/*"]`` would ignore the entire model, which would have the same
effect as the previous example if there was only 1 layer.

Example 3: ``save_ignore_keys = ["state/model/layer*.weights"]`` would ignore all weights in the model.

Example 4: ``save_ignore_keys = ["state/rank_zero_seed", "rng"]`` would reset all randomness when
saving the checkpoint.

If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify
the state_dict before it is loaded.

(default: ``None``)

Attributes:
saved_checkpoints (List[Tuple[Timestamp, List[pathlib.Path]]]): The checkpoint timestamps and filepaths.

Expand Down Expand Up @@ -243,6 +264,7 @@ def __init__(
overwrite: bool = False,
num_checkpoints_to_keep: int = -1,
weights_only: bool = False,
ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
):
folder = str(folder)
filename = str(filename)
Expand All @@ -267,6 +289,7 @@ def __init__(
self.all_saved_checkpoints_to_timestamp: Dict[str, Timestamp] = {}
self.num_checkpoints_to_keep = num_checkpoints_to_keep
self.weights_only = weights_only
self.ignore_keys = ignore_keys

self.start_batch = None

Expand Down Expand Up @@ -363,6 +386,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
state=state,
filename=filename_with_placeholders,
weights_only=self.weights_only,
ignore_keys=self.ignore_keys,
)
log.debug(f'Checkpoint locally saved to {saved_path}')

Expand Down
23 changes: 23 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,27 @@ class Trainer:
state. This parameter has no effect if ``save_folder`` is ``None``. (default: ``False``)

.. seealso:: :class:`~.CheckpointSaver`
save_ignore_keys (List[str] | (Dict) -> None, optional): A list of paths for the ``state_dict`` of the checkpoint,
which, when provided, will be ignored from the state_dict before a checkpoint is saved. Each path is a list
of strings specifying the keys to index into ``state_dict`` joined together with `/` as a separator (as PyTorch
uses `.` in parameter names). If a prefix is provided, all children are also ignored (see Example 2).
See :mod:`composer.core.state` for the structure of state_dict.

Example 1: ``save_ignore_keys = ["state/model/layer1.weights", "state/model/layer1.bias"]`` would ignore
layer 1 weights and bias.

Example 2: ``save_ignore_keys = ["state/model/*"]`` would ignore the entire model, which would have the same
effect as the previous example if there was only 1 layer.

Example 3: ``save_ignore_keys = ["state/model/layer*.weights"]`` would ignore all weights in the model.

Example 4: ``save_ignore_keys = ["state/rank_zero_seed", "rng"]`` would reset all randomness when
saving the checkpoint.

If a callable, it should take one argument which is the state_dict. The callable is free to arbitrarily modify
the state_dict before it is loaded.

(default: ``None``)
save_num_checkpoints_to_keep (int, optional): The number of checkpoints to keep locally. The oldest checkpoints
are removed first. Set to ``-1`` to keep all checkpoints locally. (default: ``-1``)

Expand Down Expand Up @@ -866,6 +887,7 @@ def __init__(
save_overwrite: bool = False,
save_interval: Union[str, int, Time, Callable[[State, Event], bool]] = '1ep',
save_weights_only: bool = False,
save_ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
save_num_checkpoints_to_keep: int = -1,
save_metrics: bool = False,

Expand Down Expand Up @@ -1150,6 +1172,7 @@ def __init__(
latest_remote_file_name=latest_remote_file_name,
overwrite=save_overwrite,
weights_only=save_weights_only,
ignore_keys=save_ignore_keys,
save_interval=save_interval,
num_checkpoints_to_keep=save_num_checkpoints_to_keep,
)
Expand Down
15 changes: 13 additions & 2 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import warnings
from importlib import import_module
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union

import torch
from packaging import version
Expand Down Expand Up @@ -938,6 +938,7 @@ def _save_checkpoint(
save_filename: str,
*,
weights_only: bool = False,
ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
) -> Union[str, None]: # noqa: D103

is_deepspeed = is_model_deepspeed(state.model)
Expand All @@ -957,6 +958,15 @@ def _save_checkpoint(
'rng': reproducibility.get_rng_state(),
}

if ignore_keys:
# Filter provided list of key paths
if not callable(ignore_keys):
ignore_keys = glob_filter(ignore_keys)
# Call function to modify state_dict
ignore_keys(state_dict)
# Ensure state exists
state_dict['state'] = state_dict.get('state', {})

if state.fsdp_sharded_state_dict_enabled:
# To load optimizer states with 2.0 <= torch < 2.1.3 , the optimizer state must be at the top
# level of the state dict because the load_sharded_optimizer_state_dict function
Expand Down Expand Up @@ -1087,9 +1097,10 @@ def save_checkpoint(
filename: str = 'ep{epoch}-ba{batch}-rank{rank}',
*,
weights_only: bool = False,
ignore_keys: Optional[Union[List[str], Callable[[Dict], None]]] = None,
) -> Union[str, None]: # noqa: D103
save_filename = get_save_filename(state, filename)
return _save_checkpoint(state, save_filename, weights_only=weights_only)
return _save_checkpoint(state, save_filename, weights_only=weights_only, ignore_keys=ignore_keys)


save_checkpoint.__doc__ = f"""Checkpoint the training ``state``.
Expand Down
38 changes: 38 additions & 0 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,44 @@ def test_load_ignore_keys(self, load_ignore_keys, weights_equal, callbacks_equal
assert trainer_1_rng_state is not None
deep_compare(trainer_1_rng_state, trainer_2._rng_state)

@pytest.mark.parametrize('save_ignore_keys,weights_equal,callbacks_equal,rng_equal', [
['*', False, False, False],
['state/model/*', False, True, True],
['state/callbacks/*', True, False, True],
['rng', True, True, False],
])
@pytest.mark.filterwarnings('ignore:.* is not in the state_dict.*:UserWarning')
def test_save_ignore_keys(self, save_ignore_keys, weights_equal, callbacks_equal, rng_equal):

trainer_1 = self.get_trainer(save_folder='first', save_ignore_keys=[save_ignore_keys])
trainer_1.fit()
trainer_1_rng_state = reproducibility.get_rng_state()
trainer_1.close()

last_checkpoint = os.path.join('first', 'ep2.pt')
trainer_2 = self.get_trainer(load_path=last_checkpoint)

# Check weights loaded properly
with contextlib.nullcontext() if weights_equal else pytest.raises(AssertionError):
self._assert_weights_equivalent(
trainer_1.state.model,
trainer_2.state.model,
)

# Check callbacks state
stateful_callbacks_equal = self._stateful_callbacks_equal(
trainer_1.state.callbacks,
trainer_2.state.callbacks,
)
if callbacks_equal:
assert stateful_callbacks_equal
else:
assert not stateful_callbacks_equal

if rng_equal:
assert trainer_1_rng_state is not None
deep_compare(trainer_1_rng_state, trainer_2._rng_state)

@pytest.mark.remote
@device('cpu')
@pytest.mark.parametrize('load_weights_only', [True, False])
Expand Down
Loading