Skip to content

Commit

Permalink
Update pytorch_lightning/callbacks/model_checkpoint.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Sep 26, 2020
1 parent ecab426 commit 9997453
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,10 +469,13 @@ def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
self.best_model_score = checkpointed_state["best_model_score"]
self.best_model_path = checkpointed_state["best_model_path"]

def to_yaml(self, save_path: Optional[Union[str, Path]]=None):
""" Saves `{'checkpoint_name': score}` dict as a YAML file."""
def to_yaml(self, filepath: Optional[Union[str, Path]] = None):
"""
Saves the `best_k_models` dict containing the checkpoint
paths with the corresponding scores to a YAML file.
"""
best_k = {k: v.item() for k, v in self.best_k_models.items()}
if save_path is None:
save_path = os.path.join(self.dirpath, "best_k_models.yaml")
with open(save_path, "w") as fp:
if filepath is None:
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
with open(filepath, "w") as fp:
yaml.dump(best_k, fp)

0 comments on commit 9997453

Please sign in to comment.