Skip to content

Commit

Permalink
Merge pull request #46 from LouisDo2108/master
Browse files Browse the repository at this point in the history
add a rename param for download_from_wandb
  • Loading branch information
kaylode authored Feb 4, 2023
2 parents b7298bc + 75c9110 commit c4623d9
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions theseus/base/utilities/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import urllib.request as urlreq

import gdown
import os
from pathlib import Path

from theseus.base.utilities.loggers.observer import LoggerObserver

Expand Down Expand Up @@ -60,18 +62,28 @@ def download_from_url(url, root=None, filename=None):
return fpath


def download_from_wandb(filename, run_path, save_dir, generate_id_text_file=False):
import wandb
def download_from_wandb(filename, run_path, save_dir, rename=None, generate_id_text_file=False):

import wandb

try:
path = wandb.restore(filename, run_path=run_path, root=save_dir)

LOGGER.text("Successfully download {} from wandb run path {}".format(filename, run_path), level=LoggerObserver.INFO)

# Save run id to wandb_id.txt
if generate_id_text_file:
wandb_id = osp.basename(run_path)
with open(osp.join(save_dir, "wandb_id.txt"), "w") as f:
f.write(wandb_id)


if rename:
new_name = str(Path(path.name).resolve().parent / rename)
os.rename(Path(path.name).resolve(), new_name)
LOGGER.text("Saved to {}".format(new_name), level=LoggerObserver.INFO)
return new_name


LOGGER.text("Saved to {}".format((Path(save_dir) / path.name).resolve()), level=LoggerObserver.INFO)
return path.name
except:
LOGGER.text("Failed to download from wandb.", level=LoggerObserver.ERROR)
Expand Down

0 comments on commit c4623d9

Please sign in to comment.