Skip to content

Commit

Permalink
Remove save overwrite (mosaicml#3431)
Browse files Browse the repository at this point in the history
* remove save overwrite

* fix tests

* lint

* remove bad test
  • Loading branch information
mvpatel2000 authored Jun 27, 2024
1 parent fc24d64 commit 5f8265d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 14 deletions.
5 changes: 0 additions & 5 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1732,11 +1732,6 @@ def __init__(
error_message = ''
if save_folder is None:
error_message += 'The `save_folder` must be specified when autoresume is enabled. '
if save_overwrite:
error_message += textwrap.dedent(
'The flag `save_overwrite` must be False when autoresume is enabled as autoresume always loads the '
'latest existing checkpoint in `save_folder`. ',
)
if save_latest_filename is None:
error_message += 'The `save_latest_filename` must be specified so autoresume knows where to load checkpoints from. '
if error_message != '':
Expand Down
26 changes: 17 additions & 9 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,11 +730,19 @@ def get_logger(self, tmp_path: pathlib.Path):

@world_size(1, 2)
@device('cpu', 'gpu')
@pytest.mark.parametrize('file_extension', ['.pt', '.tar.gz', '.pt.lz4'])
@pytest.mark.parametrize('use_object_store', [True, False])
@pytest.mark.parametrize('delete_local', [True, False])
@pytest.mark.parametrize('test_slashed', [True, False])
@pytest.mark.parametrize('save_metrics', [True, False])
@pytest.mark.parametrize(
'file_extension,save_metrics,save_overwrite',
[
['.pt', False, False],
['.tar.gz', False, False],
['.pt.lz4', False, False],
['.pt', True, False],
['.pt', False, True],
],
)
def test_autoresume(
self,
device: str,
Expand All @@ -744,6 +752,7 @@ def test_autoresume(
delete_local: bool,
test_slashed: bool,
save_metrics: bool,
save_overwrite: bool,
world_size: int,
):
if delete_local and not use_object_store:
Expand Down Expand Up @@ -786,6 +795,7 @@ def test_autoresume(
autoresume=True,
load_path='ignore_me.pt', # this should be ignored
load_ignore_keys=['*'], # this should be ignored
save_overwrite=save_overwrite,
loggers=[self.get_logger(tmp_path)] if use_object_store else [],
)

Expand Down Expand Up @@ -1212,19 +1222,17 @@ def test_load_weights_object_store(self, tmp_path):
)

@pytest.mark.parametrize(
'run_name,save_folder,save_overwrite,latest_filename',
'run_name,save_folder,latest_filename',
[
[None, 'first', False, 'latest-rank{rank}.pt'],
['big-chungus', None, False, 'latest-rank{rank}.pt'],
['big-chungus', 'first', True, 'latest-rank{rank}.pt'],
['big-chungus', 'first', False, None],
[None, 'first', 'latest-rank{rank}.pt'],
['big-chungus', None, 'latest-rank{rank}.pt'],
['big-chungus', 'first', None],
],
)
def test_autoresume_fail(self, run_name, save_folder, save_overwrite, latest_filename):
def test_autoresume_fail(self, run_name, save_folder, latest_filename):
with pytest.raises(ValueError):
self.get_trainer(
latest_filename=latest_filename,
save_overwrite=save_overwrite,
save_folder=save_folder,
run_name=run_name,
autoresume=True,
Expand Down

0 comments on commit 5f8265d

Please sign in to comment.