Skip to content

Commit

Permalink
Add step index in checkpoint name (#3807)
Browse files Browse the repository at this point in the history
* true final value of global step

* ch check

* tests

* save each validation interval

* wip

* add test

* add test

* wip

* fix tests, revert old edits, fix merge conflicts, update doctests

* test + bugfix

* sort files

* format test

* suggestion by ananth

* added changelog

* naming

* docs

* example

* suggestion

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* fix test

* pep

* pep

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
4 people committed Nov 21, 2020
1 parent 69ce966 commit 861e8a3
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 67 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236))

- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))

### Changed

- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))
Expand Down
65 changes: 35 additions & 30 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ModelCheckpoint(Callback):
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )
By default, filename is ``None`` and will be set to ``'{epoch}'``.
By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``.
Example::
Expand Down Expand Up @@ -222,16 +222,16 @@ def save_checkpoint(self, trainer, pl_module):
monitor_candidates = self._monitor_candidates(trainer)

# ie: path/val_loss=0.5.ckpt
filepath = self._get_metric_interpolated_filepath_name(epoch, monitor_candidates)
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step)

# callback supports multiple simultaneous modes
# here we call each mode sequentially
# Mode 1: save all checkpoints OR only the top k
if self.save_top_k:
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath)
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath)

# Mode 2: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
Expand Down Expand Up @@ -360,16 +360,17 @@ def _format_checkpoint_name(
cls,
filename: Optional[str],
epoch: int,
step: int,
metrics: Dict[str, Any],
prefix: str = "",
) -> str:
if not filename:
# filename is not set, use default name
filename = "{epoch}"
filename = "{epoch}-{step}"
# check and parse user passed keys in the string
groups = re.findall(r"(\{.*?)[:\}]", filename)
if len(groups) >= 0:
metrics["epoch"] = epoch
metrics.update({"epoch": epoch, 'step': step})
for group in groups:
name = group[1:]
filename = filename.replace(group, name + "={" + name)
Expand All @@ -379,32 +380,32 @@ def _format_checkpoint_name(
return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt])

def format_checkpoint_name(
self, epoch: int, metrics: Dict[str, Any], ver: Optional[int] = None
self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None
) -> str:
"""Generate a filename according to the defined template.
Example::
>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
'missing=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{step}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {}))
'step=0.ckpt'
"""
filename = self._format_checkpoint_name(
self.filename, epoch, metrics, prefix=self.prefix
self.filename, epoch, step, metrics, prefix=self.prefix
)
if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
Expand Down Expand Up @@ -479,13 +480,11 @@ def _validate_monitor_key(self, trainer):
)
raise MisconfigurationException(m)

def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics):
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int):
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
version_cnt = 0
while self._fs.exists(filepath):
filepath = self.format_checkpoint_name(
epoch, ckpt_name_metrics, ver=version_cnt
)
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
return filepath
Expand All @@ -494,9 +493,10 @@ def _monitor_candidates(self, trainer):
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
return ckpt_name_metrics

def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, filepath):
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return
Expand All @@ -506,7 +506,11 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
# when user ALSO asked for the 'last.ckpt' change the name
if self.save_last:
last_filepath = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
trainer.global_step,
ckpt_name_metrics,
prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")

Expand All @@ -523,17 +527,19 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
if self.monitor is None:
self.best_model_path = self.last_model_path

def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath):
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
current = metrics.get(self.monitor)
epoch = metrics.get("epoch")
step = metrics.get("step")

if not isinstance(current, torch.Tensor) and current is not None:
current = torch.tensor(current, device=pl_module.device)

if self.check_monitor_top_k(current):
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)
elif self.verbose:
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}"
f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}"
)

def _is_valid_monitor_key(self, metrics):
Expand All @@ -544,11 +550,11 @@ def _update_best_and_save(
filepath: str,
current: torch.Tensor,
epoch: int,
step: int,
trainer,
pl_module,
):

k = epoch + 1 if self.save_top_k == -1 else self.save_top_k
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k

del_list = []
if len(self.best_k_models) == k and k > 0:
Expand All @@ -575,9 +581,8 @@ def _update_best_and_save(

if self.verbose:
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} reached"
f" {current:0.5f} (best {self.best_model_score:0.5f}),"
f" saving model to {filepath} as top {k}"
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
)
self._save_model(filepath, trainer, pl_module)

Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,10 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
# depre warning
if eval_results is not None and user_reduced:
step = 'testing_epoch_end' if self.testing else 'validation_epoch_end'
m = f'The {step} should not return anything as of 9.1.' \
f'to log, use self.log(...) or self.write(...) directly in the LightningModule'
self.warning_cache.warn(m)
self.warning_cache.warn(
f'The {step} should not return anything as of 9.1.'
' To log, use self.log(...) or self.write(...) directly in the LightningModule'
)

if using_eval_result and not user_reduced:
eval_results = self.__auto_reduce_result_objs(outputs)
Expand Down
Loading

0 comments on commit 861e8a3

Please sign in to comment.