Skip to content

Commit

Permalink
Update pytorch_lightning/callbacks/model_checkpoint.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 authored and Borda committed Sep 27, 2020
1 parent 8522d78 commit b0b96e2
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,10 +533,13 @@ def _do_check_save(
if cur_path != filepath:
self._del_model(cur_path)

def to_yaml(self, save_path: Optional[Union[str, Path]]=None):
""" Saves `{'checkpoint_name': score}` dict as a YAML file."""
def to_yaml(self, filepath: Optional[Union[str, Path]] = None):
"""
Saves the `best_k_models` dict containing the checkpoint
paths with the corresponding scores to a YAML file.
"""
best_k = {k: v.item() for k, v in self.best_k_models.items()}
if save_path is None:
save_path = os.path.join(self.dirpath, "best_k_models.yaml")
with open(save_path, "w") as fp:
if filepath is None:
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
with open(filepath, "w") as fp:
yaml.dump(best_k, fp)

0 comments on commit b0b96e2

Please sign in to comment.