diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index ec6cc705b..6c2cc4f70 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -30,6 +30,7 @@ def __init__(self, log_dir, job_name, cfg): self._model_dir = self._log_dir / "models" self._buffer_dir = self._log_dir / "buffers" self._save_model = cfg.save_model + self._disable_wandb_artifact = cfg.wandb.disable_artifact self._save_buffer = cfg.save_buffer self._group = cfg_to_group(cfg) self._seed = cfg.seed @@ -71,9 +72,10 @@ def save_model(self, policy, identifier): self._model_dir.mkdir(parents=True, exist_ok=True) fp = self._model_dir / f"{str(identifier)}.pt" policy.save(fp) - if self._wandb: + if self._wandb and not self._disable_wandb_artifact: + # note wandb artifact does not accept ":" in its name artifact = self._wandb.Artifact( - self._group + "-" + str(self._seed) + "-" + str(identifier), + self._group.replace(":", "_") + "-" + str(self._seed) + "-" + str(identifier), type="model", ) artifact.add_file(fp) diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index 6841cb828..9a97b50d1 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -30,5 +30,7 @@ policy: ??? wandb: enable: true + # Set to true to disable saving an artifact despite save_model == True + disable_artifact: false project: lerobot notes: ""