Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add step index in checkpoint name #3807

Merged
merged 27 commits into from
Nov 2, 2020
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236))

- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))

### Changed

- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))
Expand Down
65 changes: 35 additions & 30 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class ModelCheckpoint(Callback):
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
... )

By default, filename is ``None`` and will be set to ``'{epoch}'``.
By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``.


Example::
Expand Down Expand Up @@ -222,16 +222,16 @@ def save_checkpoint(self, trainer, pl_module):
monitor_candidates = self._monitor_candidates(trainer)

# ie: path/val_loss=0.5.ckpt
filepath = self._get_metric_interpolated_filepath_name(epoch, monitor_candidates)
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step)

# callback supports multiple simultaneous modes
# here we call each mode sequentially
# Mode 1: save all checkpoints OR only the top k
if self.save_top_k:
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath)
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath)

# Mode 2: save the last checkpoint
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath)

def __validate_init_configuration(self):
if self.save_top_k is not None and self.save_top_k < -1:
Expand Down Expand Up @@ -360,16 +360,17 @@ def _format_checkpoint_name(
cls,
filename: Optional[str],
epoch: int,
step: int,
metrics: Dict[str, Any],
prefix: str = "",
) -> str:
if not filename:
# filename is not set, use default name
filename = "{epoch}"
filename = "{epoch}-{step}"
# check and parse user passed keys in the string
groups = re.findall(r"(\{.*?)[:\}]", filename)
if len(groups) >= 0:
metrics["epoch"] = epoch
metrics.update({"epoch": epoch, 'step': step})
for group in groups:
name = group[1:]
filename = filename.replace(group, name + "={" + name)
Expand All @@ -379,32 +380,32 @@ def _format_checkpoint_name(
return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt])

def format_checkpoint_name(
self, epoch: int, metrics: Dict[str, Any], ver: Optional[int] = None
self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None
) -> str:
"""Generate a filename according to the defined template.

Example::

>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={}))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
'missing=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{step}')
>>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {}))
Borda marked this conversation as resolved.
Show resolved Hide resolved
'step=0.ckpt'

"""
filename = self._format_checkpoint_name(
self.filename, epoch, metrics, prefix=self.prefix
self.filename, epoch, step, metrics, prefix=self.prefix
Borda marked this conversation as resolved.
Show resolved Hide resolved
)
if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
Expand Down Expand Up @@ -479,13 +480,11 @@ def _validate_monitor_key(self, trainer):
)
raise MisconfigurationException(m)

def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics):
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int):
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
version_cnt = 0
while self._fs.exists(filepath):
filepath = self.format_checkpoint_name(
epoch, ckpt_name_metrics, ver=version_cnt
)
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
# this epoch called before
version_cnt += 1
return filepath
Expand All @@ -494,9 +493,10 @@ def _monitor_candidates(self, trainer):
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this update also happens in _format_checkpoint_name - is it possible to consolidate to a single place? is there a risk they can fall out of sync?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which point do you have in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ananthsub could you explain further what you mean? I think I understand but need some clarity here.

return ckpt_name_metrics

def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, filepath):
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath):
should_save_last = self.monitor is None or self.save_last
if not should_save_last:
return
Expand All @@ -506,7 +506,11 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
# when user ALSO asked for the 'last.ckpt' change the name
if self.save_last:
last_filepath = self._format_checkpoint_name(
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
self.CHECKPOINT_NAME_LAST,
trainer.current_epoch,
trainer.global_step,
ckpt_name_metrics,
prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")

Expand All @@ -523,17 +527,19 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
if self.monitor is None:
self.best_model_path = self.last_model_path

def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath):
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
current = metrics.get(self.monitor)
epoch = metrics.get("epoch")
step = metrics.get("step")
Comment on lines +532 to +533
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as above


if not isinstance(current, torch.Tensor) and current is not None:
current = torch.tensor(current, device=pl_module.device)

if self.check_monitor_top_k(current):
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)
elif self.verbose:
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}"
f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}"
)

def _is_valid_monitor_key(self, metrics):
Expand All @@ -544,11 +550,11 @@ def _update_best_and_save(
filepath: str,
current: torch.Tensor,
epoch: int,
step: int,
Borda marked this conversation as resolved.
Show resolved Hide resolved
trainer,
pl_module,
):

k = epoch + 1 if self.save_top_k == -1 else self.save_top_k
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k

del_list = []
if len(self.best_k_models) == k and k > 0:
Expand All @@ -575,9 +581,8 @@ def _update_best_and_save(

if self.verbose:
rank_zero_info(
f"Epoch {epoch:d}: {self.monitor} reached"
f" {current:0.5f} (best {self.best_model_score:0.5f}),"
f" saving model to {filepath} as top {k}"
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f} (best {self.best_model_score:0.5f}),"
f' saving model to "{filepath}" as top {k}'
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
)
self._save_model(filepath, trainer, pl_module)

Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,9 +250,10 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
# depre warning
if eval_results is not None and user_reduced:
step = 'testing_epoch_end' if self.testing else 'validation_epoch_end'
m = f'The {step} should not return anything as of 9.1.' \
f'to log, use self.log(...) or self.write(...) directly in the LightningModule'
self.warning_cache.warn(m)
self.warning_cache.warn(
f'The {step} should not return anything as of 9.1.'
' To log, use self.log(...) or self.write(...) directly in the LightningModule'
)

if using_eval_result and not user_reduced:
eval_results = self.__auto_reduce_result_objs(outputs)
Expand Down
Loading