diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 91dd0b1e19..f5a6b57d77 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1732,11 +1732,6 @@ def __init__( error_message = '' if save_folder is None: error_message += 'The `save_folder` must be specified when autoresume is enabled. ' - if save_overwrite: - error_message += textwrap.dedent( - 'The flag `save_overwrite` must be False when autoresume is enabled as autoresume always loads the ' - 'latest existing checkpoint in `save_folder`. ', - ) if save_latest_filename is None: error_message += 'The `save_latest_filename` must be specified so autoresume knows where to load checkpoints from. ' if error_message != '': diff --git a/tests/trainer/test_checkpoint.py b/tests/trainer/test_checkpoint.py index d23b55875f..9912563eb8 100644 --- a/tests/trainer/test_checkpoint.py +++ b/tests/trainer/test_checkpoint.py @@ -730,11 +730,19 @@ def get_logger(self, tmp_path: pathlib.Path): @world_size(1, 2) @device('cpu', 'gpu') - @pytest.mark.parametrize('file_extension', ['.pt', '.tar.gz', '.pt.lz4']) @pytest.mark.parametrize('use_object_store', [True, False]) @pytest.mark.parametrize('delete_local', [True, False]) @pytest.mark.parametrize('test_slashed', [True, False]) - @pytest.mark.parametrize('save_metrics', [True, False]) + @pytest.mark.parametrize( + 'file_extension,save_metrics,save_overwrite', + [ + ['.pt', False, False], + ['.tar.gz', False, False], + ['.pt.lz4', False, False], + ['.pt', True, False], + ['.pt', False, True], + ], + ) def test_autoresume( self, device: str, @@ -744,6 +752,7 @@ def test_autoresume( delete_local: bool, test_slashed: bool, save_metrics: bool, + save_overwrite: bool, world_size: int, ): if delete_local and not use_object_store: @@ -786,6 +795,7 @@ def test_autoresume( autoresume=True, load_path='ignore_me.pt', # this should be ignored load_ignore_keys=['*'], # this should be ignored + save_overwrite=save_overwrite, loggers=[self.get_logger(tmp_path)] if use_object_store else [], ) @@ -1212,19 +1222,17 @@ def test_load_weights_object_store(self, tmp_path): ) @pytest.mark.parametrize( - 'run_name,save_folder,save_overwrite,latest_filename', + 'run_name,save_folder,latest_filename', [ - [None, 'first', False, 'latest-rank{rank}.pt'], - ['big-chungus', None, False, 'latest-rank{rank}.pt'], - ['big-chungus', 'first', True, 'latest-rank{rank}.pt'], - ['big-chungus', 'first', False, None], + [None, 'first', 'latest-rank{rank}.pt'], + ['big-chungus', None, 'latest-rank{rank}.pt'], + ['big-chungus', 'first', None], ], ) - def test_autoresume_fail(self, run_name, save_folder, save_overwrite, latest_filename): + def test_autoresume_fail(self, run_name, save_folder, latest_filename): with pytest.raises(ValueError): self.get_trainer( latest_filename=latest_filename, - save_overwrite=save_overwrite, save_folder=save_folder, run_name=run_name, autoresume=True,