From 85e9ccf51f9174436aa206df514fcecd0d2ab097 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 4 Oct 2020 08:32:03 +0200 Subject: [PATCH 01/21] true final value of global step --- pytorch_lightning/trainer/trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 219df9f67301d..681b98602edd8 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -519,6 +519,8 @@ def train(self): f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has' ' not been met. Training will continue...') + # lower the final step as there is no real change + self.global_step -= 1 # hook self.train_loop.on_train_end() From d32396fbfcfcb44a2b4358a23550a70c838000c3 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 4 Oct 2020 08:34:07 +0200 Subject: [PATCH 02/21] ch check --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 2517888ad5ed7..ad27b50d53399 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -210,10 +210,10 @@ def save_checkpoint(self, trainer, pl_module): # here we call each mode sequentially # Mode 1: save all checkpoints OR only the top k if self.save_top_k: - self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath) + self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, global_step, filepath) # Mode 2: save the last checkpoint - self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath) + self._save_last_checkpoint(trainer, pl_module, epoch, global_step, monitor_candidates, filepath) def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: From 0550bc38e047905644e77c115a0d495f015207d2 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sun, 4 Oct 2020 09:02:26 +0200 Subject: [PATCH 03/21] tests --- tests/callbacks/test_lr_monitor.py | 2 +- tests/loggers/test_all.py | 4 ++-- tests/models/test_hooks.py | 4 ++-- .../__init__.py | 0 .../test_eval_loop_dict_return.py | 0 .../test_trainer_steps_dict_return.py | 0 .../test_trainer_steps_result_return.py | 0 .../test_trainer_steps_scalar_return.py | 0 .../test_validation_steps_result_return.py | 2 +- 9 files changed, 6 insertions(+), 6 deletions(-) rename tests/trainer/{legacy_deprecate_flow_log_tests => deprecate_legacy_flow_log}/__init__.py (100%) rename tests/trainer/{legacy_deprecate_flow_log_tests => deprecate_legacy_flow_log}/test_eval_loop_dict_return.py (100%) rename tests/trainer/{legacy_deprecate_flow_log_tests => deprecate_legacy_flow_log}/test_trainer_steps_dict_return.py (100%) rename tests/trainer/{legacy_deprecate_flow_log_tests => deprecate_legacy_flow_log}/test_trainer_steps_result_return.py (100%) rename tests/trainer/{legacy_deprecate_flow_log_tests => deprecate_legacy_flow_log}/test_trainer_steps_scalar_return.py (100%) rename tests/trainer/{legacy_deprecate_flow_log_tests => deprecate_legacy_flow_log}/test_validation_steps_result_return.py (99%) diff --git a/tests/callbacks/test_lr_monitor.py b/tests/callbacks/test_lr_monitor.py index 4370150768504..734771668c98f 100644 --- a/tests/callbacks/test_lr_monitor.py +++ b/tests/callbacks/test_lr_monitor.py @@ -94,7 +94,7 @@ def test_lr_monitor_multi_lrs(tmpdir, logging_interval): 'Names of learning rates not set correctly' if logging_interval == 'step': - expected_number_logged = trainer.global_step + expected_number_logged = trainer.global_step + 1 if logging_interval == 'epoch': expected_number_logged = trainer.max_epochs diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 6a9f3aa1c92ad..0ff4dca8e87c1 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -85,14 +85,14 @@ def log_metrics(self, metrics, step): (0, ['epoch', 'train_some_val']), (0, ['early_stop_on', 'epoch', 'val_acc']), (0, ['hp_metric']), - (1, ['epoch', 'test_acc', 'test_loss']) + (0, ['epoch', 'test_acc', 'test_loss']) ] assert log_metric_names == expected else: expected = [ (0, ['epoch', 'train_some_val']), (0, ['early_stop_on', 'epoch', 'val_acc']), - (1, ['epoch', 'test_acc', 'test_loss']) + (0, ['epoch', 'test_acc', 'test_loss']) ] assert log_metric_names == expected diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 3681484a8bcc8..ca88033127527 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -125,7 +125,7 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): trainer.fit(model) if batch_idx_ > len(model.val_dataloader()) - 1: assert trainer.batch_idx == len(model.val_dataloader()) - 1 - assert trainer.global_step == len(model.val_dataloader()) * max_epochs + assert trainer.global_step == len(model.val_dataloader()) * max_epochs - 1 else: assert trainer.batch_idx == batch_idx_ - assert trainer.global_step == (batch_idx_ + 1) * max_epochs + assert trainer.global_step == (batch_idx_ + 1) * max_epochs - 1 diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/__init__.py b/tests/trainer/deprecate_legacy_flow_log/__init__.py similarity index 100% rename from tests/trainer/legacy_deprecate_flow_log_tests/__init__.py rename to tests/trainer/deprecate_legacy_flow_log/__init__.py diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py b/tests/trainer/deprecate_legacy_flow_log/test_eval_loop_dict_return.py similarity index 100% rename from tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py rename to tests/trainer/deprecate_legacy_flow_log/test_eval_loop_dict_return.py diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py b/tests/trainer/deprecate_legacy_flow_log/test_trainer_steps_dict_return.py similarity index 100% rename from tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py rename to tests/trainer/deprecate_legacy_flow_log/test_trainer_steps_dict_return.py diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_result_return.py b/tests/trainer/deprecate_legacy_flow_log/test_trainer_steps_result_return.py similarity index 100% rename from tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_result_return.py rename to tests/trainer/deprecate_legacy_flow_log/test_trainer_steps_result_return.py diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py b/tests/trainer/deprecate_legacy_flow_log/test_trainer_steps_scalar_return.py similarity index 100% rename from tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py rename to tests/trainer/deprecate_legacy_flow_log/test_trainer_steps_scalar_return.py diff --git a/tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py b/tests/trainer/deprecate_legacy_flow_log/test_validation_steps_result_return.py similarity index 99% rename from tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py rename to tests/trainer/deprecate_legacy_flow_log/test_validation_steps_result_return.py index a43b50c442dac..909b8bd28ec0c 100644 --- a/tests/trainer/legacy_deprecate_flow_log_tests/test_validation_steps_result_return.py +++ b/tests/trainer/deprecate_legacy_flow_log/test_validation_steps_result_return.py @@ -56,7 +56,7 @@ def test_val_step_result_callbacks(tmpdir): assert len(trainer.dev_debugger.early_stopping_history) == 5 # only 2 checkpoints expected - assert len(trainer.dev_debugger.checkpoint_callback_history) == 2 + assert len(trainer.dev_debugger.checkpoint_callback_history) == 3 # make sure the last known metric is correct assert trainer.logger_connector.callback_metrics['checkpoint_on'] == 171 + 15 From cbbd06c74214e59d8890c195d5303a7fc6fe13cc Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 3 Oct 2020 00:37:10 +0200 Subject: [PATCH 04/21] save each validation interval --- pytorch_lightning/callbacks/model_checkpoint.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index ad27b50d53399..93e052df80a1a 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -508,7 +508,6 @@ def _update_best_and_save( trainer, pl_module, ): - k = epoch + 1 if self.save_top_k == -1 else self.save_top_k del_list = [] @@ -536,9 +535,8 @@ def _update_best_and_save( if self.verbose: rank_zero_info( - f"Epoch {epoch:d}: {self.monitor} reached" - f" {current:0.5f} (best {self.best_model_score:0.5f})," - f" saving model to {filepath} as top {k}" + f"Epoch {epoch:d}: {self.monitor} reached {current:0.5f} (best {self.best_model_score:0.5f})," + f' saving model to "{filepath}" as top {k}' ) self._save_model(filepath, trainer, pl_module) From 07bb754260d6962a2bc00ff7b88d916a130b0ab6 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 3 Oct 2020 01:40:36 +0200 Subject: [PATCH 05/21] wip --- .../callbacks/model_checkpoint.py | 33 +++++---- pytorch_lightning/trainer/evaluation_loop.py | 5 +- tests/checkpointing/test_model_checkpoint.py | 70 +++++++++++-------- 3 files changed, 62 insertions(+), 46 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 93e052df80a1a..37e0469257f5e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -204,7 +204,7 @@ def save_checkpoint(self, trainer, pl_module): monitor_candidates = self._monitor_candidates(trainer) # ie: path/val_loss=0.5.ckpt - filepath = self._get_metric_interpolated_filepath_name(epoch, monitor_candidates) + filepath = self._get_metric_interpolated_filepath_name(epoch, global_step, monitor_candidates) # callback supports multiple simultaneous modes # here we call each mode sequentially @@ -213,7 +213,7 @@ def save_checkpoint(self, trainer, pl_module): self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, global_step, filepath) # Mode 2: save the last checkpoint - self._save_last_checkpoint(trainer, pl_module, epoch, global_step, monitor_candidates, filepath) + self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath) def __validate_init_configuration(self): if self.save_top_k is not None and self.save_top_k < -1: @@ -323,6 +323,7 @@ def _format_checkpoint_name( cls, filename: Optional[str], epoch: int, + step: int, metrics: Dict[str, Any], prefix: str = "", ) -> str: @@ -332,7 +333,7 @@ def _format_checkpoint_name( # check and parse user passed keys in the string groups = re.findall(r"(\{.*?)[:\}]", filename) if len(groups) >= 0: - metrics["epoch"] = epoch + metrics.update({"epoch": epoch, 'step': step}) for group in groups: name = group[1:] filename = filename.replace(group, name + "={" + name) @@ -342,7 +343,7 @@ def _format_checkpoint_name( return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt]) def format_checkpoint_name( - self, epoch: int, metrics: Dict[str, Any], ver: Optional[int] = None + self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None ) -> str: """Generate a filename according to the defined template. @@ -350,20 +351,20 @@ def format_checkpoint_name( >>> tmpdir = os.path.dirname(__file__) >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}')) - >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) + >>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={})) 'epoch=0.ckpt' >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}')) - >>> os.path.basename(ckpt.format_checkpoint_name(5, {})) + >>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={})) 'epoch=005.ckpt' >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}')) - >>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456))) + >>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' >>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}')) - >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) + >>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={})) 'missing=0.ckpt' """ filename = self._format_checkpoint_name( - self.filename, epoch, metrics, prefix=self.prefix + self.filename, epoch, step, metrics, prefix=self.prefix ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) @@ -440,12 +441,12 @@ def _validate_monitor_key(self, trainer): ) raise MisconfigurationException(m) - def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics): - filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics) + def _get_metric_interpolated_filepath_name(self, epoch: int, step: int, ckpt_name_metrics: Dict[str, Any]): + filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics) version_cnt = 0 while self._fs.exists(filepath): filepath = self.format_checkpoint_name( - epoch, ckpt_name_metrics, ver=version_cnt + epoch, step, ckpt_name_metrics, ver=version_cnt ) # this epoch called before version_cnt += 1 @@ -457,7 +458,7 @@ def _monitor_candidates(self, trainer): ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics) return ckpt_name_metrics - def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, filepath): + def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath): should_save_last = self.monitor is None or self.save_last if not should_save_last: return @@ -467,7 +468,11 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi # when user ALSO asked for the 'last.ckpt' change the name if self.save_last: last_filepath = self._format_checkpoint_name( - self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix + self.CHECKPOINT_NAME_LAST, + trainer.current_epoch, + trainer.global_step, + ckpt_name_metrics, + prefix=self.prefix ) last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt") diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 5e04173e6df5c..585a0cdd0d6fe 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -239,9 +239,10 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): # depre warning if eval_results is not None and user_reduced: step = 'testing_epoch_end' if self.testing else 'validation_epoch_end' - m = f'The {step} should not return anything as of 9.1.' \ + self.warning_cache.warn( + f'The {step} should not return anything as of 9.1.' f'to log, use self.log(...) or self.write(...) directly in the LightningModule' - self.warning_cache.warn(m) + ) if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index ee988bb8f4b60..ccc9250ba284f 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -53,7 +53,7 @@ def test_model_checkpoint_to_yaml(tmpdir, save_top_k): path_yaml = os.path.join(tmpdir, 'best_k_models.yaml') checkpoint.to_yaml(path_yaml) d = yaml.full_load(open(path_yaml, 'r')) - best_k = {k: v.item() for k, v in checkpoint.best_k_models.items()} + best_k = {k: v for k, v in checkpoint.best_k_models.items()} assert d == best_k @@ -124,7 +124,9 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir): """Test to ensure that the model callback saves the checkpoints only once in distributed mode.""" model = EvalModelTemplate() num_epochs = 4 - model_checkpoint = ModelCheckpointTestInvocations(monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1) + model_checkpoint = ModelCheckpointTestInvocations( + filepath=tmpdir, monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1 + ) trainer = Trainer( distributed_backend="ddp_cpu", num_processes=2, @@ -139,50 +141,51 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir): def test_model_checkpoint_format_checkpoint_name(tmpdir): # empty filename: - ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, {}) - assert ckpt_name == 'epoch=3' - ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, {}, prefix='test') - assert ckpt_name == 'test-epoch=3' + ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {}) + assert ckpt_name == 'epoch=3-step=2' + ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, 2, {}, prefix='test') + assert ckpt_name == 'test-epoch=3-step=2' # no groups case: - ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, {}, prefix='test') + ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, 2, {}, prefix='test') assert ckpt_name == 'test-ckpt' # no prefix - ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, {'acc': 0.03}) + ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch:03d}-{acc}', 3, 2, {'acc': 0.03}) assert ckpt_name == 'epoch=003-acc=0.03' # prefix char_org = ModelCheckpoint.CHECKPOINT_JOIN_CHAR ModelCheckpoint.CHECKPOINT_JOIN_CHAR = '@' - ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, {'acc': 0.03}, prefix='test') + ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, 2, {'acc': 0.03}, prefix='test') assert ckpt_name == 'test@epoch=3,acc=0.03000' ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org # no filepath set - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=None).format_checkpoint_name(3, {}) - assert ckpt_name == 'epoch=3.ckpt' - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='').format_checkpoint_name(5, {}) - assert ckpt_name == 'epoch=5.ckpt' + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=None).format_checkpoint_name(3, 4, {}) + assert ckpt_name == 'epoch=3-step=4.ckpt' + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='').format_checkpoint_name(5, 4, {}) + assert ckpt_name == 'epoch=5-step=4.ckpt' # CWD - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, {}) - assert Path(ckpt_name) == Path('.') / 'epoch=3.ckpt' + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, 4, {}) + assert Path(ckpt_name) == Path('.') / 'epoch=3-step=4.ckpt' # dir does not exist so it is used as filename filepath = tmpdir / 'dir' - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {}) + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 4, {}) assert ckpt_name == tmpdir / 'test-dir.ckpt' # now, dir exists os.mkdir(filepath) - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, {}) - assert ckpt_name == filepath / 'test-epoch=3.ckpt' + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 4, {}) + assert ckpt_name == filepath / 'test-epoch=3-step=4.ckpt' # with ver ckpt_name = ModelCheckpoint(monitor='early_stop_on', - filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, {}, ver=3) + filepath=tmpdir / 'name', prefix='test').format_checkpoint_name(3, 4, {}, ver=3) assert ckpt_name == tmpdir / 'test-name-v3.ckpt' def test_model_checkpoint_save_last(tmpdir): """Tests that save_last produces only one last checkpoint.""" + seed_everything() model = EvalModelTemplate() epochs = 3 ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' - model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir, save_top_k=-1, save_last=True) + model_checkpoint = ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir / '{step}', save_top_k=-1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, @@ -191,10 +194,12 @@ def test_model_checkpoint_save_last(tmpdir): logger=False, ) trainer.fit(model) - last_filename = model_checkpoint._format_checkpoint_name(ModelCheckpoint.CHECKPOINT_NAME_LAST, epochs - 1, {}) + last_filename = model_checkpoint._format_checkpoint_name( + ModelCheckpoint.CHECKPOINT_NAME_LAST, trainer.current_epoch, trainer.global_step, {} + ) last_filename = last_filename + '.ckpt' assert str(tmpdir / last_filename) == model_checkpoint.last_model_path - assert set(os.listdir(tmpdir)) == set([f'epoch={i}.ckpt' for i in range(epochs)] + [last_filename]) + assert set(os.listdir(tmpdir)) == set([f'step={i}.ckpt' for i in [19, 29, 30]] + [last_filename]) ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last' @@ -229,12 +234,13 @@ def test_none_monitor_save_last(tmpdir): def test_model_checkpoint_none_monitor(tmpdir): """ Test that it is possible to save all checkpoints when monitor=None. """ + seed_everything() model = EvalModelTemplate() model.validation_step = model.validation_step_no_monitor model.validation_epoch_end = model.validation_epoch_end_no_monitor epochs = 2 - checkpoint_callback = ModelCheckpoint(monitor=None, filepath=tmpdir, save_top_k=-1) + checkpoint_callback = ModelCheckpoint(monitor=None, filepath=tmpdir / '{step}', save_top_k=-1) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, @@ -246,13 +252,13 @@ def test_model_checkpoint_none_monitor(tmpdir): # these should not be set if monitor is None assert checkpoint_callback.monitor is None - assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1.ckpt' + assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'step=20.ckpt' assert checkpoint_callback.best_model_score == 0 assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == '' # check that the correct ckpts were created - expected = [f'epoch={e}.ckpt' for e in range(epochs)] + expected = [f'step={i}.ckpt' for i in [9, 19, 20]] assert set(os.listdir(tmpdir)) == set(expected) @@ -260,7 +266,7 @@ def test_model_checkpoint_none_monitor(tmpdir): def test_model_checkpoint_period(tmpdir, period): model = EvalModelTemplate() epochs = 5 - checkpoint_callback = ModelCheckpoint(filepath=tmpdir, save_top_k=-1, period=period) + checkpoint_callback = ModelCheckpoint(filepath=tmpdir / '{epoch}', save_top_k=-1, period=period) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, @@ -268,6 +274,7 @@ def test_model_checkpoint_period(tmpdir, period): max_epochs=epochs, limit_train_batches=0.1, limit_val_batches=0.1, + val_check_interval=1.0, logger=False, ) trainer.fit(model) @@ -304,13 +311,14 @@ def test_model_checkpoint_topk_all(tmpdir): seed_everything(1000) epochs = 2 model = EvalModelTemplate() - checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor="early_stop_on", save_top_k=-1) + checkpoint_callback = ModelCheckpoint(filepath=tmpdir / '{epoch}', monitor="early_stop_on", save_top_k=-1) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, checkpoint_callback=checkpoint_callback, max_epochs=epochs, logger=False, + val_check_interval=1.0, ) trainer.fit(model) assert checkpoint_callback.best_model_path == tmpdir / "epoch=1.ckpt" @@ -364,12 +372,12 @@ def test_default_checkpoint_behavior(tmpdir): assert len(results) == 1 assert results[0]['test_acc'] >= 0.80 - assert len(trainer.dev_debugger.checkpoint_callback_history) == 3 + assert len(trainer.dev_debugger.checkpoint_callback_history) == 4 # make sure the checkpoint we saved has the metric in the name ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints')) assert len(ckpts) == 1 - assert ckpts[0] == 'epoch=2.ckpt' + assert ckpts[0] == 'epoch=2-step=15.ckpt' def test_ckpt_metric_names_results(tmpdir): @@ -426,19 +434,21 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): model = EvalModelTemplate() num_epochs = 3 model_checkpoint = ModelCheckpoint( - monitor='early_stop_on', filepath=tmpdir, save_top_k=num_epochs, save_last=True + monitor='early_stop_on', filepath=tmpdir / '{epoch}', save_top_k=num_epochs, save_last=True ) trainer = Trainer( default_root_dir=tmpdir, early_stop_callback=False, checkpoint_callback=model_checkpoint, max_epochs=num_epochs, + val_check_interval=1.0, ) trainer.fit(model) path_last_epoch = str(tmpdir / f"epoch={num_epochs - 1}.ckpt") path_last = str(tmpdir / "last.ckpt") assert path_last == model_checkpoint.last_model_path + assert os.path.isfile(path_last_epoch) ckpt_last_epoch = torch.load(path_last_epoch) ckpt_last = torch.load(path_last) From f681d31eb5adae00b0b90ce9304e16fab18c1846 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 3 Oct 2020 21:36:46 +0200 Subject: [PATCH 06/21] add test --- tests/base/model_valid_steps.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/base/model_valid_steps.py b/tests/base/model_valid_steps.py index 77b5e7748a649..015368ab0e0fc 100644 --- a/tests/base/model_valid_steps.py +++ b/tests/base/model_valid_steps.py @@ -34,6 +34,21 @@ def validation_step(self, batch, batch_idx, *args, **kwargs): }) return output + def validation_step__decreasing(self, batch, batch_idx, *args, **kwargs): + if not hasattr(self, 'running_loss'): + self.running_loss = 1 + if not hasattr(self, 'running_acc'): + self.running_acc = 0 + + self.running_loss -= 1e-2 + self.running_acc += 1e-2 + + output = OrderedDict({ + 'val_loss': torch.tensor(self.running_loss), + 'val_acc': torch.tensor(self.running_acc), + }) + return output + def validation_step_no_monitor(self, batch, batch_idx, *args, **kwargs): """ Lightning calls this inside the validation loop From fde3fe438fdcac0900a6c3bdcff412ac4a2ea281 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 3 Oct 2020 22:02:16 +0200 Subject: [PATCH 07/21] add test --- .../trainer/flags/test_val_check_interval.py | 62 ++----------------- 1 file changed, 5 insertions(+), 57 deletions(-) diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index 91ec234781f99..030bbd143f5e6 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -4,7 +4,8 @@ @pytest.mark.parametrize('max_epochs', [1, 2, 3]) -def test_val_check_interval_1(tmpdir, max_epochs): +@pytest.mark.parametrize('interval', [1.0, 0.25, 0.33]) +def test_val_check_interval_(tmpdir, max_epochs, interval): class TestModel(SimpleModule): def __init__(self): @@ -21,64 +22,11 @@ def on_validation_epoch_start(self) -> None: model = TestModel() trainer = Trainer( + default_root_dir=tmpdir, max_epochs=max_epochs, - val_check_interval=1.0, + val_check_interval=interval, logger=False, ) trainer.fit(model) - assert model.val_epoch_calls == max_epochs - - -@pytest.mark.parametrize('max_epochs', [1, 2, 3]) -def test_val_check_interval_quarter(tmpdir, max_epochs): - - class TestModel(SimpleModule): - def __init__(self): - super().__init__() - self.train_epoch_calls = 0 - self.val_epoch_calls = 0 - - def on_train_epoch_start(self) -> None: - self.train_epoch_calls += 1 - - def on_validation_epoch_start(self) -> None: - if not self.trainer.running_sanity_check: - self.val_epoch_calls += 1 - - model = TestModel() - trainer = Trainer( - max_epochs=max_epochs, - val_check_interval=0.25, - logger=False, - ) - trainer.fit(model) - - assert model.val_epoch_calls == max_epochs * 4 - - -@pytest.mark.parametrize('max_epochs', [1, 2, 3]) -def test_val_check_interval_third(tmpdir, max_epochs): - - class TestModel(SimpleModule): - def __init__(self): - super().__init__() - self.train_epoch_calls = 0 - self.val_epoch_calls = 0 - - def on_train_epoch_start(self) -> None: - self.train_epoch_calls += 1 - - def on_validation_epoch_start(self) -> None: - if not self.trainer.running_sanity_check: - self.val_epoch_calls += 1 - - model = TestModel() - trainer = Trainer( - max_epochs=max_epochs, - val_check_interval=0.33, - logger=False, - ) - trainer.fit(model) - - assert model.val_epoch_calls == max_epochs * 3 + assert model.val_epoch_calls == max_epochs * round(1 / interval) From e28d3763202240ff10b95c7a65487b8de42ad0ff Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 3 Oct 2020 23:03:36 +0200 Subject: [PATCH 08/21] wip --- pytorch_lightning/trainer/training_loop.py | 16 ++++++++-------- tests/base/model_valid_steps.py | 16 ++++++++-------- tests/loggers/test_comet.py | 2 +- tests/loggers/test_mlflow.py | 2 +- tests/loggers/test_wandb.py | 2 +- tests/trainer/deprecate_legacy_flow/__init__.py | 0 6 files changed, 19 insertions(+), 19 deletions(-) create mode 100644 tests/trainer/deprecate_legacy_flow/__init__.py diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 4cd923b45242a..201a205f69ee5 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -197,12 +197,13 @@ def on_train_end(self): def check_checkpoint_callback(self, should_save, is_last=False): # TODO bake this logic into the checkpoint callback - if should_save: - checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] - if is_last and any(c.save_last for c in checkpoint_callbacks): - rank_zero_info('Saving latest checkpoint...') - model = self.trainer.get_model() - [c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks] + if not should_save: + return + checkpoint_callbacks = [c for c in self.trainer.callbacks if isinstance(c, ModelCheckpoint)] + if is_last and any(c.save_last for c in checkpoint_callbacks): + rank_zero_info('Saving latest checkpoint...') + model = self.trainer.get_model() + [c.on_validation_end(self.trainer, model) for c in checkpoint_callbacks] def on_train_epoch_start(self, epoch): model = self.trainer.get_model() @@ -599,8 +600,7 @@ def run_training_epoch(self): # epoch end hook self.run_on_epoch_end_hook() - # increment the global step once - # progress global step according to grads progress + # increment the global step once progress global step according to grads progress self.increment_accumulated_grad_global_step() def run_training_batch(self, batch, batch_idx, dataloader_idx): diff --git a/tests/base/model_valid_steps.py b/tests/base/model_valid_steps.py index 015368ab0e0fc..6e2ad247f84e3 100644 --- a/tests/base/model_valid_steps.py +++ b/tests/base/model_valid_steps.py @@ -2,6 +2,7 @@ from collections import OrderedDict from pytorch_lightning.core.step_result import EvalResult +import numpy as np import torch @@ -35,17 +36,16 @@ def validation_step(self, batch, batch_idx, *args, **kwargs): return output def validation_step__decreasing(self, batch, batch_idx, *args, **kwargs): - if not hasattr(self, 'running_loss'): - self.running_loss = 1 - if not hasattr(self, 'running_acc'): - self.running_acc = 0 + if not hasattr(self, 'running'): + self.running = 0 + self.running += 1 - self.running_loss -= 1e-2 - self.running_acc += 1e-2 + running_loss = np.e ** (10 / self.running) - 1 + running_acc = np.log(self.running + 1) output = OrderedDict({ - 'val_loss': torch.tensor(self.running_loss), - 'val_acc': torch.tensor(self.running_acc), + 'val_loss': torch.tensor(running_loss), + 'val_acc': torch.tensor(running_acc), }) return output diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 0e1199e88d27a..5c0e5abede209 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -98,7 +98,7 @@ def test_comet_logger_dirs_creation(tmpdir, monkeypatch): trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} def test_comet_name_default(): diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index e5b871e4ec7be..8c6abca63ae82 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -41,7 +41,7 @@ def test_mlflow_logger_dirs_creation(tmpdir): assert 'epoch' in os.listdir(tmpdir / exp_id / run_id / 'metrics') assert set(os.listdir(tmpdir / exp_id / run_id / 'params')) == model.hparams.keys() assert trainer.checkpoint_callback.dirpath == (tmpdir / exp_id / run_id / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} def test_mlflow_experiment_id_retrieved_once(tmpdir): diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 9907ad9d087a2..bccf4188dd291 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -95,4 +95,4 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): trainer.fit(model) assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} diff --git a/tests/trainer/deprecate_legacy_flow/__init__.py b/tests/trainer/deprecate_legacy_flow/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d From 171cf6e71ce176c44262cc67f4dd7ca6b9e184e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 25 Oct 2020 03:03:43 +0200 Subject: [PATCH 09/21] fix tests, revert old edits, fix merge conflicts, update doctests --- .../callbacks/model_checkpoint.py | 13 ++-- pytorch_lightning/trainer/trainer.py | 2 - pytorch_lightning/trainer/training_loop.py | 3 +- tests/base/model_valid_steps.py | 15 ----- tests/checkpointing/test_model_checkpoint.py | 61 +++++++++++------- tests/loggers/test_all.py | 4 +- tests/loggers/test_comet.py | 2 +- tests/loggers/test_mlflow.py | 2 +- tests/models/test_hooks.py | 4 +- .../deprecate_legacy_flow_log/__init__.py | 0 .../trainer/flags/test_val_check_interval.py | 62 +++++++++++++++++-- .../__init__.py | 0 .../test_eval_loop_dict_return.py | 0 .../test_trainer_steps_dict_return.py | 0 .../test_trainer_steps_scalar_return.py | 0 tests/trainer/test_trainer.py | 2 +- 16 files changed, 113 insertions(+), 57 deletions(-) delete mode 100644 tests/trainer/deprecate_legacy_flow_log/__init__.py rename tests/trainer/{deprecate_legacy_flow => legacy_deprecate_flow_log_tests}/__init__.py (100%) rename tests/trainer/{deprecate_legacy_flow_log => legacy_deprecate_flow_log_tests}/test_eval_loop_dict_return.py (100%) rename tests/trainer/{deprecate_legacy_flow_log => legacy_deprecate_flow_log_tests}/test_trainer_steps_dict_return.py (100%) rename tests/trainer/{deprecate_legacy_flow_log => legacy_deprecate_flow_log_tests}/test_trainer_steps_scalar_return.py (100%) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 07ef5292e9626..b680f9e09ec55 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -367,7 +367,7 @@ def _format_checkpoint_name( ) -> str: if not filename: # filename is not set, use default name - filename = "{epoch}" + filename = "{epoch}-{step}" # check and parse user passed keys in the string groups = re.findall(r"(\{.*?)[:\}]", filename) if len(groups) >= 0: @@ -401,7 +401,7 @@ def format_checkpoint_name( >>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={})) 'missing=0.ckpt' >>> ckpt = ModelCheckpoint(filename='{epoch}') - >>> os.path.basename(ckpt.format_checkpoint_name(0, {})) + >>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {})) 'epoch=0.ckpt' """ @@ -529,17 +529,17 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath) if self.monitor is None: self.best_model_path = self.last_model_path - def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath): + def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, step, filepath): current = metrics.get(self.monitor) if not isinstance(current, torch.Tensor) and current is not None: current = torch.tensor(current, device=pl_module.device) if self.check_monitor_top_k(current): - self._update_best_and_save(filepath, current, epoch, trainer, pl_module) + self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module) elif self.verbose: rank_zero_info( - f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}" + f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}" ) def _is_valid_monitor_key(self, metrics): @@ -550,6 +550,7 @@ def _update_best_and_save( filepath: str, current: torch.Tensor, epoch: int, + global_step: int, trainer, pl_module, ): @@ -580,7 +581,7 @@ def _update_best_and_save( if self.verbose: rank_zero_info( - f"Epoch {epoch:d}: {self.monitor} reached {current:0.5f} (best {self.best_model_score:0.5f})," + f"Epoch {epoch:d}, global step {global_step:d}: {self.monitor} reached {current:0.5f} (best {self.best_model_score:0.5f})," f' saving model to "{filepath}" as top {k}' ) self._save_model(filepath, trainer, pl_module) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 80651c95da733..44250ae905aba 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -507,8 +507,6 @@ def train(self): ' not been met. Training will continue...' ) - # lower the final step as there is no real change - self.global_step -= 1 # hook self.train_loop.on_train_end() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 30c9fb470a9ed..d32f47dbbd485 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -619,7 +619,8 @@ def run_training_epoch(self): # epoch end hook self.run_on_epoch_end_hook(epoch_output) - # increment the global step once progress global step according to grads progress + # increment the global step once + # progress global step according to grads progress self.increment_accumulated_grad_global_step() def run_training_batch(self, batch, batch_idx, dataloader_idx): diff --git a/tests/base/model_valid_steps.py b/tests/base/model_valid_steps.py index 0b00d77108520..e23e62dccdaba 100644 --- a/tests/base/model_valid_steps.py +++ b/tests/base/model_valid_steps.py @@ -15,7 +15,6 @@ from collections import OrderedDict from pytorch_lightning.core.step_result import EvalResult -import numpy as np import torch @@ -48,20 +47,6 @@ def validation_step(self, batch, batch_idx, *args, **kwargs): }) return output - def validation_step__decreasing(self, batch, batch_idx, *args, **kwargs): - if not hasattr(self, 'running'): - self.running = 0 - self.running += 1 - - running_loss = np.e ** (10 / self.running) - 1 - running_acc = np.log(self.running + 1) - - output = OrderedDict({ - 'val_loss': torch.tensor(running_loss), - 'val_acc': torch.tensor(running_acc), - }) - return output - def validation_step_no_monitor(self, batch, batch_idx, *args, **kwargs): """ Lightning calls this inside the validation loop diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 03084695ed962..65e5e37b357b8 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -171,9 +171,7 @@ def test_model_checkpoint_no_extraneous_invocations(tmpdir): """Test to ensure that the model callback saves the checkpoints only once in distributed mode.""" model = EvalModelTemplate() num_epochs = 4 - model_checkpoint = ModelCheckpointTestInvocations( - filepath=tmpdir, monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1 - ) + model_checkpoint = ModelCheckpointTestInvocations(monitor='early_stop_on', expected_count=num_epochs, save_top_k=-1) trainer = Trainer( distributed_backend="ddp_cpu", num_processes=2, @@ -189,8 +187,10 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): # empty filename: ckpt_name = ModelCheckpoint._format_checkpoint_name('', 3, 2, {}) assert ckpt_name == 'epoch=3-step=2' + ckpt_name = ModelCheckpoint._format_checkpoint_name(None, 3, 2, {}, prefix='test') assert ckpt_name == 'test-epoch=3-step=2' + # no groups case: ckpt_name = ModelCheckpoint._format_checkpoint_name('ckpt', 3, 2, {}, prefix='test') assert ckpt_name == 'test-ckpt' @@ -205,23 +205,43 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): ckpt_name = ModelCheckpoint._format_checkpoint_name('{epoch},{acc:.5f}', 3, 2, {'acc': 0.03}, prefix='test') assert ckpt_name == 'test@epoch=3,acc=0.03000' ModelCheckpoint.CHECKPOINT_JOIN_CHAR = char_org - # no filepath set - ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, 4, {}) - assert ckpt_name == 'epoch=3-step=4.ckpt' + + # no dirpath set + ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath=None).format_checkpoint_name(3, 2, {}) + assert ckpt_name == 'epoch=3-step=2.ckpt' ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='').format_checkpoint_name(5, 4, {}) assert ckpt_name == 'epoch=5-step=4.ckpt' + # CWD ckpt_name = ModelCheckpoint(monitor='early_stop_on', dirpath='.').format_checkpoint_name(3, 4, {}) - assert Path(ckpt_name) == str(Path('.').resolve() / 'epoch=3-step=4.ckpt') + assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt') + + # with ver + ckpt_name = ModelCheckpoint( + monitor='early_stop_on', dirpath=tmpdir, filename='name', prefix='test' + ).format_checkpoint_name(3, 2, {}, ver=3) + assert ckpt_name == tmpdir / 'test-name-v3.ckpt' + + # using slashes + ckpt_name = ModelCheckpoint( + monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}' + ).format_checkpoint_name(4, 3, {'val/loss': 0.03}) + assert ckpt_name == 'epoch=4_val/loss=0.03000.ckpt' + + # TODO: Checks with filepath. To be removed in v1.2 + # CWD + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath='.').format_checkpoint_name(3, 2, {}) + assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=2.ckpt') + # dir does not exist so it is used as filename filepath = tmpdir / 'dir' - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 4, {}) + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 2, {}) assert ckpt_name == tmpdir / 'test-dir.ckpt' # now, dir exists os.mkdir(filepath) - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 4, {}) - assert ckpt_name == filepath / 'test-epoch=3-step=4.ckpt' + ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 2, {}) + assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt' def test_model_checkpoint_save_last(tmpdir): @@ -230,7 +250,7 @@ def test_model_checkpoint_save_last(tmpdir): model = EvalModelTemplate() epochs = 3 ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last-{epoch}' - model_checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir / '{step}', save_top_k=-1, save_last=True) + model_checkpoint = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=-1, save_last=True) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=model_checkpoint, @@ -243,7 +263,7 @@ def test_model_checkpoint_save_last(tmpdir): ) last_filename = last_filename + '.ckpt' assert str(tmpdir / last_filename) == model_checkpoint.last_model_path - assert set(os.listdir(tmpdir)) == set([f'step={i}.ckpt' for i in [19, 29, 30]] + [last_filename]) + assert set(os.listdir(tmpdir)) == set([f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19, 29])] + [last_filename]) ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last' @@ -284,7 +304,7 @@ def test_model_checkpoint_none_monitor(tmpdir): model.validation_epoch_end = model.validation_epoch_end_no_monitor epochs = 2 - checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir / '{step}', save_top_k=-1) + checkpoint_callback = ModelCheckpoint(monitor=None, dirpath=tmpdir, save_top_k=-1) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint_callback, @@ -295,13 +315,13 @@ def test_model_checkpoint_none_monitor(tmpdir): # these should not be set if monitor is None assert checkpoint_callback.monitor is None - assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'step=20.ckpt' + assert checkpoint_callback.best_model_path == checkpoint_callback.last_model_path == tmpdir / 'epoch=1-step=19.ckpt' assert checkpoint_callback.best_model_score == 0 assert checkpoint_callback.best_k_models == {} assert checkpoint_callback.kth_best_model_path == '' # check that the correct ckpts were created - expected = [f'step={i}.ckpt' for i in [9, 19, 20]] + expected = [f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19])] assert set(os.listdir(tmpdir)) == set(expected) @@ -309,7 +329,7 @@ def test_model_checkpoint_none_monitor(tmpdir): def test_model_checkpoint_period(tmpdir, period): model = EvalModelTemplate() epochs = 5 - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir / '{epoch}', save_top_k=-1, period=period) + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', save_top_k=-1, period=period) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint_callback, @@ -352,7 +372,7 @@ def test_model_checkpoint_topk_all(tmpdir): seed_everything(1000) epochs = 2 model = EvalModelTemplate() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir / '{epoch}', monitor="early_stop_on", save_top_k=-1) + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}', monitor="early_stop_on", save_top_k=-1) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint_callback, @@ -412,12 +432,12 @@ def test_default_checkpoint_behavior(tmpdir): assert len(results) == 1 assert results[0]['test_acc'] >= 0.80 - assert len(trainer.dev_debugger.checkpoint_callback_history) == 4 + assert len(trainer.dev_debugger.checkpoint_callback_history) == 3 # make sure the checkpoint we saved has the metric in the name ckpts = os.listdir(os.path.join(tmpdir, 'lightning_logs', 'version_0', 'checkpoints')) assert len(ckpts) == 1 - assert ckpts[0] == 'epoch=2-step=15.ckpt' + assert ckpts[0] == 'epoch=2-step=14.ckpt' def test_ckpt_metric_names_results(tmpdir): @@ -475,13 +495,12 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): model = EvalModelTemplate() num_epochs = 3 model_checkpoint = ModelCheckpoint( - monitor='early_stop_on', filepath=dirpath / '{epoch}', save_top_k=num_epochs, save_last=True + monitor='early_stop_on', dirpath=tmpdir, filename="{epoch}", save_top_k=num_epochs, save_last=True ) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=model_checkpoint, max_epochs=num_epochs, - val_check_interval=1.0, ) trainer.fit(model) diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 35a3cc5d66abc..5405349a289f8 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -120,14 +120,14 @@ def log_metrics(self, metrics, step): (0, ['epoch', 'train_some_val']), (0, ['early_stop_on', 'epoch', 'val_acc']), (0, ['hp_metric']), - (0, ['epoch', 'test_acc', 'test_loss']) + (1, ['epoch', 'test_acc', 'test_loss']) ] assert log_metric_names == expected else: expected = [ (0, ['epoch', 'train_some_val']), (0, ['early_stop_on', 'epoch', 'val_acc']), - (0, ['epoch', 'test_acc', 'test_loss']) + (1, ['epoch', 'test_acc', 'test_loss']) ] assert log_metric_names == expected diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index a50d07a952374..d1679bf59f26d 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -127,7 +127,7 @@ def test_comet_logger_dirs_creation(comet, comet_experiment, tmpdir, monkeypatch trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=1, limit_val_batches=3) trainer.fit(model) - assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / version / 'checkpoints') + assert trainer.checkpoint_callback.dirpath == (tmpdir / 'test' / "1" / 'checkpoints') assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=9.ckpt'} diff --git a/tests/loggers/test_mlflow.py b/tests/loggers/test_mlflow.py index f6f5f7e7d777e..a200fbf549e6a 100644 --- a/tests/loggers/test_mlflow.py +++ b/tests/loggers/test_mlflow.py @@ -115,7 +115,7 @@ def test_mlflow_log_dir(client, mlflow, tmpdir): ) trainer.fit(model) assert trainer.checkpoint_callback.dirpath == (tmpdir / "exp-id" / "run-id" / 'checkpoints') - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} + assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0-step=0.ckpt'} def test_mlflow_logger_dirs_creation(tmpdir): diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 33e4f10dd05f9..886e0db4e7854 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -139,10 +139,10 @@ def on_train_batch_start(self, batch, batch_idx, dataloader_idx): trainer.fit(model) if batch_idx_ > len(model.val_dataloader()) - 1: assert trainer.batch_idx == len(model.val_dataloader()) - 1 - assert trainer.global_step == len(model.val_dataloader()) * max_epochs - 1 + assert trainer.global_step == len(model.val_dataloader()) * max_epochs else: assert trainer.batch_idx == batch_idx_ - assert trainer.global_step == (batch_idx_ + 1) * max_epochs - 1 + assert trainer.global_step == (batch_idx_ + 1) * max_epochs def test_trainer_model_hook_system(tmpdir): diff --git a/tests/trainer/deprecate_legacy_flow_log/__init__.py b/tests/trainer/deprecate_legacy_flow_log/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/tests/trainer/flags/test_val_check_interval.py b/tests/trainer/flags/test_val_check_interval.py index 26718715af145..9e8d5c23fb077 100644 --- a/tests/trainer/flags/test_val_check_interval.py +++ b/tests/trainer/flags/test_val_check_interval.py @@ -17,8 +17,7 @@ @pytest.mark.parametrize('max_epochs', [1, 2, 3]) -@pytest.mark.parametrize('interval', [1.0, 0.25, 0.33]) -def test_val_check_interval_(tmpdir, max_epochs, interval): +def test_val_check_interval_1(tmpdir, max_epochs): class TestModel(SimpleModule): def __init__(self): @@ -35,11 +34,64 @@ def on_validation_epoch_start(self) -> None: model = TestModel() trainer = Trainer( - default_root_dir=tmpdir, max_epochs=max_epochs, - val_check_interval=interval, + val_check_interval=1.0, logger=False, ) trainer.fit(model) - assert model.val_epoch_calls == max_epochs * round(1 / interval) + assert model.val_epoch_calls == max_epochs + + +@pytest.mark.parametrize('max_epochs', [1, 2, 3]) +def test_val_check_interval_quarter(tmpdir, max_epochs): + + class TestModel(SimpleModule): + def __init__(self): + super().__init__() + self.train_epoch_calls = 0 + self.val_epoch_calls = 0 + + def on_train_epoch_start(self) -> None: + self.train_epoch_calls += 1 + + def on_validation_epoch_start(self) -> None: + if not self.trainer.running_sanity_check: + self.val_epoch_calls += 1 + + model = TestModel() + trainer = Trainer( + max_epochs=max_epochs, + val_check_interval=0.25, + logger=False, + ) + trainer.fit(model) + + assert model.val_epoch_calls == max_epochs * 4 + + +@pytest.mark.parametrize('max_epochs', [1, 2, 3]) +def test_val_check_interval_third(tmpdir, max_epochs): + + class TestModel(SimpleModule): + def __init__(self): + super().__init__() + self.train_epoch_calls = 0 + self.val_epoch_calls = 0 + + def on_train_epoch_start(self) -> None: + self.train_epoch_calls += 1 + + def on_validation_epoch_start(self) -> None: + if not self.trainer.running_sanity_check: + self.val_epoch_calls += 1 + + model = TestModel() + trainer = Trainer( + max_epochs=max_epochs, + val_check_interval=0.33, + logger=False, + ) + trainer.fit(model) + + assert model.val_epoch_calls == max_epochs * 3 diff --git a/tests/trainer/deprecate_legacy_flow/__init__.py b/tests/trainer/legacy_deprecate_flow_log_tests/__init__.py similarity index 100% rename from tests/trainer/deprecate_legacy_flow/__init__.py rename to tests/trainer/legacy_deprecate_flow_log_tests/__init__.py diff --git a/tests/trainer/deprecate_legacy_flow_log/test_eval_loop_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py similarity index 100% rename from tests/trainer/deprecate_legacy_flow_log/test_eval_loop_dict_return.py rename to tests/trainer/legacy_deprecate_flow_log_tests/test_eval_loop_dict_return.py diff --git a/tests/trainer/deprecate_legacy_flow_log/test_trainer_steps_dict_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py similarity index 100% rename from tests/trainer/deprecate_legacy_flow_log/test_trainer_steps_dict_return.py rename to tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_dict_return.py diff --git a/tests/trainer/deprecate_legacy_flow_log/test_trainer_steps_scalar_return.py b/tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py similarity index 100% rename from tests/trainer/deprecate_legacy_flow_log/test_trainer_steps_scalar_return.py rename to tests/trainer/legacy_deprecate_flow_log_tests/test_trainer_steps_scalar_return.py diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 4841d1461fec6..ff02abac177cf 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -429,7 +429,7 @@ def mock_save_function(filepath, *args): losses = [10, 9, 2.8, 5, 2.5] checkpoint_callback = ModelCheckpoint( - dirpath=tmpdir, monitor='checkpoint_on', save_top_k=save_top_k, + dirpath=tmpdir, filename='{epoch}', monitor='checkpoint_on', save_top_k=save_top_k, save_last=save_last, prefix=file_prefix, verbose=1 ) checkpoint_callback.save_function = mock_save_function From dee1820d36a518db807239f597124039853c0333 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 25 Oct 2020 03:27:01 +0100 Subject: [PATCH 10/21] test + bugfix --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 16 ++++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index b680f9e09ec55..7f95f72af8223 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -554,7 +554,7 @@ def _update_best_and_save( trainer, pl_module, ): - k = epoch + 1 if self.save_top_k == -1 else self.save_top_k + k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k del_list = [] if len(self.best_k_models) == k and k > 0: diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 65e5e37b357b8..414b21e12ed5b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -753,3 +753,19 @@ def test_filepath_decomposition_dirpath_filename(tmpdir, filepath, dirpath, file assert mc_cb.dirpath == dirpath assert mc_cb.filename == filename + + +def test_val_check_interval_checkpoint_files(tmpdir): + """ Test correct checkpoint naming when validating/checkpointing multiple times per epoch. """ + model = EvalModelTemplate() + model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="val_acc", mode="max", verbose=True) + trainer = Trainer( + default_root_dir=tmpdir, + val_check_interval=0.2, + max_epochs=1, + limit_train_batches=10, + callbacks=[model_checkpoint] + ) + trainer.fit(model) + files = [p.name for p in Path(tmpdir).glob("*.ckpt")] + assert files == [f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]] From 999925df303877b7000e571001b28d5e3867f5fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 25 Oct 2020 03:51:17 +0100 Subject: [PATCH 11/21] sort files --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 414b21e12ed5b..e9c1d5c4fa0a3 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -767,5 +767,5 @@ def test_val_check_interval_checkpoint_files(tmpdir): callbacks=[model_checkpoint] ) trainer.fit(model) - files = [p.name for p in Path(tmpdir).glob("*.ckpt")] + files = sorted([p.name for p in Path(tmpdir).glob("*.ckpt")]) assert files == [f"epoch=0-step={s}.ckpt" for s in [1, 3, 5, 7, 9]] From 9c6ff0053141c4cc38e4a46dad63de814ecb6070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 25 Oct 2020 03:52:25 +0100 Subject: [PATCH 12/21] format test --- tests/checkpointing/test_model_checkpoint.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e9c1d5c4fa0a3..47969cfdcd929 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -758,7 +758,13 @@ def test_filepath_decomposition_dirpath_filename(tmpdir, filepath, dirpath, file def test_val_check_interval_checkpoint_files(tmpdir): """ Test correct checkpoint naming when validating/checkpointing multiple times per epoch. """ model = EvalModelTemplate() - model_checkpoint = ModelCheckpoint(dirpath=tmpdir, save_top_k=-1, monitor="val_acc", mode="max", verbose=True) + model_checkpoint = ModelCheckpoint( + dirpath=tmpdir, + save_top_k=-1, + monitor="val_acc", + mode="max", + verbose=True + ) trainer = Trainer( default_root_dir=tmpdir, val_check_interval=0.2, From eb69d4530a09f9ab6390d5d463eec57e3a785b7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 25 Oct 2020 03:56:37 +0100 Subject: [PATCH 13/21] suggestion by ananth --- pytorch_lightning/callbacks/model_checkpoint.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 7f95f72af8223..fd269fdd06324 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -223,13 +223,13 @@ def save_checkpoint(self, trainer, pl_module): monitor_candidates = self._monitor_candidates(trainer) # ie: path/val_loss=0.5.ckpt - filepath = self._get_metric_interpolated_filepath_name(epoch, global_step, monitor_candidates) + filepath = self._get_metric_interpolated_filepath_name(monitor_candidates) # callback supports multiple simultaneous modes # here we call each mode sequentially # Mode 1: save all checkpoints OR only the top k if self.save_top_k: - self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, global_step, filepath) + self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath) # Mode 2: save the last checkpoint self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath) @@ -481,13 +481,13 @@ def _validate_monitor_key(self, trainer): ) raise MisconfigurationException(m) - def _get_metric_interpolated_filepath_name(self, epoch: int, step: int, ckpt_name_metrics: Dict[str, Any]): + def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any]): + epoch = ckpt_name_metrics.get("current_epoch") + step = ckpt_name_metrics.get("global_step") filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics) version_cnt = 0 while self._fs.exists(filepath): - filepath = self.format_checkpoint_name( - epoch, step, ckpt_name_metrics, ver=version_cnt - ) + filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt) # this epoch called before version_cnt += 1 return filepath @@ -496,6 +496,7 @@ def _monitor_candidates(self, trainer): ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics) ckpt_name_metrics.update(trainer.logger_connector.callback_metrics) ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics) + ckpt_name_metrics.update({"global_step": trainer.global_step, "current_epoch": trainer.current_epoch}) return ckpt_name_metrics def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath): @@ -529,8 +530,10 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath) if self.monitor is None: self.best_model_path = self.last_model_path - def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, step, filepath): + def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath): current = metrics.get(self.monitor) + epoch = metrics.get("current_epoch") + step = metrics.get("global_step") if not isinstance(current, torch.Tensor) and current is not None: current = torch.tensor(current, device=pl_module.device) From 5fb867d057493f5ad631ff58e4143406fb9c8f14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 25 Oct 2020 04:14:04 +0100 Subject: [PATCH 14/21] added changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8d979b41893f5..01f86c766f78a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `strict` option to the scheduler dictionary ([#3586](https://github.com/PyTorchLightning/pytorch-lightning/pull/3586)) +- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807)) + ### Changed - Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587)) From d475a4e5ea7bf8d59fa7ca7aefad4b3ac1d09162 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 25 Oct 2020 04:18:09 +0100 Subject: [PATCH 15/21] naming --- pytorch_lightning/callbacks/model_checkpoint.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index fd269fdd06324..ea396b2584037 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -482,8 +482,8 @@ def _validate_monitor_key(self, trainer): raise MisconfigurationException(m) def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any]): - epoch = ckpt_name_metrics.get("current_epoch") - step = ckpt_name_metrics.get("global_step") + epoch = ckpt_name_metrics.get("epoch") + step = ckpt_name_metrics.get("step") filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics) version_cnt = 0 while self._fs.exists(filepath): @@ -496,7 +496,7 @@ def _monitor_candidates(self, trainer): ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics) ckpt_name_metrics.update(trainer.logger_connector.callback_metrics) ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics) - ckpt_name_metrics.update({"global_step": trainer.global_step, "current_epoch": trainer.current_epoch}) + ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch}) return ckpt_name_metrics def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath): @@ -532,8 +532,8 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath) def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath): current = metrics.get(self.monitor) - epoch = metrics.get("current_epoch") - step = metrics.get("global_step") + epoch = metrics.get("epoch") + step = metrics.get("step") if not isinstance(current, torch.Tensor) and current is not None: current = torch.tensor(current, device=pl_module.device) @@ -553,7 +553,7 @@ def _update_best_and_save( filepath: str, current: torch.Tensor, epoch: int, - global_step: int, + step: int, trainer, pl_module, ): @@ -584,7 +584,7 @@ def _update_best_and_save( if self.verbose: rank_zero_info( - f"Epoch {epoch:d}, global step {global_step:d}: {self.monitor} reached {current:0.5f} (best {self.best_model_score:0.5f})," + f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f} (best {self.best_model_score:0.5f})," f' saving model to "{filepath}" as top {k}' ) self._save_model(filepath, trainer, pl_module) From bc376db4593c669a87c374c03c5a875011ae7d44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 25 Oct 2020 04:18:16 +0100 Subject: [PATCH 16/21] docs --- pytorch_lightning/callbacks/model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index ea396b2584037..6ece83d03cc29 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -101,7 +101,7 @@ class ModelCheckpoint(Callback): ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... ) - By default, filename is ``None`` and will be set to ``'{epoch}'``. + By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``. Example: From 254f56c15664c3f5319736d4a944f5b7d52e7946 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 26 Oct 2020 10:55:26 +0100 Subject: [PATCH 17/21] example --- pytorch_lightning/callbacks/model_checkpoint.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 6ece83d03cc29..6cc2281f397be 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -223,7 +223,7 @@ def save_checkpoint(self, trainer, pl_module): monitor_candidates = self._monitor_candidates(trainer) # ie: path/val_loss=0.5.ckpt - filepath = self._get_metric_interpolated_filepath_name(monitor_candidates) + filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step) # callback supports multiple simultaneous modes # here we call each mode sequentially @@ -400,9 +400,9 @@ def format_checkpoint_name( >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') >>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={})) 'missing=0.ckpt' - >>> ckpt = ModelCheckpoint(filename='{epoch}') + >>> ckpt = ModelCheckpoint(filename='{step}') >>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {})) - 'epoch=0.ckpt' + 'step=0.ckpt' """ filename = self._format_checkpoint_name( @@ -481,9 +481,7 @@ def _validate_monitor_key(self, trainer): ) raise MisconfigurationException(m) - def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any]): - epoch = ckpt_name_metrics.get("epoch") - step = ckpt_name_metrics.get("step") + def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int): filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics) version_cnt = 0 while self._fs.exists(filepath): From febd53b2630d273a11aeb79e03d802f401224e59 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 26 Oct 2020 23:33:53 +0530 Subject: [PATCH 18/21] suggestion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/evaluation_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 1e3ece65bfd28..ffc72f8f0022e 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -252,7 +252,7 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): step = 'testing_epoch_end' if self.testing else 'validation_epoch_end' self.warning_cache.warn( f'The {step} should not return anything as of 9.1.' - f'to log, use self.log(...) or self.write(...) directly in the LightningModule' + ' To log, use self.log(...) or self.write(...) directly in the LightningModule' ) if using_eval_result and not user_reduced: From 9bb5fd2ef02ddf9c46134e4f78988cc584e09f60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 1 Nov 2020 05:48:23 +0100 Subject: [PATCH 19/21] fix test --- tests/checkpointing/test_model_checkpoint.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e992b8724ad36..565858f7410e6 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -377,7 +377,13 @@ def validation_epoch_end(self, outputs): return {'epoch': self.current_epoch} model = CustomModel() - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="epoch", mode='max', save_top_k=-1) + checkpoint_callback = ModelCheckpoint( + dirpath=tmpdir, + filename="{epoch}", + monitor="epoch", + mode='max', + save_top_k=-1, + ) trainer = Trainer( default_root_dir=tmpdir, checkpoint_callback=checkpoint_callback, From ad02f330e80d9332fb3209e299d04d6c77aef211 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sun, 1 Nov 2020 05:51:08 +0100 Subject: [PATCH 20/21] pep --- tests/loggers/test_wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 67463bbf0ddd2..5d1c092eb7737 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -143,4 +143,4 @@ def wrapper_something(): params = WandbLogger._sanitize_callable_params(params) assert params["gpus"] == '_gpus_arg_default' assert params["something"] == "something" - assert params["wrapper_something"] == "wrapper_something" \ No newline at end of file + assert params["wrapper_something"] == "wrapper_something" From c0882c48863c0eeaa71504ca192c81c24406b103 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 2 Nov 2020 15:19:07 +0530 Subject: [PATCH 21/21] pep --- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- tests/checkpointing/test_model_checkpoint.py | 13 ++++++++++--- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index bd0ede6c0efa3..53fcb4382cea4 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -581,8 +581,8 @@ def _update_best_and_save( if self.verbose: rank_zero_info( - f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f} (best {self.best_model_score:0.5f})," - f' saving model to "{filepath}" as top {k}' + f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}" + f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}' ) self._save_model(filepath, trainer, pl_module) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 565858f7410e6..fc46eee1bf725 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -235,12 +235,16 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir): # dir does not exist so it is used as filename filepath = tmpdir / 'dir' - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 2, {}) + ckpt_name = ModelCheckpoint( + monitor='early_stop_on', filepath=filepath, prefix='test' + ).format_checkpoint_name(3, 2, {}) assert ckpt_name == tmpdir / 'test-dir.ckpt' # now, dir exists os.mkdir(filepath) - ckpt_name = ModelCheckpoint(monitor='early_stop_on', filepath=filepath, prefix='test').format_checkpoint_name(3, 2, {}) + ckpt_name = ModelCheckpoint( + monitor='early_stop_on', filepath=filepath, prefix='test' + ).format_checkpoint_name(3, 2, {}) assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt' @@ -263,7 +267,10 @@ def test_model_checkpoint_save_last(tmpdir): ) last_filename = last_filename + '.ckpt' assert str(tmpdir / last_filename) == model_checkpoint.last_model_path - assert set(os.listdir(tmpdir)) == set([f'epoch={i}-step={j}.ckpt' for i, j in zip(range(epochs), [9, 19, 29])] + [last_filename]) + assert set(os.listdir(tmpdir)) == set( + [f"epoch={i}-step={j}.ckpt" for i, j in zip(range(epochs), [9, 19, 29])] + [last_filename] + ) + ModelCheckpoint.CHECKPOINT_NAME_LAST = 'last'