Skip to content

Commit

Permalink
feat(wandb): log models as artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
borisdayma committed Feb 26, 2021
1 parent 67c0215 commit 3fb55f6
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
-------------------------
"""
import os
import re
import numbers
from argparse import Namespace
from typing import Any, Dict, Optional, Union

Expand Down Expand Up @@ -162,10 +164,6 @@ def experiment(self) -> Run:
**self._kwargs
) if wandb.run is None else wandb.run

# save checkpoints in wandb dir to upload on W&B servers
if self._save_dir is None:
self._save_dir = self._experiment.dir

# define default x-axis (for latest wandb versions)
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
Expand Down Expand Up @@ -209,6 +207,18 @@ def version(self) -> Optional[str]:

@rank_zero_only
def finalize(self, status: str) -> None:
# upload all checkpoints from saving dir
# save checkpoints as artifacts
if self._log_model:
wandb.save(os.path.join(self.save_dir, "*.ckpt"))
# use run name and ensure it's a valid Artifact name
artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self.experiment.name)
# gather interesting metadata
metadata = {
k: v
for k, v in dict(self.experiment.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
# TODO: see if we can also log data from `trainer.checkpoint_callback` (best_model_path, etc)
artifact = wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
# TODO: we need access to `trainer.checkpoint_callback.dirpath`
artifact.add_dir(trainer.checkpoint_callback.dirpath)
self.experiment.log_artifact(artifact)

0 comments on commit 3fb55f6

Please sign in to comment.